当前位置:网站首页>【pytorch 記錄】pytorch的變量parameter、buffer。self.register_buffer()、self.register_parameter()

【pytorch 記錄】pytorch的變量parameter、buffer。self.register_buffer()、self.register_parameter()

2022-06-13 02:53:00 magic_ll

在pytorch中模型需要保存下來的參數包括:

  • parameter:反向傳播需要被 optimizer 更新的,可以被訓練。
  • buffer:反向傳播不需要被 optimizer 更新,不可被訓練。

這兩種參數都會分別保存到 一個OrderDict 的變量中,最終由 model.state_module() 返回進行保存。

1 nn.Module的介紹

需要先說明下:直接torch.randn(1, 2) 這種定義的變量,沒有綁定在pytorch的網絡中,訓練結束後也就沒有在保存在模型中。當我們想要將一些變量保存(如yolov5中的anchor),可以用作簡單的後處理,就需要將這種變量注册到網絡中,可以使用的api為:self.register_buffer() :不可被訓練;self.register_parameter()nn.parameter.Parameter()nn.Parameter():可以被訓練。

對於pytorch定義網絡時,都要繼承與 nn.Module。到源碼中看到該類的初始化中,成員變量如下,這裏我們關心是綠色選中區域,這三個成員都是 OrderedDict() 類型的
在這裏插入圖片描述

成員變量:

  • _buffers:由self.register_buffer() 定義,requires_grad默認為False,不可被訓練。
  • _parasmeter:self.register_parameter()、nn.parameter.Parameter()、nn.Parameter() 定義的變量都存放在該屬性下,且定義的參數的 requires_grad 默認為 True。
  • _module:nn.Sequential()、nn.conv() 等定義的網絡結構中的結構存放在該屬性下。

成員函數:

  • self.state_dict():OrderedDict 類型。保存神經網絡的推理參數,包括parameter、buffer
  • self.name_parameters():為迭代器。self._moduleself._parameters中所有的可訓練參數的名字+tensor。包括 BN的 bn.weight、bn.bias。
  • self.parameters():與self.name_parameters()一樣,但不包含名字
  • self.name_buffers():為迭代器。網絡中所有的不可訓練參數和自己注册的buffer 中的參數的名字+tensor。包括 BN的 bn.running_mean、bn.running_var、bn.num_batches_tracked。
  • self.buffers():與self.name_buffers()一樣,但不包含名字
  • net.named_modules():為迭代器。self._module中定義的網絡結構的名字+層
  • net.modules()

2 代碼示例

import torch  
import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()

        """=======case1: self._modules======="""
        self.conv = nn.Conv2d(1, 1, 3, 1, 1)
        self.TEST_1 = nn.Sequential(OrderedDict([
            ('conv', nn.Conv2d(1, 1, 3, bias=False)),
            ('fc', nn.Linear(1, 2, bias=False))
        ]))

        """=======case2: self._buffers======="""
        self.register_buffer('TEST_2', torch.randn(1, 3))

        """=======case3: model._parameters======="""
        self.register_parameter('TEST_30', nn.Parameter(torch.randn(1, 4)))
        self.TEST_31 = nn.parameter.Parameter(torch.tensor(1.0))
        self.TEST_32 = nn.Parameter(torch.tensor(2.0))

        """=======case4======="""
        self.TEST_4 = torch.randn(1, 2)

    def forward(self, x):
       return x

model = Model() 
print()
print(f'=========================================model._modules:\n{
       model._modules}\n') 
print(f'=========================================model._buffers:\n{
       model._buffers}\n') 
print(f'=========================================model._parameters:\n{
       model._parameters}\n')
print(f'=========================================model.state_dict():\n{
       model.state_dict()}\n')

其實debug方式查看會更便捷。直接打印也沒有問題。
在這裏插入圖片描述

如果要打印介紹的成員函數的內容,則有:

named_buffers = [param for param in model.named_buffers()]
print(f'===================================named_buffers:\n{
       named_buffers}\n')

named_parameters = [param for param in model.named_parameters()]
print(f'===================================named_parameters:\n{
       named_parameters}\n')

named_modules = [param for param in model.named_modules()]
print(f'===================================named_modules:\n{
       named_modules}\n')
原网站

版权声明
本文为[magic_ll]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/164/202206130252266190.html