当前位置:网站首页>2021 Li Hongyi machine learning (2): pytorch

2021 Li Hongyi machine learning (2): pytorch

2022-07-05 02:38:00 Three ears 01

1 Basic content

1.1 produce tensor

import torch
import numpy as np
x = torch.tensor([[1, -1], [-1, 1]])
y = torch.from_numpy(np.array([[1, -1], [-1, 1]]))
x, y
(tensor([[ 1, -1],
         [-1,  1]]),
 tensor([[ 1, -1],
         [-1,  1]], dtype=torch.int32))

1.2 squeeze Compress dimensions

x = torch.zeros([1, 2, 3])
y = x.squeeze(0)
x, y, x.shape, y.shape
(tensor([[[0., 0., 0.],
          [0., 0., 0.]]]),
 tensor([[0., 0., 0.],
         [0., 0., 0.]]),
 torch.Size([1, 2, 3]),
 torch.Size([2, 3]))

1.3 unsqueeze Exhibition dimension

x = torch.zeros([2, 3])
y = x.unsqueeze(1)  # dim = 1
z = x.unsqueeze(2)  # dim = 2
x, y, z, x.shape, y.shape, z.shape
(tensor([[0., 0., 0.],
         [0., 0., 0.]]),
 tensor([[[0., 0., 0.]],
 
         [[0., 0., 0.]]]),
 tensor([[[0.],
          [0.],
          [0.]],
 
         [[0.],
          [0.],
          [0.]]]),
 torch.Size([2, 3]),
 torch.Size([2, 1, 3]),
 torch.Size([2, 3, 1]))

1.4 transpose Transposition

x = torch.zeros([2, 3])
y = x.transpose(0, 1)
x.shape, y.shape
(torch.Size([2, 3]), torch.Size([3, 2]))

1.5 cat Appoint dimension Connect multiple tensor

x = torch.zeros(2,1,3)
y = torch.zeros(2,3,3)
z = torch.zeros(2,2,3)
w = torch.cat([x, y, z], dim=1)
w.shape
torch.Size([2, 6, 3])

1.6 Calculate the gradient

 Insert picture description here

2 neural network

 Insert picture description here

2.1 Read data

 Insert picture description here
 Insert picture description here
The two are inclusive relations .

2.2 torch.nn

layer = torch.nn.Linear(32, 64)
layer.weight.shape, layer.bias.shape
(torch.Size([64, 32]), torch.Size([64]))
nn.Sigmoid()
nn.ReLU()
nn.MSELoss()  #  Mostly used for linear regression 
nn.CrossEntropyLoss()  # It's mostly used for classification 

 Insert picture description here

2.3 torch.optim

SGD:

torch.optin.sGD(params, lr , momentum = 0)

3 The whole process

3.1 training

 Insert picture description here
 Insert picture description here

3.2 validation

 Insert picture description here

3.3 testing

 Insert picture description here

4 Download and load

4.1 Save

torch.save( model.state_dict(), path)

4.2 Load

ckpt = torch.load(path)
model.load_state_dict(ckpt)
原网站

版权声明
本文为[Three ears 01]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/02/202202140912208513.html