.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/drop_in_replacement.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_drop_in_replacement.py: ====================================================== From ResNets to Momentum ResNets 2) ====================================================== This illustrates on two simple examples how to replace an existing ResNet with a MomentumNet. Michael E. Sander, Pierre Ablin, Mathieu Blondel, Gabriel Peyre. Momentum Residual Neural Networks. Proceedings of the 38th International Conference on Machine Learning, PMLR 139:9276-9287 .. GENERATED FROM PYTHON SOURCE LINES 16-22 .. code-block:: default # Authors: Michael Sander, Pierre Ablin # License: MIT import torch from momentumnet import transform_to_momentumnet .. GENERATED FROM PYTHON SOURCE LINES 23-25 A torchvision model #################### .. GENERATED FROM PYTHON SOURCE LINES 25-31 .. code-block:: default from torchvision.models import resnet18 resnet = resnet18() mresnet18 = transform_to_momentumnet(resnet, gamma=0.9, use_backprop=False) .. GENERATED FROM PYTHON SOURCE LINES 32-33 It naturally extends the original ResNet .. GENERATED FROM PYTHON SOURCE LINES 33-42 .. code-block:: default x = torch.rand((64, 3, 7, 7)) resnet = resnet18() lx = resnet(x) mresnet = transform_to_momentumnet(resnet, gamma=0.0) # gamma = 0 should gives the exacts same model print(((resnet(x) - mresnet(x)) ** 2).sum()) .. GENERATED FROM PYTHON SOURCE LINES 43-45 A Natural Language Transformer model ##################################### .. GENERATED FROM PYTHON SOURCE LINES 45-54 .. code-block:: default transformer = torch.nn.Transformer(num_encoder_layers=6, num_decoder_layers=6) mtransformer = transform_to_momentumnet( transformer, # Specify the sublayers to transform sub_layers=["encoder.layers", "decoder.layers"], gamma=0.9, use_backprop=False, keep_first_layer=False, ) .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.000 seconds) .. _sphx_glr_download_auto_examples_drop_in_replacement.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: drop_in_replacement.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: drop_in_replacement.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_