当前位置:网站首页>基于PyTorch利用CNN对自己的数据集进行分类
基于PyTorch利用CNN对自己的数据集进行分类
2022-07-07 15:40:00 【AI炮灰】
main.py文件
#-*- coding: utf-8 -*-
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms
from pathlib import Path
from model import CNN_MODEL
import torch.nn as nn
import torch.optim as optim
import cv2
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties
import os
from PIL import Image
epoch = 200
train_transformser = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize([224,224]),
transforms.Grayscale(num_output_channels=1),
#transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5),std=(0.5))
])
'''
test_transformser = transforms.Compose([
#transforms.ToPILImage(),
transforms.Resize([224,224]),
transforms.Grayscale(num_output_channels=1),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5),std=(0.5))
])
'''
def prepareData(dir_path):
dir_path = Path(dir_path)
classes = []
for category in dir_path.iterdir():
if category.is_dir():
classes.append(category.name)
images_list = []
labels_list = []
for index,name in enumerate(classes):
class_path = dir_path / name
if not class_path.is_dir():
continue
for img_path in class_path.glob('*.jpg'):
images_list.append(str(img_path))
labels_list.append(int(index))
return images_list,labels_list
class MyDataSet(Dataset):
def __init__(self,dir_path):
self.dir_path = dir_path
self.images, self.labels = prepareData(self.dir_path)
def __len__(self):
return len(self.images)
def __getitem__(self,index):
img_path = self.images[index]
label = self.labels[index]
#img = Image.open(img_path)
img = cv2.imread(img_path)
img = train_transformser(img)
sample = {'image': img, 'label': label}
return sample
def train(model, criterion, optimizer, trainloader, traindataSetLen, testloader, testdataSetLen, epochs=epoch, log_interval=50,learning_rate=0.001):
print('--------- Train Start ---------')
train_loss_history = []
test_loss_history = []
train_acc_history = []
test_acc_history = []
for epoch in range(epochs):
print('epoch:[%d]'%epoch)
model.train()
tarin_running_loss = 0.0
train_accuracy = 0.0
for data in trainloader:
img = data['image']
label = data['label']
output = model(img)
optimizer.zero_grad()
loss = criterion(output, label)
loss.backward()
_,pred = torch.max(output,1)
num_correct = (pred==label).sum()
train_accuracy += num_correct.data.item()
optimizer.step()
tarin_running_loss += loss.item()
#train_loss_history.append(loss.item())
print('[%d] train loss: %.4f , train Accuracy: %.4f' %(epoch + 1, tarin_running_loss / traindataSetLen, train_accuracy / traindataSetLen))
train_loss_history.append(tarin_running_loss / traindataSetLen)
train_acc_history.append(train_accuracy / traindataSetLen)
tarin_running_loss = 0.0
print('--------- Test Start ---------')
model.eval()
test_running_loss = 0.0
test_accuracy = 0.0
for data in testloader:
img = data['image']
label = data['label']
output = model(img)
loss = criterion(output,label)
test_running_loss += loss.item()
_,pred = torch.max(output,1)
num_correct = (pred==label).sum()
test_accuracy += num_correct.data.item()
print('[%d] Test loss: %.4f , Accuracy: %.4f' %(epoch + 1, test_running_loss / testdataSetLen, test_accuracy / testdataSetLen))
test_loss_history.append(test_running_loss / testdataSetLen)
test_acc_history.append(test_accuracy / testdataSetLen)
test_running_loss = 0.0
print('----- Train Finished -----')
return {
'train_loss_history':train_loss_history,
'test_loss_history':test_loss_history,
'train_acc_history':train_acc_history,
'test_acc_history':test_acc_history
}
font = FontProperties(fname="SimHei.ttf", size=14)
plt.rcParams['font.family']=['SimHei']
DIR_PATH = 'F:\\radarData\\test'
#gpu or cpu
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
print('--------- Prepare Data Start---------')
#load data
dataset = MyDataSet(DIR_PATH)
#split data
train_dataset_size = int(len(dataset)*0.8)
test_dataset_size = len(dataset) - train_dataset_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset,[train_dataset_size, test_dataset_size])
train_dataset_loader = DataLoader(train_dataset, batch_size = 2, shuffle=True)
test_dataset_loader = DataLoader(test_dataset, batch_size = 2, shuffle=True)
print('--------- Prepare Data End---------')
model = CNN_MODEL()
#learning rate
learning_rate = 0.0001
#loss function
criterion = nn.CrossEntropyLoss()
#random grad down
optimizer = optim.SGD(model.parameters(),lr=learning_rate)
result = train(model, criterion, optimizer, train_dataset_loader, train_dataset_size, test_dataset_loader, test_dataset_size, epochs=epoch, log_interval=50, learning_rate = 0.01)
plt.figure()
plt.plot(result['train_loss_history'], label='训练损失值')
plt.plot(result['test_loss_history'], label='测试损失值')
plt.xlabel('训练批次',fontsize=13, fontproperties=font)
plt.ylabel('损失值',fontsize=13, fontproperties=font)
plt.ylim(0,1.2)
plt.title('训练与测试损失值',fontsize=13, fontproperties=font)
plt.legend(loc='upper right')
plt.savefig("./epoch_loss.png")
plt.show()
plt.figure()
plt.plot(result['train_acc_history'], label='训练准确率')
plt.plot(result['test_acc_history'], label='测试准确率')
plt.xlabel('训练批次',fontsize=13, fontproperties=font)
plt.ylabel('准确率',fontsize=13, fontproperties=font)
plt.ylim(0,1.2)
plt.title('训练与测试准确率',fontsize=13, fontproperties=font)
plt.legend(loc='lower right')
plt.savefig("./epoch_acc.png")
plt.show()
model.py
import torch.nn as nn
from torch import nn, optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchsummary import summary
class CNN_MODEL(nn.Module):
def __init__(self):
super(CNN_MODEL, self).__init__()
#常用Layer分为卷积层、池化层、激活函数层、循环网络等、正则化层、损失函数层
self.layer1 = nn.Sequential(
#stride卷积步长
nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=2, padding=1),
#BatchNorm2d进行数据的归一化处理,这使得数据在进行ReLU之前不会因为数据过大导致网络性能不稳定
nn.BatchNorm2d(32),
#ReLU激活函数,inplace要是位True,会把输出直接覆盖到输入中
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2)
)
self.layer2 = nn.Sequential(
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2)
)
self.layer3 = nn.Sequential(
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2)
)
self.layer4 = nn.Sequential(
nn.Linear(in_features=128*3*3, out_features=2048),
nn.Dropout2d(0.2),
nn.ReLU(inplace=True)
)
self.layer5 = nn.Sequential(
nn.Linear(in_features=2048, out_features=2048),
nn.Dropout2d(0.2),
nn.ReLU(inplace=True)
)
self.layer6 = nn.Sequential(
nn.Linear(in_features=2048, out_features=10),
nn.Softmax()
)
def forward(self,x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
# print(x.size())
x = x.view(x.size(0), -1)
x = self.layer4(x)
x = self.layer5(x)
x = self.layer6(x)
return x
边栏推荐
猜你喜欢
策略模式 - Unity
mysql官网下载:Linux的mysql8.x版本(图文详解)
How to choose the appropriate automated testing tools?
Examen des lois et règlements sur la sécurité de l'information
【可信计算】第十二次课:TPM授权与会话
Please insert the disk into "U disk (H)" & unable to access the disk structure is damaged and cannot be read
百度地图自定义样式向右拖拽导致全球地图经度0度无法正常显示
User defined view essential knowledge, Android R & D post must ask 30+ advanced interview questions
第3章业务功能开发(用户登录)
【网络攻防原理与技术】第1章:绪论
随机推荐
Toast will display a simple prompt message on the program interface
L1-025 正整数A+B(Lua)
百度地图自定义样式向右拖拽导致全球地图经度0度无法正常显示
【OKR目标管理】价值分析
【TPM2.0原理及应用指南】 5、7、8章
Several best practices for managing VDI
DatePickerDialog and trimepickerdialog
The computer cannot add a domain, and the Ping domain name is displayed as the public IP. What is the problem? How to solve it?
The mail server is listed in the blacklist. How to unblock it quickly?
【饭谈】Web3.0到来后,测试人员该何去何从?(十条预言和建议)
【网络攻防原理与技术】第4章:网络扫描技术
无法链接远程redis服务器(解决办法百分百)
Functions and usage of tabhost tab
在窗口上面显示进度条
[fan Tan] after the arrival of Web3.0, where should testers go? (ten predictions and suggestions)
What is cloud computing?
Functions and usage of serachview
Create dialog style windows with popupwindow
【TPM2.0原理及应用指南】 1-3章
Function and usage of textswitch text switcher