当前位置:网站首页>基于PyTorch利用CNN对自己的数据集进行分类
基于PyTorch利用CNN对自己的数据集进行分类
2022-07-07 15:40:00 【AI炮灰】
main.py文件
#-*- coding: utf-8 -*-
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms
from pathlib import Path
from model import CNN_MODEL
import torch.nn as nn
import torch.optim as optim
import cv2
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties
import os
from PIL import Image
epoch = 200
train_transformser = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize([224,224]),
transforms.Grayscale(num_output_channels=1),
#transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5),std=(0.5))
])
'''
test_transformser = transforms.Compose([
#transforms.ToPILImage(),
transforms.Resize([224,224]),
transforms.Grayscale(num_output_channels=1),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5),std=(0.5))
])
'''
def prepareData(dir_path):
dir_path = Path(dir_path)
classes = []
for category in dir_path.iterdir():
if category.is_dir():
classes.append(category.name)
images_list = []
labels_list = []
for index,name in enumerate(classes):
class_path = dir_path / name
if not class_path.is_dir():
continue
for img_path in class_path.glob('*.jpg'):
images_list.append(str(img_path))
labels_list.append(int(index))
return images_list,labels_list
class MyDataSet(Dataset):
def __init__(self,dir_path):
self.dir_path = dir_path
self.images, self.labels = prepareData(self.dir_path)
def __len__(self):
return len(self.images)
def __getitem__(self,index):
img_path = self.images[index]
label = self.labels[index]
#img = Image.open(img_path)
img = cv2.imread(img_path)
img = train_transformser(img)
sample = {'image': img, 'label': label}
return sample
def train(model, criterion, optimizer, trainloader, traindataSetLen, testloader, testdataSetLen, epochs=epoch, log_interval=50,learning_rate=0.001):
print('--------- Train Start ---------')
train_loss_history = []
test_loss_history = []
train_acc_history = []
test_acc_history = []
for epoch in range(epochs):
print('epoch:[%d]'%epoch)
model.train()
tarin_running_loss = 0.0
train_accuracy = 0.0
for data in trainloader:
img = data['image']
label = data['label']
output = model(img)
optimizer.zero_grad()
loss = criterion(output, label)
loss.backward()
_,pred = torch.max(output,1)
num_correct = (pred==label).sum()
train_accuracy += num_correct.data.item()
optimizer.step()
tarin_running_loss += loss.item()
#train_loss_history.append(loss.item())
print('[%d] train loss: %.4f , train Accuracy: %.4f' %(epoch + 1, tarin_running_loss / traindataSetLen, train_accuracy / traindataSetLen))
train_loss_history.append(tarin_running_loss / traindataSetLen)
train_acc_history.append(train_accuracy / traindataSetLen)
tarin_running_loss = 0.0
print('--------- Test Start ---------')
model.eval()
test_running_loss = 0.0
test_accuracy = 0.0
for data in testloader:
img = data['image']
label = data['label']
output = model(img)
loss = criterion(output,label)
test_running_loss += loss.item()
_,pred = torch.max(output,1)
num_correct = (pred==label).sum()
test_accuracy += num_correct.data.item()
print('[%d] Test loss: %.4f , Accuracy: %.4f' %(epoch + 1, test_running_loss / testdataSetLen, test_accuracy / testdataSetLen))
test_loss_history.append(test_running_loss / testdataSetLen)
test_acc_history.append(test_accuracy / testdataSetLen)
test_running_loss = 0.0
print('----- Train Finished -----')
return {
'train_loss_history':train_loss_history,
'test_loss_history':test_loss_history,
'train_acc_history':train_acc_history,
'test_acc_history':test_acc_history
}
font = FontProperties(fname="SimHei.ttf", size=14)
plt.rcParams['font.family']=['SimHei']
DIR_PATH = 'F:\\radarData\\test'
#gpu or cpu
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
print('--------- Prepare Data Start---------')
#load data
dataset = MyDataSet(DIR_PATH)
#split data
train_dataset_size = int(len(dataset)*0.8)
test_dataset_size = len(dataset) - train_dataset_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset,[train_dataset_size, test_dataset_size])
train_dataset_loader = DataLoader(train_dataset, batch_size = 2, shuffle=True)
test_dataset_loader = DataLoader(test_dataset, batch_size = 2, shuffle=True)
print('--------- Prepare Data End---------')
model = CNN_MODEL()
#learning rate
learning_rate = 0.0001
#loss function
criterion = nn.CrossEntropyLoss()
#random grad down
optimizer = optim.SGD(model.parameters(),lr=learning_rate)
result = train(model, criterion, optimizer, train_dataset_loader, train_dataset_size, test_dataset_loader, test_dataset_size, epochs=epoch, log_interval=50, learning_rate = 0.01)
plt.figure()
plt.plot(result['train_loss_history'], label='训练损失值')
plt.plot(result['test_loss_history'], label='测试损失值')
plt.xlabel('训练批次',fontsize=13, fontproperties=font)
plt.ylabel('损失值',fontsize=13, fontproperties=font)
plt.ylim(0,1.2)
plt.title('训练与测试损失值',fontsize=13, fontproperties=font)
plt.legend(loc='upper right')
plt.savefig("./epoch_loss.png")
plt.show()
plt.figure()
plt.plot(result['train_acc_history'], label='训练准确率')
plt.plot(result['test_acc_history'], label='测试准确率')
plt.xlabel('训练批次',fontsize=13, fontproperties=font)
plt.ylabel('准确率',fontsize=13, fontproperties=font)
plt.ylim(0,1.2)
plt.title('训练与测试准确率',fontsize=13, fontproperties=font)
plt.legend(loc='lower right')
plt.savefig("./epoch_acc.png")
plt.show()
model.py
import torch.nn as nn
from torch import nn, optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchsummary import summary
class CNN_MODEL(nn.Module):
def __init__(self):
super(CNN_MODEL, self).__init__()
#常用Layer分为卷积层、池化层、激活函数层、循环网络等、正则化层、损失函数层
self.layer1 = nn.Sequential(
#stride卷积步长
nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=2, padding=1),
#BatchNorm2d进行数据的归一化处理,这使得数据在进行ReLU之前不会因为数据过大导致网络性能不稳定
nn.BatchNorm2d(32),
#ReLU激活函数,inplace要是位True,会把输出直接覆盖到输入中
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2)
)
self.layer2 = nn.Sequential(
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2)
)
self.layer3 = nn.Sequential(
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2)
)
self.layer4 = nn.Sequential(
nn.Linear(in_features=128*3*3, out_features=2048),
nn.Dropout2d(0.2),
nn.ReLU(inplace=True)
)
self.layer5 = nn.Sequential(
nn.Linear(in_features=2048, out_features=2048),
nn.Dropout2d(0.2),
nn.ReLU(inplace=True)
)
self.layer6 = nn.Sequential(
nn.Linear(in_features=2048, out_features=10),
nn.Softmax()
)
def forward(self,x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
# print(x.size())
x = x.view(x.size(0), -1)
x = self.layer4(x)
x = self.layer5(x)
x = self.layer6(x)
return x
边栏推荐
- 赋能智慧电力建设 | 麒麟信安高可用集群管理系统,保障用户关键业务连续性
- 第3章业务功能开发(用户访问项目)
- Create dialog style windows with popupwindow
- DatePickerDialog和trimepickerDialog
- 第2章搭建CRM项目开发环境(数据库设计)
- Toast will display a simple prompt message on the program interface
- Numberpick的功能和用法
- Sator launched Web3 game "satorspace" and launched hoobi
- Devops' operational and commercial benefits Guide
- 百度地图自定义样式向右拖拽导致全球地图经度0度无法正常显示
猜你喜欢
随机推荐
字符串 - string(Lua)
Mrs offline data analysis: process OBS data through Flink job
The server is completely broken and cannot be repaired. How to use backup to restore it into a virtual machine without damage?
Audio device strategy audio device output and input selection is based on 7.0 code
MySQL implements the query of merging two fields into one field
Repair method of firewall system crash and file loss, material cost 0 yuan
DNS series (I): why does the updated DNS record not take effect?
【分布式理论】(一)分布式事务
LeetCode 497(C#)
Ratingbar的功能和用法
《世界粮食安全和营养状况》报告发布:2021年全球饥饿人口增至8.28亿
状态模式 - Unity(有限状态机)
鲲鹏开发者峰会2022 | 麒麟信安携手鲲鹏共筑计算产业新生态
第3章业务功能开发(安全退出)
Mysql 索引命中级别分析
Siggraph 2022 best technical paper award comes out! Chen Baoquan team of Peking University was nominated for honorary nomination
Solidity 开发环境搭建
toast会在程序界面上显示一个简单的提示信息
How to mount the original data disk without damage after the reinstallation of proxmox ve?
Show progress bar above window