.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/plot_memory.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_plot_memory.py: ================================== Plotting memory consumptions ================================== This example compares memory used when using a ResNet or a Momentum ResNet as a function of their depth .. GENERATED FROM PYTHON SOURCE LINES 9-22 .. code-block:: default # Authors: Michael Sander, Pierre Ablin # License: MIT import torch import torch.nn as nn from momentumnet import MomentumNet import matplotlib.pyplot as plt from memory_profiler import memory_usage import numpy as np device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") .. GENERATED FROM PYTHON SOURCE LINES 23-25 Fix random seed for reproducible figures ########################################## .. GENERATED FROM PYTHON SOURCE LINES 25-28 .. code-block:: default torch.manual_seed(1) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 29-31 Parameters of the simulation ############################# .. GENERATED FROM PYTHON SOURCE LINES 31-92 .. code-block:: default Depths = np.arange(1, 200, 100) hidden = 1 d = 2 function = nn.Sequential(nn.Linear(d, hidden), nn.Tanh(), nn.Linear(hidden, d)) function_res = nn.Sequential( nn.Linear(d, hidden), nn.Tanh(), nn.Linear(hidden, d) ) X = torch.rand(2, 200000, d) def train(net): Loss = (net(X) ** 2).mean() Loss.backward() Mem_list_mom = [] for n_iters in Depths: mom_net = MomentumNet( [ function, ] * n_iters, gamma=1 - 1 / (50 * n_iters), init_speed=0, use_backprop=False, ) used_mem = np.max(memory_usage((train, (mom_net,)))) Mem_list_mom.append(used_mem) Mem_list_res = [] for n_iters in Depths: res_net = MomentumNet( [ function_res, ] * n_iters, gamma=0.0, init_speed=0, use_backprop=True, ) used_mem = np.max(memory_usage((train, (res_net,)))) Mem_list_res.append(used_mem) plt.figure(figsize=(8, 4)) plt.plot(Depths, Mem_list_res, label="ResNet", linewidth=4, color="darkblue") plt.plot(Depths, Mem_list_mom, label="MomentumNet", linewidth=4, color="red") y_ = plt.ylabel("Memory (MiB)") x_ = plt.xlabel("Depth") plt.legend() plt.show() .. image-sg:: /auto_examples/images/sphx_glr_plot_memory_001.png :alt: plot memory :srcset: /auto_examples/images/sphx_glr_plot_memory_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 3 minutes 14.555 seconds) .. _sphx_glr_download_auto_examples_plot_memory.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_memory.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_memory.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_