当前位置:网站首页>Alexnet (pytoch Implementation)
Alexnet (pytoch Implementation)
2022-07-27 00:17:00 【Ap21ril】
AlexNet(Pytorch Realization )
1. model.py
import torch.nn as nn
import torch
class AlexNet(nn.Module):
def __init__(self, num_classes=1000, init_weights=False):
super(AlexNet, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2), # input[3, 224, 224] output[48, 55, 55]
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2), # output[48, 27, 27]
nn.Conv2d(48, 128, kernel_size=5, padding=2), # output[128, 27, 27]
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2), # output[128, 13, 13]
nn.Conv2d(128, 192, kernel_size=3, padding=1), # output[192, 13, 13]
nn.ReLU(inplace=True),
nn.Conv2d(192, 192, kernel_size=3, padding=1), # output[192, 13, 13]
nn.ReLU(inplace=True),
nn.Conv2d(192, 128, kernel_size=3, padding=1), # output[128, 13, 13]
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2), # output[128, 6, 6]
)
self.classifier = nn.Sequential(
nn.Dropout(p=0.5),
nn.Linear(128 * 6 * 6, 2048),
nn.ReLU(inplace=True),
nn.Dropout(p=0.5),
nn.Linear(2048, 2048),
nn.ReLU(inplace=True),
nn.Linear(2048, num_classes),
)
if init_weights:
self._initialize_weights()
def forward(self, x):
x = self.features(x)
# Before flattening :torch.Size([32, 128, 6, 6])
# print(f' Before flattening :{x.shape}')
x = torch.flatten(x, start_dim=1)
# print(f' After flattening :{x.shape}')
# After flattening :torch.Size([32, 4608])
x = self.classifier(x)
return x
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
2. train.py
import json
import os
import time
import torch
import torch.nn as nn
import numpy as np
from matplotlib import pyplot as plt
from torch import optim
from torchvision import datasets,transforms,utils
from model import AlexNet
def main():
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Using {} device'.format(device))
# Preprocess the dataset ,transforms
data_transform={
"train":transforms.Compose([
# Data to enhance
transforms.RandomResizedCrop(224), # Randomly cut the size from the training set to 224x224 Region
transforms.RandomHorizontalFlip(), # random invert
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
]),
"val": transforms.Compose([
# Data to enhance
transforms.Resize((224,224)), # The image of the verification set must be 224x224
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
}
# Get data set
data_root = os.path.abspath(os.path.join(os.getcwd())) # Get dataset path
print(data_root)
image_path = data_root+'/data_set/flower_data/' # Set the image path
train_dataset = datasets.ImageFolder(root=image_path+'/train',
transform=data_transform['train'])
# Return correspondence
flower_list = train_dataset.class_to_idx
# {'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}
cla_dict = dict((val,key) for key,val in flower_list.items())
# {0: 'daisy', 1: 'dandelion', 2: 'roses', 3: 'sunflowers', 4: 'tulips'}
# Write dictionary information into json In file
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)
batch_size = 32
# Load training set
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size = batch_size,
shuffle = True,
num_workers=0)
# Load validation set
validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
transform=data_transform['val'])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,
batch_size=4, shuffle=False,
num_workers=0)
# Build a model
net = AlexNet(num_classes=5,init_weights=True)
net.to(device)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(),lr=0.0002)
save_path = './AlexNet.pth'
best_acc = 0.0
epochs = 10
for epoch in range(epochs):
# model training
net.train()
running_loss = 0
t1 = time.perf_counter()
for step,data in enumerate(train_loader,start=0):
images,labels = data
optimizer.zero_grad()
outputs = net(images.to(device))
loss = loss_function(outputs,labels.to(device))
loss.backward()
optimizer.step()
running_loss += loss.item()
rate = (step+1)/len(train_loader)
a = '*'*int(rate*50)
b = '.'*int((1-rate)*50)
print('\rtrain loss: :{:^3.0f}%[{}->{}]{:.3f}'.format(int(rate*100),a,b,loss),end="")
print()
print(time.perf_counter()-t1)
# Model validation
net.eval()
acc = 0.0
with torch.no_grad():
for data_test in validate_loader:
val_images,val_labels = data_test
outputs = net(val_images.to(device))
predict_y = torch.max(outputs,dim=1)[1]
acc += (predict_y==val_labels.to(device)).sum().item()
acc_val = acc/val_num
if acc_val>best_acc:
best_acc=acc_val
torch.save(net.state_dict(),save_path)
print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' %
(epoch + 1, running_loss / step, acc_val))
if __name__ == '__main__':
main()
3. predict.py
import os
import json
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
from model import AlexNet
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_transform = transforms.Compose(
[transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# load image
img_path = "rose.jpg"
assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
img = Image.open(img_path)
plt.imshow(img)
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)
# read class_indict
json_path = './class_indices.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
with open(json_path, "r") as f:
class_indict = json.load(f)
# create model
model = AlexNet(num_classes=5).to(device)
# load model weights
weights_path = "./AlexNet.pth"
assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
model.load_state_dict(torch.load(weights_path))
model.eval()
with torch.no_grad():
# predict class
output = torch.squeeze(model(img.to(device))).cpu()
predict = torch.softmax(output, dim=0)
predict_cla = torch.argmax(predict).numpy()
print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)],
predict[predict_cla].numpy())
plt.title(print_res)
for i in range(len(predict)):
print("class: {:10} prob: {:.3}".format(class_indict[str(i)],
predict[i].numpy()))
plt.show()
if __name__ == '__main__':
main()
# import torch
# from model import AlexNet
# from PIL import Image
# from torchvision import transforms
# import matplotlib.pyplot as plt
# import json
# def main():
# data_transform = transforms.Compose([
# transforms.Resize((224,224)),
# transforms.ToTensor(),
# transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
# ])
#
# # Loading pictures
# img = Image.open('rose.jpg')
# plt.imshow(img)
# plt.show()
# img = data_transform(img)
# # Expand a dimension , add to batch dimension
# img = torch.unsqueeze(img,dim=0)
#
# # analysis json
# try:
# json_file = open('./class_indices.json','r')
# class_indict = json.load(json_file)
# except Exception as e:
# print(e)
# exit(-1)
#
# # Creating models
# model = AlexNet(num_classes=5)
#
# model_weight_path = './AlexNet.pth'
# model.load_state_dict(torch.load(model_weight_path))
# model.eval()
# with torch.no_grad():
# output = torch.squeeze(model(img))
# predict = torch.softmax(output,dim=0)
# predict_cla = torch.argmax(predict).numpy()
#
# print(class_indict[str(predict_cla)],predict[predict_cla].item())
# plt.show()
#
# if __name__=='__main__':
# main()
边栏推荐
猜你喜欢

