当前位置:网站首页>Alexnet (pytoch Implementation)
Alexnet (pytoch Implementation)
2022-07-27 00:17:00 【Ap21ril】
AlexNet(Pytorch Realization )
1. model.py
import torch.nn as nn
import torch
class AlexNet(nn.Module):
def __init__(self, num_classes=1000, init_weights=False):
super(AlexNet, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2), # input[3, 224, 224] output[48, 55, 55]
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2), # output[48, 27, 27]
nn.Conv2d(48, 128, kernel_size=5, padding=2), # output[128, 27, 27]
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2), # output[128, 13, 13]
nn.Conv2d(128, 192, kernel_size=3, padding=1), # output[192, 13, 13]
nn.ReLU(inplace=True),
nn.Conv2d(192, 192, kernel_size=3, padding=1), # output[192, 13, 13]
nn.ReLU(inplace=True),
nn.Conv2d(192, 128, kernel_size=3, padding=1), # output[128, 13, 13]
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2), # output[128, 6, 6]
)
self.classifier = nn.Sequential(
nn.Dropout(p=0.5),
nn.Linear(128 * 6 * 6, 2048),
nn.ReLU(inplace=True),
nn.Dropout(p=0.5),
nn.Linear(2048, 2048),
nn.ReLU(inplace=True),
nn.Linear(2048, num_classes),
)
if init_weights:
self._initialize_weights()
def forward(self, x):
x = self.features(x)
# Before flattening :torch.Size([32, 128, 6, 6])
# print(f' Before flattening :{x.shape}')
x = torch.flatten(x, start_dim=1)
# print(f' After flattening :{x.shape}')
# After flattening :torch.Size([32, 4608])
x = self.classifier(x)
return x
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
2. train.py
import json
import os
import time
import torch
import torch.nn as nn
import numpy as np
from matplotlib import pyplot as plt
from torch import optim
from torchvision import datasets,transforms,utils
from model import AlexNet
def main():
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Using {} device'.format(device))
# Preprocess the dataset ,transforms
data_transform={
"train":transforms.Compose([
# Data to enhance
transforms.RandomResizedCrop(224), # Randomly cut the size from the training set to 224x224 Region
transforms.RandomHorizontalFlip(), # random invert
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
]),
"val": transforms.Compose([
# Data to enhance
transforms.Resize((224,224)), # The image of the verification set must be 224x224
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
}
# Get data set
data_root = os.path.abspath(os.path.join(os.getcwd())) # Get dataset path
print(data_root)
image_path = data_root+'/data_set/flower_data/' # Set the image path
train_dataset = datasets.ImageFolder(root=image_path+'/train',
transform=data_transform['train'])
# Return correspondence
flower_list = train_dataset.class_to_idx
# {'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}
cla_dict = dict((val,key) for key,val in flower_list.items())
# {0: 'daisy', 1: 'dandelion', 2: 'roses', 3: 'sunflowers', 4: 'tulips'}
# Write dictionary information into json In file
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)
batch_size = 32
# Load training set
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size = batch_size,
shuffle = True,
num_workers=0)
# Load validation set
validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
transform=data_transform['val'])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,
batch_size=4, shuffle=False,
num_workers=0)
# Build a model
net = AlexNet(num_classes=5,init_weights=True)
net.to(device)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(),lr=0.0002)
save_path = './AlexNet.pth'
best_acc = 0.0
epochs = 10
for epoch in range(epochs):
# model training
net.train()
running_loss = 0
t1 = time.perf_counter()
for step,data in enumerate(train_loader,start=0):
images,labels = data
optimizer.zero_grad()
outputs = net(images.to(device))
loss = loss_function(outputs,labels.to(device))
loss.backward()
optimizer.step()
running_loss += loss.item()
rate = (step+1)/len(train_loader)
a = '*'*int(rate*50)
b = '.'*int((1-rate)*50)
print('\rtrain loss: :{:^3.0f}%[{}->{}]{:.3f}'.format(int(rate*100),a,b,loss),end="")
print()
print(time.perf_counter()-t1)
# Model validation
net.eval()
acc = 0.0
with torch.no_grad():
for data_test in validate_loader:
val_images,val_labels = data_test
outputs = net(val_images.to(device))
predict_y = torch.max(outputs,dim=1)[1]
acc += (predict_y==val_labels.to(device)).sum().item()
acc_val = acc/val_num
if acc_val>best_acc:
best_acc=acc_val
torch.save(net.state_dict(),save_path)
print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' %
(epoch + 1, running_loss / step, acc_val))
if __name__ == '__main__':
main()
3. predict.py
import os
import json
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
from model import AlexNet
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_transform = transforms.Compose(
[transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# load image
img_path = "rose.jpg"
assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
img = Image.open(img_path)
plt.imshow(img)
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)
# read class_indict
json_path = './class_indices.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
with open(json_path, "r") as f:
class_indict = json.load(f)
# create model
model = AlexNet(num_classes=5).to(device)
# load model weights
weights_path = "./AlexNet.pth"
assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
model.load_state_dict(torch.load(weights_path))
model.eval()
with torch.no_grad():
# predict class
output = torch.squeeze(model(img.to(device))).cpu()
predict = torch.softmax(output, dim=0)
predict_cla = torch.argmax(predict).numpy()
print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)],
predict[predict_cla].numpy())
plt.title(print_res)
for i in range(len(predict)):
print("class: {:10} prob: {:.3}".format(class_indict[str(i)],
predict[i].numpy()))
plt.show()
if __name__ == '__main__':
main()
# import torch
# from model import AlexNet
# from PIL import Image
# from torchvision import transforms
# import matplotlib.pyplot as plt
# import json
# def main():
# data_transform = transforms.Compose([
# transforms.Resize((224,224)),
# transforms.ToTensor(),
# transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
# ])
#
# # Loading pictures
# img = Image.open('rose.jpg')
# plt.imshow(img)
# plt.show()
# img = data_transform(img)
# # Expand a dimension , add to batch dimension
# img = torch.unsqueeze(img,dim=0)
#
# # analysis json
# try:
# json_file = open('./class_indices.json','r')
# class_indict = json.load(json_file)
# except Exception as e:
# print(e)
# exit(-1)
#
# # Creating models
# model = AlexNet(num_classes=5)
#
# model_weight_path = './AlexNet.pth'
# model.load_state_dict(torch.load(model_weight_path))
# model.eval()
# with torch.no_grad():
# output = torch.squeeze(model(img))
# predict = torch.softmax(output,dim=0)
# predict_cla = torch.argmax(predict).numpy()
#
# print(class_indict[str(predict_cla)],predict[predict_cla].item())
# plt.show()
#
# if __name__=='__main__':
# main()
边栏推荐
- Force deduction 155 questions, minimum stack
- 卷积神经网络——LeNet(pytorch实现)
- 解析网页的完整回顾
- [netding Cup 2018] Fakebook records
- 第1章 需求分析与ssm环境准备
- MySQL数据库复杂操作:数据库约束,查询/连接表操作
- Practice of intelligent code reconstruction of Zhongyuan bank
- Deep learning of parameter adjustment skills
- Double. isNaN(double var)
- 爬虫解析网页的 对象.元素名方法
猜你喜欢
随机推荐
RecBole使用1
Transpose convolution correlation
yolov5在jetson nano上的部署 deepstream
[Gorm] model relationship -hasone
文件上传到OSS文件服务器
Design of electronic scale based on 51 single chip microcomputer
AlexNet(Pytorch实现)
Double. isNaN(double var)
14_ Basic list
LeetCode题目——数组篇
Several search terms
C and pointer Chapter 18 runtime environment 18.7 problems
Chapter 2 develop user traffic interceptors
LeetCode——哈希表篇
In depth interpretation of the investment logic of the consortium's participation in the privatization of Twitter
Dynamic memory management
push to origin/master was rejected 错误解决方法
深度学习调参技巧
第2章 开发用户流量拦截器
JUnit、JMockit、Mockito、PowerMockito





![[Gorm] model relationship -hasone](/img/90/3069059ddd09dc538c10f76d659b08.png)



