当前位置:网站首页>基于PyTorch利用CNN对自己的数据集进行分类
基于PyTorch利用CNN对自己的数据集进行分类
2022-07-07 15:40:00 【AI炮灰】
main.py文件
#-*- coding: utf-8 -*-
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms
from pathlib import Path
from model import CNN_MODEL
import torch.nn as nn
import torch.optim as optim
import cv2
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties
import os
from PIL import Image
epoch = 200
train_transformser = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize([224,224]),
transforms.Grayscale(num_output_channels=1),
#transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5),std=(0.5))
])
'''
test_transformser = transforms.Compose([
#transforms.ToPILImage(),
transforms.Resize([224,224]),
transforms.Grayscale(num_output_channels=1),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5),std=(0.5))
])
'''
def prepareData(dir_path):
dir_path = Path(dir_path)
classes = []
for category in dir_path.iterdir():
if category.is_dir():
classes.append(category.name)
images_list = []
labels_list = []
for index,name in enumerate(classes):
class_path = dir_path / name
if not class_path.is_dir():
continue
for img_path in class_path.glob('*.jpg'):
images_list.append(str(img_path))
labels_list.append(int(index))
return images_list,labels_list
class MyDataSet(Dataset):
def __init__(self,dir_path):
self.dir_path = dir_path
self.images, self.labels = prepareData(self.dir_path)
def __len__(self):
return len(self.images)
def __getitem__(self,index):
img_path = self.images[index]
label = self.labels[index]
#img = Image.open(img_path)
img = cv2.imread(img_path)
img = train_transformser(img)
sample = {'image': img, 'label': label}
return sample
def train(model, criterion, optimizer, trainloader, traindataSetLen, testloader, testdataSetLen, epochs=epoch, log_interval=50,learning_rate=0.001):
print('--------- Train Start ---------')
train_loss_history = []
test_loss_history = []
train_acc_history = []
test_acc_history = []
for epoch in range(epochs):
print('epoch:[%d]'%epoch)
model.train()
tarin_running_loss = 0.0
train_accuracy = 0.0
for data in trainloader:
img = data['image']
label = data['label']
output = model(img)
optimizer.zero_grad()
loss = criterion(output, label)
loss.backward()
_,pred = torch.max(output,1)
num_correct = (pred==label).sum()
train_accuracy += num_correct.data.item()
optimizer.step()
tarin_running_loss += loss.item()
#train_loss_history.append(loss.item())
print('[%d] train loss: %.4f , train Accuracy: %.4f' %(epoch + 1, tarin_running_loss / traindataSetLen, train_accuracy / traindataSetLen))
train_loss_history.append(tarin_running_loss / traindataSetLen)
train_acc_history.append(train_accuracy / traindataSetLen)
tarin_running_loss = 0.0
print('--------- Test Start ---------')
model.eval()
test_running_loss = 0.0
test_accuracy = 0.0
for data in testloader:
img = data['image']
label = data['label']
output = model(img)
loss = criterion(output,label)
test_running_loss += loss.item()
_,pred = torch.max(output,1)
num_correct = (pred==label).sum()
test_accuracy += num_correct.data.item()
print('[%d] Test loss: %.4f , Accuracy: %.4f' %(epoch + 1, test_running_loss / testdataSetLen, test_accuracy / testdataSetLen))
test_loss_history.append(test_running_loss / testdataSetLen)
test_acc_history.append(test_accuracy / testdataSetLen)
test_running_loss = 0.0
print('----- Train Finished -----')
return {
'train_loss_history':train_loss_history,
'test_loss_history':test_loss_history,
'train_acc_history':train_acc_history,
'test_acc_history':test_acc_history
}
font = FontProperties(fname="SimHei.ttf", size=14)
plt.rcParams['font.family']=['SimHei']
DIR_PATH = 'F:\\radarData\\test'
#gpu or cpu
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
print('--------- Prepare Data Start---------')
#load data
dataset = MyDataSet(DIR_PATH)
#split data
train_dataset_size = int(len(dataset)*0.8)
test_dataset_size = len(dataset) - train_dataset_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset,[train_dataset_size, test_dataset_size])
train_dataset_loader = DataLoader(train_dataset, batch_size = 2, shuffle=True)
test_dataset_loader = DataLoader(test_dataset, batch_size = 2, shuffle=True)
print('--------- Prepare Data End---------')
model = CNN_MODEL()
#learning rate
learning_rate = 0.0001
#loss function
criterion = nn.CrossEntropyLoss()
#random grad down
optimizer = optim.SGD(model.parameters(),lr=learning_rate)
result = train(model, criterion, optimizer, train_dataset_loader, train_dataset_size, test_dataset_loader, test_dataset_size, epochs=epoch, log_interval=50, learning_rate = 0.01)
plt.figure()
plt.plot(result['train_loss_history'], label='训练损失值')
plt.plot(result['test_loss_history'], label='测试损失值')
plt.xlabel('训练批次',fontsize=13, fontproperties=font)
plt.ylabel('损失值',fontsize=13, fontproperties=font)
plt.ylim(0,1.2)
plt.title('训练与测试损失值',fontsize=13, fontproperties=font)
plt.legend(loc='upper right')
plt.savefig("./epoch_loss.png")
plt.show()
plt.figure()
plt.plot(result['train_acc_history'], label='训练准确率')
plt.plot(result['test_acc_history'], label='测试准确率')
plt.xlabel('训练批次',fontsize=13, fontproperties=font)
plt.ylabel('准确率',fontsize=13, fontproperties=font)
plt.ylim(0,1.2)
plt.title('训练与测试准确率',fontsize=13, fontproperties=font)
plt.legend(loc='lower right')
plt.savefig("./epoch_acc.png")
plt.show()
model.py
import torch.nn as nn
from torch import nn, optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchsummary import summary
class CNN_MODEL(nn.Module):
def __init__(self):
super(CNN_MODEL, self).__init__()
#常用Layer分为卷积层、池化层、激活函数层、循环网络等、正则化层、损失函数层
self.layer1 = nn.Sequential(
#stride卷积步长
nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=2, padding=1),
#BatchNorm2d进行数据的归一化处理,这使得数据在进行ReLU之前不会因为数据过大导致网络性能不稳定
nn.BatchNorm2d(32),
#ReLU激活函数,inplace要是位True,会把输出直接覆盖到输入中
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2)
)
self.layer2 = nn.Sequential(
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2)
)
self.layer3 = nn.Sequential(
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2)
)
self.layer4 = nn.Sequential(
nn.Linear(in_features=128*3*3, out_features=2048),
nn.Dropout2d(0.2),
nn.ReLU(inplace=True)
)
self.layer5 = nn.Sequential(
nn.Linear(in_features=2048, out_features=2048),
nn.Dropout2d(0.2),
nn.ReLU(inplace=True)
)
self.layer6 = nn.Sequential(
nn.Linear(in_features=2048, out_features=10),
nn.Softmax()
)
def forward(self,x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
# print(x.size())
x = x.view(x.size(0), -1)
x = self.layer4(x)
x = self.layer5(x)
x = self.layer6(x)
return x
边栏推荐
- 本周小贴士#134:make_unique与私有构造函数
- 第3章业务功能开发(用户登录)
- 麒麟信安云平台全新升级!
- mysql官网下载:Linux的mysql8.x版本(图文详解)
- Problems encountered in Jenkins' release of H5 developed by uniapp
- Function and usage of textswitch text switcher
- Sator推出Web3遊戲“Satorspace” ,並上線Huobi
- LeetCode 497(C#)
- 第1章CRM核心业务介绍
- Function and usage of calendar view component
猜你喜欢
随机推荐
Alertdialog create dialog
Examen des lois et règlements sur la sécurité de l'information
notification是显示在手机状态栏的通知
Mrs offline data analysis: process OBS data through Flink job
How to mount the original data disk without damage after the reinstallation of proxmox ve?
99%的人都不知道|私有化部署还永久免费的即时通讯软件!
imageswitcher的功能和用法
使用 xml资源文件定义菜单
【可信计算】第十一次课:TPM密码资源管理(三) NV索引与PCR
Numberpick的功能和用法
第2章搭建CRM项目开发环境(数据库设计)
网络攻防复习篇
Please insert the disk into "U disk (H)" & unable to access the disk structure is damaged and cannot be read
第3章业务功能开发(用户登录)
本周小贴士#141:注意隐式转换到bool
企业即时通讯软件是什么?它有哪些优势呢?
麒麟信安中标国网新一代调度项目!
TabHOST 选项卡的功能和用法
Flask build API service SQL configuration file
【饭谈】Web3.0到来后,测试人员该何去何从?(十条预言和建议)