当前位置:网站首页>PyTorch的参数固定以及detach clone
PyTorch的参数固定以及detach clone
2022-06-10 11:33:00 【MallocLu】
detach
detach()后的tensor与原tensor共享数据内存,当原始tensor在计算图中数值发生反向传播等更新之后,detach()的tensor值也发生了改变。
import torch
from torch import optim
from torch.nn import Parameter
x = torch.tensor(1.)
a = torch.tensor(1., requires_grad=True)
b = Parameter(torch.tensor(2.))
y = a**2 * x + b * x
z = y**2 + 2*y
optimizer = optim.SGD([a, b], lr=0.01)
ta = a.detach()
tb = b.detach()
print('before:', a, b, ta, tb)
print()
optimizer.zero_grad()
z.backward()
optimizer.step()
print('before:', a, b, ta, tb)
# before: tensor(1., requires_grad=True) Parameter containing:
# tensor(2., requires_grad=True) tensor(1.) tensor(2.)
#
# before: tensor(0.8400, requires_grad=True) Parameter containing:
# tensor(1.9200, requires_grad=True) tensor(0.8400) tensor(1.9200)
clone
clone使用了新的内存,当原始tensor在计算图中数值发生反向传播等更新之后,clone()的tensor值不会发生变化。
如果新的tensor从 tensor(requires_grad=True)或者Parameter 克隆而来,则其grad_fn=,即表示其仍可以作为中间节点传播梯度,相当于一次恒等映射。
import torch
from torch import optim
from torch.nn import Parameter
x = torch.tensor(1.)
a = torch.tensor(1., requires_grad=True)
b = Parameter(torch.tensor(2.))
y = a**2 * x + b * x
z = y**2 + 2*y
optimizer = optim.SGD([a, b], lr=0.01)
ta = a.clone()
tb = b.clone()
print('before:', a, b, ta, tb)
print()
optimizer.zero_grad()
z.backward()
optimizer.step()
print('before:', a, b, ta, tb)
# before: tensor(1., requires_grad=True) Parameter containing:
# tensor(2., requires_grad=True) tensor(1., grad_fn=<CloneBackward>) tensor(2., grad_fn=<CloneBackward>)
#
# before: tensor(0.8400, requires_grad=True) Parameter containing:
# tensor(1.9200, requires_grad=True) tensor(1., grad_fn=<CloneBackward>) tensor(2., grad_fn=<CloneBackward>)
import torch
from torch import optim
from torch.nn import Parameter
x = torch.tensor(1.)
a = torch.tensor(1., requires_grad=True)
b = Parameter(torch.tensor(2.))
y = a**2 * x + b * x
z = y**2 + 2*y
# 从z -> z2相当于一个恒等映射,梯度还是能够传播的
z2 = z.clone()
optimizer = optim.SGD([a, b], lr=0.01)
print('before:', a, b)
print()
optimizer.zero_grad()
z2.backward()
optimizer.step()
print('before:', a, b)
# before: tensor(1., requires_grad=True) Parameter containing:
# tensor(2., requires_grad=True)
#
# before: tensor(0.8400, requires_grad=True) Parameter containing:
# tensor(1.9200, requires_grad=True)
lambda
from functools import reduce
foo = [2, 18, 9, 22, 17, 24, 8, 12, 27]
# lambda语句中,冒号前是参数,可以有多个,用逗号隔开,冒号右边的返回值。
# 搭配filter
print(filter(lambda x: x % 3 == 0, foo))
print(list(filter(lambda x: x % 3 == 0, foo)))
print()
# 搭配map
print(map(lambda x: x * 2 + 10, foo))
print(list(map(lambda x: x * 2 + 10, foo)))
print()
# 搭配reduce
# reduce累加,lambda先计算x2 + y18 = 20,然后x20 + y9 = 29
print(reduce(lambda x, y: x + y, foo))
# <filter object at 0x000002206C252A88>
# [18, 9, 24, 12, 27]
#
# <map object at 0x000002206C1FF608>
# [14, 46, 28, 54, 44, 58, 26, 34, 64]
#
# 139
detach固定部分参数
缺点:只能固定detach之前的所有参数
# 在Net.forward里删除out = out.detach(),结果为:
# layer1.weight False
# layer1.bias False
# layer2.weight False
# layer2.bias False
# 即所有参数发生了变化(优化)
# 在Net.forward里添加out = out.detach(),结果为:
# layer1.weight True
# layer1.bias True
# layer2.weight False
# layer2.bias False
# 即self.layer1里的参数没有反生变化,这是因为out = out.detach()返回的tensor不能传播梯度,
# 所以在反向传播至该tensor时,不能再向前传播,所以其之前的参数将被锁定
import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.layer1 = nn.Linear(10, 5)
self.layer2 = nn.Linear(5, 3)
def forward(self, x):
out = self.layer1(x)
out = out.detach()
out = F.relu(self.layer2(out))
return out
net = Net()
optimizer = optim.SGD(net.parameters(), lr=0.01)
input = torch.randn(8, 10)
# 把训练前各个参数的值存储起来
storeParam = {
}
for name, param in net.named_parameters():
storeParam[name] = param.detach().clone()
for i in range(100):
out = net(input)
loss = F.mse_loss(out, torch.zeros(8, 3))
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 比较训练前后各个参数的值是否相同
for name, param in net.named_parameters():
print(f"{
name} {
torch.equal(param, storeParam[name])}")
requires_grad = False固定部分参数
# 通过self.layer1.weight.requires_grad = False,只固定该参数
# loss的requires_grad=False则其不能调用backward;某non-leaf Tensor的requires_grad=False则其之前的参数由于梯度反向传播时被截断,
# 所以不会得到更新;leaf Tensor的requires_grad=False,表示其不需要梯度,所以其也不能被更新。
import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.layer1 = nn.Linear(10, 5)
self.layer1.weight.requires_grad = False
self.layer2 = nn.Linear(5, 3)
def forward(self, x):
out = self.layer1(x)
out = F.relu(self.layer2(out))
return out
net = Net()
optimizer = optim.SGD(net.parameters(), lr=0.01)
input = torch.randn(8, 10)
# 把训练前各个参数的值存储起来
storeParam = {
}
for name, param in net.named_parameters():
storeParam[name] = param.detach().clone()
for i in range(100):
out = net(input)
loss = F.mse_loss(out, torch.zeros(8, 3))
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 比较训练前后各个参数的值是否相同
for name, param in net.named_parameters():
print(f"{
name} {
torch.equal(param, storeParam[name])}")
# layer1.weight True
# layer1.bias False
# layer2.weight False
# layer2.bias False
边栏推荐
- JS implements tree data operation through recursion
- 线性代数的本质4 矩阵乘法与线性复合变换
- 【万人独木桥】那个夏天—后高考生活该如何安排?
- The securities and futures industry meets the new data regulation, and IP guard helps improve data security management
- Start from 0 to build a high-performance R & D full stack team
- kubernetes 设置 Master 可调度与不可调度
- 计网面试题
- 十二、进程地址空间(pmap;vdso;mmap)
- Software testing quality and assurance
- Day 2 linked list (simple)
猜你喜欢

