当前位置:网站首页>Training and recognition of handwritten digits through the lenet-5 network built by pytorch
Training and recognition of handwritten digits through the lenet-5 network built by pytorch
2022-06-28 10:15:00 【fengbingchun】
call PyTorch The related interface implements a LeNet-5 The Internet , And then through MNIST Dataset training model , Finally, the generated model is predicted , It mainly includes 2 Most of the : Training and forecasting
1. Training part :
(1). load MNIST Data sets , By calling TorchVision Interface implementation in the module , Zoom each image to 32*32 size , The number of small batch datasets is set to 32;
(2). Set the initial value of network parameters , This ensures that the initial value is fixed for each retraining , Easy to find positioning problems ;
(3). Design LeNet-5 The Internet , And instantiate a network object , Reload the __init__ and forward Two functions , Used layer Include Conv2d、AvgPool2d、Linear; Activate function using Tanh:
(4). Specify optimization algorithm , Here the Adam;
(5). Specify the loss function , Here the CrossEntropyLoss;
(6). Training ,epochs Set to 10, Give the results of each training ;
(7). Save the model , Recommended state_dict.
The code snippet is as follows :
def load_mnist_dataset(img_size, batch_size):
''' Download and load mnist Data sets
img_size: Image size , Same width, height and length
batch_size: Number of small batch datasets
'''
# Yes PIL Zoom the image first , And then convert to tensor type
transforms_ = transforms.Compose([transforms.Resize(size=(img_size, img_size)), transforms.ToTensor()])
''' download MNIST Data sets
root: mnist Data set storage directory name
train: Optional parameters , The default is True; if True, From MNIST/processed/training.pt Create a dataset ; if False, From MNIST/processed/test.pt Create a dataset
transform: Optional parameters , The default is None; receive PIL Image and processing
target_transform: Optional parameters , The default is None
download: Optional parameters , The default is False; if True, Then download the data set from the network to root Specified directory
'''
train_dataset = datasets.MNIST(root="mnist_data", train=True, transform=transforms_, target_transform=None, download=True)
valid_dataset = datasets.MNIST(root="mnist_data", train=False, transform=transforms_, target_transform=None, download=False)
# load MNIST Data sets :shuffle by True, Every time epoch Re - shuffle the order
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(dataset=valid_dataset, batch_size=batch_size, shuffle=False)
return train_loader, valid_loader, train_dataset, valid_dataset
class LeNet5(nn.Module):
''' structure lenet The Internet '''
def __init__(self, n_classes: int) -> None:
super(LeNet5, self).__init__() # Call the parent class Module Construction method of
# n_classes: Number of categories
# nn.Sequential: Sequence containers ,Module Will be added in the order they are passed in the constructor , It allows the entire container to be treated as a single module
self.feature_extractor = nn.Sequential( # Input 32*32
nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=0), # Convolution layer ,28*28*6
nn.Tanh(), # Activation function Tanh, Make its value range in (-1, 1) Inside
nn.AvgPool2d(kernel_size=2, stride=None, padding=0), # Average pooling layer ,14*14*6
nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=0), # 10*10*16
nn.Tanh(),
nn.AvgPool2d(kernel_size=2, stride=None, padding=0), # 5*5*16
nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5, stride=1, padding=0), # 1*1*120
nn.Tanh()
)
self.classifier = nn.Sequential( # Input 1*1*120
nn.Linear(in_features=120, out_features=84), # Fully connected layer ,84
nn.Tanh(),
nn.Linear(in_features=84, out_features=n_classes) # 10
)
# LeNet5 Inherit nn.Module, Definition forward After the function ,backward The function will use Autograd Be automatically implemented
# Just instantiate one LeNet5 Object and pass in the corresponding parameters x You can automatically call forward function
def forward(self, x: Tensor):
x = self.feature_extractor(x)
x = torch.flatten(input=x, start_dim=1) # Flatten the input as specified ,start_dim=1 Then the first dimension does not change , Flatten the back
logits = self.classifier(x)
probs = F.softmax(input=logits, dim=1) # Activation function softmax: So that the range of each element is (0,1) Between , And the sum of all the elements is 1
return logits, probs
def validate(valid_loader, model, criterion, device):
'''Function for the validation step of the training loop'''
model.eval() # Set the network to evaluation mode
running_loss = 0
for X, y_true in valid_loader:
X = X.to(device) # Import the data to the specified device (cpu or gpu)
y_true = y_true.to(device)
# Forward pass and record loss
y_hat, _ = model(X) # Forward propagation : call Module Of __call__ Method , The specified network... Is called within this method ( Such as LeNet5) Of forward Method
loss = criterion(y_hat, y_true) # Calculation loss, ditto , adopt __call__ Method call to specify the loss function class ( Such as CrossEntropyLoss) Medium forward Method
running_loss += loss.item() * X.size(0)
epoch_loss = running_loss / len(valid_loader.dataset)
return model, epoch_loss
def get_accuracy(model, data_loader, device):
'''Function for computing the accuracy of the predictions over the entire data_loader'''
correct_pred = 0
n = 0
with torch.no_grad(): # Temporarily put all the... In the loop Tensor Of requires_grad Flag set to False, No more calculations Tensor Gradient of ( Automatic derivation )
model.eval() # Set the network to evaluation mode
for X, y_true in data_loader:
X = X.to(device) # Import the data to the specified device (cpu or gpu)
y_true = y_true.to(device)
_, y_prob = model(X) # y_prob.size(): troch.Size([32, 10]): [cols, rows]
# torch.max(input): return Tensor The maximum value of all elements in
# torch.max(input, dim): By dimension dim Return maximum , And return the index
# dim=0: Returns the element with the largest value in each column , And return the index
# dim=1: Return the element with the maximum value in each row , And return the index
_, predicted_labels = torch.max(y_prob, 1)
n += y_true.size(0)
correct_pred += (predicted_labels == y_true).sum()
return correct_pred.float() / n
def train(train_loader, model, criterion, optimizer, device):
'''Function for the training step of the training loop'''
model.train() # Set the network to training mode
running_loss = 0
for X, y_true in train_loader: # First call DataLoader Class __iter__ function , Then the loop calls _DataLoaderIter Class __next__ function
# X.size(shape: [n,c,h,w]): torch.Size([32, 1, 32, 32]); y_true.size: torch.Size([32]); n by batch_size
optimizer.zero_grad() # Reset the gradient in the optimization algorithm to 0, You need to call it before calculating the gradient of the next small batch dataset , Otherwise, the gradient will accumulate into the existing gradient
# take Tensor Import the data to the specified device (cpu or gpu)
X = X.to(device)
y_true = y_true.to(device)
y_hat, _ = model(X) # Forward propagation : call Module Of __call__ Method , The specified network... Is called within this method ( Such as LeNet5) Of forward Method
# y_hat.size(): torch.Size([32, 10]); _.size(): torch.Size([32, 10])
loss = criterion(y_hat, y_true) # Calculation loss, ditto , adopt __call__ Method call to specify the loss function class ( Such as CrossEntropyLoss) Medium forward Method
running_loss += loss.item() * X.size(0)
loss.backward() # Back propagation , Use Autograd Automatically calculate the current gradient of the scalar
optimizer.step() # Update the network parameters according to the gradient , Optimizer pass .grad The gradient stored in is used to adjust each parameter
epoch_loss = running_loss / len(train_loader.dataset)
return model, optimizer, epoch_loss
def training_loop(model, criterion, optimizer, train_loader, valid_loader, epochs, device, print_every=1):
'''Function defining the entire training loop
model: Network objects
criterion: Loss function object
optimizer: Optimization algorithm object
train_loader: Training dataset objects
valid_loader: Test dataset objects
epochs: The number of times the entire training data set is retrained
device: Specified in the cpu On or on gpu Up operation
print_every: Print the training results every few times
'''
train_losses = []
valid_losses = []
for epoch in range(0, epochs):
model, optimizer, train_loss = train(train_loader, model, criterion, optimizer, device)
train_losses.append(train_loss)
# After each training, it is evaluated through the test data set
with torch.no_grad(): # Temporarily put all the... In the loop Tensor Of requires_grad Flag set to False, No more calculations Tensor Gradient of ( Automatic derivation )
model, valid_loss = validate(valid_loader, model, criterion, device)
valid_losses.append(valid_loss)
if epoch % print_every == (print_every - 1):
train_acc = get_accuracy(model, train_loader, device=device)
valid_acc = get_accuracy(model, valid_loader, device=device)
print(f' {datetime.now().time().replace(microsecond=0)}:'
f' Epoch: {epoch}', f' Train loss: {train_loss:.4f}', f' Valid loss: {valid_loss:.4f}'
f' Train accuracy: {100 * train_acc:.2f}', f' Valid accuracy: {100 * valid_acc:.2f}')
return model, optimizer, (train_losses, valid_losses)
def train_and_save_model():
print("#### start training ... ####")
print("1. load mnist dataset")
train_loader, valid_loader, _, _ = load_mnist_dataset(img_size=32, batch_size=32)
print("2. fixed random init value")
# Used to set random initialization ; If you do not set the network initialization at each training is random , The result is uncertain ; If set , Then each initialization is fixed
torch.manual_seed(seed=42)
#print("value:", torch.rand(1), torch.rand(1), torch.rand(1)) # Run multiple times , The output value is the same every time ,[0, 1)
print("3. instantiate lenet net object")
model = LeNet5(n_classes=10).to('cpu') # stay CPU Up operation
print("4. specify the optimization algorithm: Adam")
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.001) # Define optimization algorithms :Adam It is an optimization algorithm based on gradient
print("5. specify the loss function: CrossEntropyLoss")
criterion = nn.CrossEntropyLoss() # Define the loss function : Cross entropy loss
print("6. repeated training")
model, _, _ = training_loop(model, criterion, optimizer, train_loader, valid_loader, epochs=10, device='cpu') # epochs The number of times the entire data set is trained for traversal
print("7. save model")
model_name = "../../../data/Lenet-5.pth"
#torch.save(model, model_name) # Save the entire model , Corresponding to model = torch.load
torch.save(model.state_dict(), model_name) # recommend : Only the trained parameters of the model are saved , Corresponding to model.load_state_dict(torch.load)The results are shown below :

