.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/plot_separation_nested_rings.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_plot_separation_nested_rings.py: =================================================== Separation of nested rings using a Momentum ResNet. =================================================== This example shows how a Momentum ResNet separates two nested rings Michael E. Sander, Pierre Ablin, Mathieu Blondel, Gabriel Peyre. Momentum Residual Neural Networks. Proceedings of the 38th International Conference on Machine Learning, PMLR 139:9276-9287 .. GENERATED FROM PYTHON SOURCE LINES 16-30 .. code-block:: default # Authors: Michael Sander, Pierre Ablin # License: MIT import matplotlib.pyplot as plt import torch from torch import nn import numpy as np import torch.optim as optim from momentumnet import MomentumNet from momentumnet.toy_datasets import make_data torch.manual_seed(1) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 31-33 Parameters of the simulation ############################# .. GENERATED FROM PYTHON SOURCE LINES 33-57 .. code-block:: default hidden = 16 n_iters = 10 N = 1000 function = nn.Sequential(nn.Linear(2, hidden), nn.Tanh(), nn.Linear(hidden, 2)) # Network mresnet = MomentumNet( [ function, ] * n_iters, gamma=0.99, ) criterion = nn.CrossEntropyLoss() n_epochs = 30 lr_list = np.ones(n_epochs) * 0.5 optimizer = optim.Adam(mresnet.parameters(), lr=lr_list[0]) .. GENERATED FROM PYTHON SOURCE LINES 58-60 Training ################################## .. GENERATED FROM PYTHON SOURCE LINES 60-74 .. code-block:: default for i in range(n_epochs): for param_group in optimizer.param_groups: param_group["lr"] = lr_list[i] optimizer.zero_grad() x, y = make_data( 2000, ) pred = mresnet(x) loss = criterion(pred, y) loss.backward() optimizer.step() .. GENERATED FROM PYTHON SOURCE LINES 75-77 Plot the results ############################################# .. GENERATED FROM PYTHON SOURCE LINES 77-96 .. code-block:: default x_, y_ = make_data(500) fig, axis = plt.subplots(1, n_iters + 1, figsize=(n_iters + 1, 1)) for i in range(n_iters + 1): mom_net = MomentumNet( [ function, ] * i, gamma=0.99, init_speed=0, ) with torch.no_grad(): pred_ = mom_net(x_) axis[i].scatter(pred_[:, 0], pred_[:, 1], c=y_ + 3, s=1) axis[i].axis("off") plt.show() .. image-sg:: /auto_examples/images/sphx_glr_plot_separation_nested_rings_001.png :alt: plot separation nested rings :srcset: /auto_examples/images/sphx_glr_plot_separation_nested_rings_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 40.875 seconds) .. _sphx_glr_download_auto_examples_plot_separation_nested_rings.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_separation_nested_rings.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_separation_nested_rings.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_