.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/plot_dynamics_1D.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_plot_dynamics_1D.py: ================================== Plotting the dynamics in 1D ================================== This example compares the dynamics of a ResNet and a Momentum ResNet. We try to learn a mapping with crossing trajectories. Trajectories corresponding to the ResNet fail to cross. On the opposite, the Momentum ResNet learns the desired mapping. Michael E. Sander, Pierre Ablin, Mathieu Blondel, Gabriel Peyre. Momentum Residual Neural Networks. Proceedings of the 38th International Conference on Machine Learning, PMLR 139:9276-9287 .. GENERATED FROM PYTHON SOURCE LINES 18-32 .. code-block:: default # Authors: Michael Sander, Pierre Ablin # License: MIT import copy import torch import torch.nn as nn import numpy as np import matplotlib.pyplot as plt import torch.optim as optim from momentumnet import MomentumNet from momentumnet.toy_datasets import make_data_1D .. GENERATED FROM PYTHON SOURCE LINES 33-35 Fix random seed for reproducible figures ########################################## .. GENERATED FROM PYTHON SOURCE LINES 35-38 .. code-block:: default torch.manual_seed(1) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 39-41 Parameters of the simulation ############################# .. GENERATED FROM PYTHON SOURCE LINES 41-48 .. code-block:: default hidden = 16 n_iters = 15 gamma = 0.99 d = 1 .. GENERATED FROM PYTHON SOURCE LINES 49-51 Defining the functions for the forward pass ############################################ .. GENERATED FROM PYTHON SOURCE LINES 51-55 .. code-block:: default function = nn.Sequential(nn.Linear(d, hidden), nn.Tanh(), nn.Linear(hidden, d)) function_res = copy.deepcopy(function) .. GENERATED FROM PYTHON SOURCE LINES 56-58 Defining our models #################### .. GENERATED FROM PYTHON SOURCE LINES 58-76 .. code-block:: default mom_net = MomentumNet( [ function, ] * n_iters, gamma=gamma, init_speed=0, ) res_net = MomentumNet( [ function_res, ] * n_iters, gamma=0.0, init_speed=0, ) .. GENERATED FROM PYTHON SOURCE LINES 77-79 Training our models to learn a non-homeomorphic mapping ######################################################## .. GENERATED FROM PYTHON SOURCE LINES 79-111 .. code-block:: default def h(x): return -(x ** 3) def Loss(pred, x): return ((pred - h(x)) ** 2).mean() optimizer = optim.SGD(mom_net.parameters(), lr=0.01) for i in range(301): optimizer.zero_grad() x = make_data_1D(200) pred = mom_net(x) loss = Loss(pred, x) loss.backward() optimizer.step() optimizer = optim.SGD(res_net.parameters(), lr=0.01) for i in range(2001): optimizer.zero_grad() x = make_data_1D(200) pred = res_net(x) loss = Loss(pred, x) loss.backward() optimizer.step() .. GENERATED FROM PYTHON SOURCE LINES 112-114 Plotting the output #################### .. GENERATED FROM PYTHON SOURCE LINES 114-191 .. code-block:: default n_plot = 8 num_plots = n_plot plt.figure(figsize=(3, 4)) colormap = plt.cm.gist_ncar plt.gca().set_prop_cycle( plt.cycler("color", plt.cm.jet(np.linspace(0.8, 0.95, num_plots))) ) x_ = make_data_1D(n_plot) x = np.linspace(-1, 1, n_plot) x_ = torch.tensor(x).view(-1, d).float() x_axis = np.arange(0, n_iters + 1) preds = np.zeros((n_iters + 1, n_plot)) preds[0] = x_[:, 0] for i in range(1, n_iters + 1): mom_net = MomentumNet( [ function, ] * i, gamma=gamma, init_speed=0, ) with torch.no_grad(): pred_ = mom_net(x_) preds[i] = pred_[:, 0] plt.plot(preds, x_axis, "-x", lw=2.5) plt.xticks([], []) plt.yticks([], []) plt.title("Momentum ResNet") plt.ylabel("Depth") plt.xlabel("Input") plt.show() num_plots = n_plot plt.figure(figsize=(3, 4)) colormap = plt.cm.gist_ncar plt.gca().set_prop_cycle( plt.cycler("color", plt.cm.jet(np.linspace(0.0, 0.1, num_plots))) ) x_axis = np.arange(0, n_iters + 1) preds_res = np.zeros((n_iters + 1, n_plot)) preds_res[0] = x_[:, 0] for i in range(1, n_iters + 1): res_net = MomentumNet( [ function_res, ] * i, gamma=0.0, init_speed=0, ) with torch.no_grad(): pred_ = res_net(x_) preds_res[i] = pred_[:, 0] plt.plot(preds_res, x_axis, "-x", lw=2.5) plt.xticks([], []) plt.yticks([], []) plt.title("ResNet") plt.ylabel("Depth") plt.xlabel("Input") plt.show() .. rst-class:: sphx-glr-horizontal * .. image-sg:: /auto_examples/images/sphx_glr_plot_dynamics_1D_001.png :alt: Momentum ResNet :srcset: /auto_examples/images/sphx_glr_plot_dynamics_1D_001.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/images/sphx_glr_plot_dynamics_1D_002.png :alt: ResNet :srcset: /auto_examples/images/sphx_glr_plot_dynamics_1D_002.png :class: sphx-glr-multi-img .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 27.402 seconds) .. _sphx_glr_download_auto_examples_plot_dynamics_1D.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_dynamics_1D.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_dynamics_1D.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_