MySQL optimization

push to origin/master was rejected 错误解决方法

MVC three-tier architecture

LeetCode——哈希表篇

Meeting OA my meeting

Deep learning of parameter adjustment skills

C and pointer Chapter 18 runtime environment 18.1 judgment of runtime environment

第1章 拦截器入门及使用技巧

机器人学台大林教授课程笔记

3 esp8266 nodemcu network server
随机推荐
Practice of data storage scheme in distributed system
13_ conditional rendering
LeetCode题目——数组篇
数据库:MySQL基础+CRUD基本操作
DHCP, VLAN, NAT, large comprehensive experiment
Complete backpack and 01 Backpack
SSRF (server side request forgery) -- Principle & bypass & Defense
The difference between SQL join and related subinquiry
Abstract classes and interfaces (sorting out some knowledge points)
20220720折腾deeplabcut2
Relationship between limit, continuity, partial derivative and total differential of multivariate function (learning notes)
Azure synapse analytics Performance Optimization Guide (3) -- optimize performance using materialized views (Part 2)
Chapter 1 Introduction and use skills of interceptors
Paging plug-in -- PageHelper
Deep learning of parameter adjustment skills
RecBole使用1
Deploy yolov5 error reporting in pycharm
14_ Basic list
[C language] classic recursion problem
Practice of intelligent code reconstruction of Zhongyuan bank