2. Handwritten digital image recognition part :
(1). Load model , Recommended load_state_dict, Corresponds to the... Used when saving the model state_dict;
(2). Set network to evaluation mode ;
(3). Prepare the test image , altogether 10 picture ,0 To 9 One each , As shown in the figure below , Be careful : The background color of the training image is black , The background color of the test image is white :

(4). Identify each image in turn .
The code snippet is shown below :
def list_files(filepath, filetype):
''' Traverse the specified files in the specified directory '''
paths = []
for root, dirs, files in os.walk(filepath):
for file in files:
if file.lower().endswith(filetype.lower()):
paths.append(os.path.join(root, file))
return paths
def get_image_label(image_name, image_name_suffix):
''' Get the test image correspondence label'''
index = image_name.rfind("/")
if index == -1:
print(f"Error: image name {image_name} is not supported")
sub = image_name[index+1:]
label = sub[:len(sub)-len(image_name_suffix)]
return label
def image_predict():
print("#### start predicting ... ####")
print("1. load model")
model_name = "../../../data/Lenet-5.pth"
model = LeNet5(n_classes=10).to('cpu') # Instantiate a network object
model.load_state_dict(torch.load(model_name)) # Load model
print("2. set net to evaluate mode")
model.eval()
print("3. prepare test images")
image_path = "../../../data/image/handwritten_digits/"
image_name_suffix = ".png"
images_name = list_files(image_path, image_name_suffix)
print("4. image recognition")
with torch.no_grad():
for image_name in images_name:
#print("image name:", image_name)
label = get_image_label(image_name, image_name_suffix)
img = cv2.imread(image_name, cv2.IMREAD_GRAYSCALE)
img = cv2.resize(img, (32, 32))
# MNIST The image background is black , The background color of the test image is white , Conversion is required before identification
img = cv2.bitwise_not(img)
#print("img shape:", img.shape)
# take opencv image The switch to pytorch tensor
transform = transforms.ToTensor()
tensor = transform(img) # tensor shape: torch.Size([1, 32, 32])
tensor = tensor.unsqueeze(0) # tensor shape: torch.Size([1, 1, 32, 32])
#print("tensor shape:", tensor.shape)
_, y_prob = model(tensor)
_, predicted_label = torch.max(y_prob, 1)
print(f" predicted label: {predicted_label.item()}, ground truth label: {label}")The execution result is shown in the figure below :

