.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/drop_in_replacement_advanced.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_drop_in_replacement_advanced.py: ====================================================== From ResNets to Momentum ResNets 3) ====================================================== This illustrates on a more complex example how to replace an existing ResNet with a MomentumNet. Michael E. Sander, Pierre Ablin, Mathieu Blondel, Gabriel Peyre. Momentum Residual Neural Networks. Proceedings of the 38th International Conference on Machine Learning, PMLR 139:9276-9287 .. GENERATED FROM PYTHON SOURCE LINES 16-22 .. code-block:: default # Authors: Michael Sander, Pierre Ablin # License: MIT import torch from momentumnet import transform_to_momentumnet .. GENERATED FROM PYTHON SOURCE LINES 23-25 We will use a Vision Transformer model ####################################### .. GENERATED FROM PYTHON SOURCE LINES 27-29 From https://arxiv.org/abs/2010.11929 Code adapted from https://github.com/lucidrains/vit-pytorch .. GENERATED FROM PYTHON SOURCE LINES 29-44 .. code-block:: default from vit_pytorch import ViT v = ViT( image_size=256, patch_size=32, num_classes=1000, dim=1024, depth=6, heads=16, mlp_dim=2048, dropout=0.1, emb_dropout=0.1, ) .. GENERATED FROM PYTHON SOURCE LINES 45-47 We first rename transformer layer from v to be consistent with our forward rule .. GENERATED FROM PYTHON SOURCE LINES 47-50 .. code-block:: default v.transformer = v.transformer.layers .. GENERATED FROM PYTHON SOURCE LINES 51-53 We simply modify the transformer module to have a Sequential form .. GENERATED FROM PYTHON SOURCE LINES 53-61 .. code-block:: default v_modules = [] for i, _ in enumerate(v.transformer): for layer in v.transformer[i]: v_modules.append(layer) v.transformer = torch.nn.Sequential(*v_modules) .. GENERATED FROM PYTHON SOURCE LINES 62-63 Now we can transform it to its momentum version .. GENERATED FROM PYTHON SOURCE LINES 63-72 .. code-block:: default mv = transform_to_momentumnet( v, ["transformer"], gamma=0.9, keep_first_layer=False, use_backprop=False, is_residual=True, ) .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.000 seconds) .. _sphx_glr_download_auto_examples_drop_in_replacement_advanced.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: drop_in_replacement_advanced.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: drop_in_replacement_advanced.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_