当前位置:网站首页>AlexNet(Pytorch实现)
AlexNet(Pytorch实现)
2022-07-26 22:37:00 【Ap21ril】
AlexNet(Pytorch实现)
1. model.py
import torch.nn as nn
import torch
class AlexNet(nn.Module):
def __init__(self, num_classes=1000, init_weights=False):
super(AlexNet, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2), # input[3, 224, 224] output[48, 55, 55]
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2), # output[48, 27, 27]
nn.Conv2d(48, 128, kernel_size=5, padding=2), # output[128, 27, 27]
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2), # output[128, 13, 13]
nn.Conv2d(128, 192, kernel_size=3, padding=1), # output[192, 13, 13]
nn.ReLU(inplace=True),
nn.Conv2d(192, 192, kernel_size=3, padding=1), # output[192, 13, 13]
nn.ReLU(inplace=True),
nn.Conv2d(192, 128, kernel_size=3, padding=1), # output[128, 13, 13]
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2), # output[128, 6, 6]
)
self.classifier = nn.Sequential(
nn.Dropout(p=0.5),
nn.Linear(128 * 6 * 6, 2048),
nn.ReLU(inplace=True),
nn.Dropout(p=0.5),
nn.Linear(2048, 2048),
nn.ReLU(inplace=True),
nn.Linear(2048, num_classes),
)
if init_weights:
self._initialize_weights()
def forward(self, x):
x = self.features(x)
# 展平前:torch.Size([32, 128, 6, 6])
# print(f'展平前:{x.shape}')
x = torch.flatten(x, start_dim=1)
# print(f'展平后:{x.shape}')
# 展平后:torch.Size([32, 4608])
x = self.classifier(x)
return x
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
2. train.py
import json
import os
import time
import torch
import torch.nn as nn
import numpy as np
from matplotlib import pyplot as plt
from torch import optim
from torchvision import datasets,transforms,utils
from model import AlexNet
def main():
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Using {} device'.format(device))
# 对数据集进行预处理,transforms
data_transform={
"train":transforms.Compose([
# 数据增强
transforms.RandomResizedCrop(224), # 从训练集中随机裁剪尺寸为224x224的区域
transforms.RandomHorizontalFlip(), # 随机翻转
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
]),
"val": transforms.Compose([
# 数据增强
transforms.Resize((224,224)), # 验证集的图片必须是224x224
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
}
# 获取数据集
data_root = os.path.abspath(os.path.join(os.getcwd())) # 获取数据集路径
print(data_root)
image_path = data_root+'/data_set/flower_data/' # 设置图片路径
train_dataset = datasets.ImageFolder(root=image_path+'/train',
transform=data_transform['train'])
# 返回对应关系
flower_list = train_dataset.class_to_idx
# {'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}
cla_dict = dict((val,key) for key,val in flower_list.items())
# {0: 'daisy', 1: 'dandelion', 2: 'roses', 3: 'sunflowers', 4: 'tulips'}
# 把字典信息写入到json文件中
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)
batch_size = 32
# 加载训练集
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size = batch_size,
shuffle = True,
num_workers=0)
# 加载验证集
validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
transform=data_transform['val'])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,
batch_size=4, shuffle=False,
num_workers=0)
# 构造模型
net = AlexNet(num_classes=5,init_weights=True)
net.to(device)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(),lr=0.0002)
save_path = './AlexNet.pth'
best_acc = 0.0
epochs = 10
for epoch in range(epochs):
# 模型训练
net.train()
running_loss = 0
t1 = time.perf_counter()
for step,data in enumerate(train_loader,start=0):
images,labels = data
optimizer.zero_grad()
outputs = net(images.to(device))
loss = loss_function(outputs,labels.to(device))
loss.backward()
optimizer.step()
running_loss += loss.item()
rate = (step+1)/len(train_loader)
a = '*'*int(rate*50)
b = '.'*int((1-rate)*50)
print('\rtrain loss: :{:^3.0f}%[{}->{}]{:.3f}'.format(int(rate*100),a,b,loss),end="")
print()
print(time.perf_counter()-t1)
# 模型验证
net.eval()
acc = 0.0
with torch.no_grad():
for data_test in validate_loader:
val_images,val_labels = data_test
outputs = net(val_images.to(device))
predict_y = torch.max(outputs,dim=1)[1]
acc += (predict_y==val_labels.to(device)).sum().item()
acc_val = acc/val_num
if acc_val>best_acc:
best_acc=acc_val
torch.save(net.state_dict(),save_path)
print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' %
(epoch + 1, running_loss / step, acc_val))
if __name__ == '__main__':
main()
3. predict.py
import os
import json
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
from model import AlexNet
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_transform = transforms.Compose(
[transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# load image
img_path = "rose.jpg"
assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
img = Image.open(img_path)
plt.imshow(img)
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)
# read class_indict
json_path = './class_indices.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
with open(json_path, "r") as f:
class_indict = json.load(f)
# create model
model = AlexNet(num_classes=5).to(device)
# load model weights
weights_path = "./AlexNet.pth"
assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
model.load_state_dict(torch.load(weights_path))
model.eval()
with torch.no_grad():
# predict class
output = torch.squeeze(model(img.to(device))).cpu()
predict = torch.softmax(output, dim=0)
predict_cla = torch.argmax(predict).numpy()
print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)],
predict[predict_cla].numpy())
plt.title(print_res)
for i in range(len(predict)):
print("class: {:10} prob: {:.3}".format(class_indict[str(i)],
predict[i].numpy()))
plt.show()
if __name__ == '__main__':
main()
# import torch
# from model import AlexNet
# from PIL import Image
# from torchvision import transforms
# import matplotlib.pyplot as plt
# import json
# def main():
# data_transform = transforms.Compose([
# transforms.Resize((224,224)),
# transforms.ToTensor(),
# transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
# ])
#
# # 加载图片
# img = Image.open('rose.jpg')
# plt.imshow(img)
# plt.show()
# img = data_transform(img)
# # 扩充一个维度,添加batch维度
# img = torch.unsqueeze(img,dim=0)
#
# #解析json
# try:
# json_file = open('./class_indices.json','r')
# class_indict = json.load(json_file)
# except Exception as e:
# print(e)
# exit(-1)
#
# # 创建模型
# model = AlexNet(num_classes=5)
#
# model_weight_path = './AlexNet.pth'
# model.load_state_dict(torch.load(model_weight_path))
# model.eval()
# with torch.no_grad():
# output = torch.squeeze(model(img))
# predict = torch.softmax(output,dim=0)
# predict_cla = torch.argmax(predict).numpy()
#
# print(class_indict[str(predict_cla)],predict[predict_cla].item())
# plt.show()
#
# if __name__=='__main__':
# main()
边栏推荐
- Opencv camera calibration and distortion correction
- Azure synapse analytics Performance Optimization Guide (3) -- optimize performance using materialized views (Part 2)
- Question 141 of Li Kou: circular linked list
- Tree and binary tree (learning notes)
- 4-4 对象生命周期
- 4-4 object lifecycle
- MVC three-tier architecture
- 买不到的数目
- Chapter 1 Introduction and use skills of interceptors
- 第3章 跨域问题
猜你喜欢

