当前位置:网站首页>【pytorch 記錄】pytorch的變量parameter、buffer。self.register_buffer()、self.register_parameter()
【pytorch 記錄】pytorch的變量parameter、buffer。self.register_buffer()、self.register_parameter()
2022-06-13 02:53:00 【magic_ll】
在pytorch中模型需要保存下來的參數包括:
- parameter:反向傳播需要被 optimizer 更新的,可以被訓練。
- buffer:反向傳播不需要被 optimizer 更新,不可被訓練。
這兩種參數都會分別保存到 一個OrderDict 的變量中,最終由 model.state_module() 返回進行保存。
1
nn.Module的介紹需要先說明下:直接torch.randn(1, 2) 這種定義的變量,沒有綁定在pytorch的網絡中,訓練結束後也就沒有在保存在模型中。當我們想要將一些變量保存(如yolov5中的anchor),可以用作簡單的後處理,就需要將這種變量注册到網絡中,可以使用的api為:
self.register_buffer():不可被訓練;self.register_parameter()、nn.parameter.Parameter()、nn.Parameter():可以被訓練。
對於pytorch定義網絡時,都要繼承與nn.Module。到源碼中看到該類的初始化中,成員變量如下,這裏我們關心是綠色選中區域,這三個成員都是 OrderedDict() 類型的
成員變量:
_buffers:由self.register_buffer() 定義,requires_grad默認為False,不可被訓練。_parasmeter:self.register_parameter()、nn.parameter.Parameter()、nn.Parameter() 定義的變量都存放在該屬性下,且定義的參數的 requires_grad 默認為 True。_module:nn.Sequential()、nn.conv() 等定義的網絡結構中的結構存放在該屬性下。
成員函數:
self.state_dict():OrderedDict 類型。保存神經網絡的推理參數,包括parameter、bufferself.name_parameters():為迭代器。self._module和self._parameters中所有的可訓練參數的名字+tensor。包括 BN的 bn.weight、bn.bias。self.parameters():與self.name_parameters()一樣,但不包含名字self.name_buffers():為迭代器。網絡中所有的不可訓練參數和自己注册的buffer 中的參數的名字+tensor。包括 BN的 bn.running_mean、bn.running_var、bn.num_batches_tracked。self.buffers():與self.name_buffers()一樣,但不包含名字net.named_modules():為迭代器。self._module中定義的網絡結構的名字+層net.modules()
2 代碼示例
import torch import torch.nn as nn class Model(nn.Module): def __init__(self): super(Model, self).__init__() """=======case1: self._modules=======""" self.conv = nn.Conv2d(1, 1, 3, 1, 1) self.TEST_1 = nn.Sequential(OrderedDict([ ('conv', nn.Conv2d(1, 1, 3, bias=False)), ('fc', nn.Linear(1, 2, bias=False)) ])) """=======case2: self._buffers=======""" self.register_buffer('TEST_2', torch.randn(1, 3)) """=======case3: model._parameters=======""" self.register_parameter('TEST_30', nn.Parameter(torch.randn(1, 4))) self.TEST_31 = nn.parameter.Parameter(torch.tensor(1.0)) self.TEST_32 = nn.Parameter(torch.tensor(2.0)) """=======case4=======""" self.TEST_4 = torch.randn(1, 2) def forward(self, x): return x model = Model() print() print(f'=========================================model._modules:\n{ model._modules}\n') print(f'=========================================model._buffers:\n{ model._buffers}\n') print(f'=========================================model._parameters:\n{ model._parameters}\n') print(f'=========================================model.state_dict():\n{ model.state_dict()}\n')其實debug方式查看會更便捷。直接打印也沒有問題。
如果要打印介紹的成員函數的內容,則有:named_buffers = [param for param in model.named_buffers()] print(f'===================================named_buffers:\n{ named_buffers}\n') named_parameters = [param for param in model.named_parameters()] print(f'===================================named_parameters:\n{ named_parameters}\n') named_modules = [param for param in model.named_modules()] print(f'===================================named_modules:\n{ named_modules}\n')
边栏推荐
- 二叉樹初始化代碼
- Opencv 08 demonstrates the effect of opening and closing operations of erode, dilate and morphological function morphologyex.
- Opencvsharp4 pixel read / write and memory structure of color image and gray image
- Vant realizes the adaptation of mobile terminal
- [thoughts in the essay] mourn for development technology expert Mao Xingyun
- nn. Conv2d and nn Convtranspose2d differences
- Exam23 named windows and simplified paths, grayscale conversion
- PK of dotnet architecture
- 数字IC设计——FIFO的设计
- Traverse the array and delete an element until it is deleted
猜你喜欢

Prometheus install and register services

Code d'initialisation de l'arbre binaire

Detailed explanation of data processing in machine learning (I) -- missing value processing (complete code attached)

Summary of the latest IOS interview questions in June 2020 (answers)

How to destroy a fragment- How to destroy Fragment?

Bi modal progressive mask attention for fine graded recognition

Rough understanding of wechat cloud development

Introduction to facial expression recognition system - Technical Paper Edition
![[data analysis and visualization] key points of data drawing 12- importance of chart notes](/img/9c/c610c6f9d08952aece97f788ae35a7.jpg)
[data analysis and visualization] key points of data drawing 12- importance of chart notes
![[data and Analysis Visualization] D3 introductory tutorial 1-d3 basic knowledge](/img/a8/468a0c4d4a009e155679898fac4b81.jpg)
[data and Analysis Visualization] D3 introductory tutorial 1-d3 basic knowledge
随机推荐
Change tax for 2
01 initial knowledge of wechat applet
[data analysis and visualization] key points of data drawing 3- spaghetti map
Ijkplayer source code - audio playback
Introduction and download of common data sets for in-depth learning (with network disk link)
Why does it feel that most papers still use RESNET as the backbone network rather than densenet?
PK of dotnet architecture
CV 06 demonstrates backgroundworker
Linked list: palindrome linked list
[deep learning] fast Reid tutorial
JS multiple judgment writing
nn. Conv2d and nn Convtranspose2d differences
IOS interview · full bat interview record of an IOS programmer (including the true interview questions of Baidu + Netease + Alibaba)
二叉樹初始化代碼
Surpass the strongest variant of RESNET! Google proposes a new convolution + attention network: coatnet, with an accuracy of 89.77%!
Ijkplayer source code - remuxing
Bi modal progressive mask attention for fine graded recognition
Logiciel professionnel de gestion de base de données: Valentina Studio Pro pour Mac
Example 4 linear filtering and built-in filtering
Exam23 named windows and simplified paths, grayscale conversion