边栏推荐
- bad zipfile offset (local header sig)
- Custom exception classes and exercises
- Decorator
- ffmpeg录音录像
- How to distinguish and define DQL, DML, DDL and DCL in SQL
- Understand 12 convolution methods (including 1x1 convolution, transpose convolution and deep separable convolution)
- Install using snap in opencloudos NET 6
- [unity] built in rendering pipeline to URP
- fastposter v2.8.4 发布 电商海报生成器
- sqlcmd 连接数据库报错
猜你喜欢

mysql打不开,闪退

一种跳板机的实现思路

Starting from full power to accelerate brand renewal, Chang'an electric and electrification products sound the "assembly number"

Starting from full power to accelerate brand renewal, Chang'an electric and electrification products sound the "assembly number"

idea连接sql sever失败

PMP Exam key summary VI - chart arrangement

How to view the web password saved by Google browser

手把手教你处理 JS 逆向之 SVG 映射

老板叫我写个APP自动化--Yaml文件读取--内附整个框架源码

学习机器学习的最佳路径是什么
随机推荐
各位大佬,问下Mysql不支持EARLIEST_OFFSET模式吗?Unsupported star
Bridge mode
[Unity][ECS]学习笔记(二)
关于FTP的协议了解
Xiaomi's payment company was fined 120000 yuan, involving the illegal opening of payment accounts, etc.: Lei Jun is the legal representative, and the products include MIUI wallet app
Google开源依赖注入框架-Guice指南
Install using snap in opencloudos NET 6
ffmpeg录音录像
[200 opencv routines] 213 Draw circle
PHP curl forged IP address and header information code instance - Alibaba cloud
引入 flink-sql-mysql-cdc-2.2.1 好多依赖冲突,有解决的吗?
Starting from full power to accelerate brand renewal, Chang'an electric and electrification products sound the "assembly number"
Key summary V of PMP examination - execution process group
Settings of gift giving module and other custom controls in one-to-one video chat system code
Generate token
Summary of MySQL basic knowledge points
股票开户用中金证券经理发的开户二维码安全吗?知道的给说一下吧
Correct conversion between JSON data and list collection
As shown in the figure, the SQL row is used to convert the original table of Figure 1. Figure 2 wants to convert it
Thread lifecycle