Chapter 1 requirements analysis and SSM environment preparation

At 12:00 on July 17, 2022, the departure of love life on June 28 was basically completed, and it needs to rebound

Practice of intelligent code reconstruction of Zhongyuan bank

Tencent cloud lightweight application server purchase method steps!

Pytorch data pipeline standardized code template

Qunar travel massive indicator data collection and storage
![[literature reading] an investigation on hardware aware vision transformer scaling](/img/3d/6f2cf1fc1e9189e7557703820d021f.png)
[literature reading] an investigation on hardware aware vision transformer scaling

Re understand the life world and ourselves

Paging plug-in -- PageHelper

DHCP, VLAN, NAT, large comprehensive experiment
随机推荐
NFT display guide: how to display your NFT collection
07 design of ponding monitoring system based on 51 single chip microcomputer
[C language] array
The place where the dream begins ---- first knowing C language (2)
At 12:00 on July 17, 2022, the departure of love life on June 28 was basically completed, and it needs to rebound
Iptables prevent nmap scanning and binlog
【C语言】数组
Tensorflow2.0 深度学习运行代码简单教程
[literature reading] an investigation on hardware aware vision transformer scaling
12_ Binding style
In depth interpretation of the investment logic of the consortium's participation in the privatization of Twitter
Number that cannot be bought
The difference between SQL join and related subinquiry
What scenarios are Tencent cloud lightweight application servers suitable for?
Practice of data storage scheme in distributed system
[interview: concurrent Article 27: multithreading: hesitation mode]
04 traditional synchronized lock
Pytorch data pipeline standardized code template
Codeforces C1. Simple Polygon Embedding
【面试:并发篇26:多线程:两阶段终止模式】volatile版本