深度剖析「圈组」关系系统设计 | 「圈组」技术系列文章

图文,文字预训练方式长期学习ing。

软件测试质量与保证大题

360、清华|Zero和R2D2:一种大规模的中文跨模态基准测试和视觉语言框架

【Question】rxjs/operator takeWhile vs takeUntil

CVPR22 Oral | 港中文提出TransRank: 排序损失+自监督=SOTA

好玩的人脸识别小软件

Practice of Flink CDC + Hudi massive data entering the lake in SF

JS implements tree data operation through recursion

“职” 为等你!| 图书策划编辑(会议论文集方向)
随机推荐
Nucleic acid detection robot
第 2 天 链表(简单)
Google Earth engine (GEE) - country identifier grid dataset
Niuke Mianjing 02
How can the team be dissolved...
87. (leaflet house) leaflet military plotting - straight arrow modification
Google Earth Engine(GEE)——国家标识符网格数据集
特权应用权限配置
搜狐员工遭遇工资补助诈骗 黑产与灰产有何区别 又要如何溯源?
【Question】rxjs/operator takeWhile vs takeUntil
CVPR22 Oral | 港中文提出TransRank: 排序损失+自监督=SOTA
LVS+Keepalived高可用群集
On line monitoring of oil content in compressed air of power plant with PID photo ionization detector
Common colors: RGB, gray value, color extraction value, transparency.
In 2021, the revenue of China's electronic components will be the first, revealing the development path of precision and high quality of Tencent
JS implements tree data operation through recursion
B站教学 手把手教你使用YOLOV5之口罩检测项目 最全记录详解 ( 深度学习 / 目标检测 / pytorch )
《浅谈数组》
高性能计算(2)——万丈高楼平地起
软件测试基础