当前位置:网站首页>机器学习之卷积神经网络使用cifar10数据集和alexnet网络模型训练分类模型,安装labelimg,以及报错ERROR
机器学习之卷积神经网络使用cifar10数据集和alexnet网络模型训练分类模型,安装labelimg,以及报错ERROR
2022-06-28 15:49:00 【华为云】
使用cifar10数据集和alexnet网络模型训练分类模型
下载cifar10数据集

代码:
import torchvisionimport torchtransform = torchvision.transforms.Compose( [torchvision.transforms.ToTensor(), torchvision.transforms.Resize(224)])train_set = torchvision.datasets.CIFAR10(root='./',download=False,train=True,transform=transform)test_set = torchvision.datasets.CIFAR10(root='./',download=False,train=False,transform=transform)train_loader = torch.utils.data.DataLoader(train_set,batch_size=8,shuffle=True)test_loader = torch.utils.data.DataLoader(test_set,batch_size=8,shuffle=True)class Alexnet(torch.nn.Module): #1080 2080 def __init__(self,num_classes=10): super(Alexnet,self).__init__() net = torchvision.models.alexnet(pretrained=False) #迁移学习 net.classifier = torch.nn.Sequential() self.features = net self.classifier = torch.nn.Sequential( torch.nn.Dropout(0.3), torch.nn.Linear(256 * 6 * 6, 4096), torch.nn.ReLU(inplace=True), torch.nn.Dropout(0.3), torch.nn.Linear(4096, 4096), torch.nn.ReLU(inplace=True), torch.nn.Linear(4096, num_classes), ) def forward(self,x): x = self.features(x) x = x.view(x.size(0),-1) x = self.classifier(x) return xdevice = torch.device('cpu')net = Alexnet().to(device)loss_func = torch.nn.CrossEntropyLoss().to(device)optim = torch.optim.Adam(net.parameters(),lr=0.001)net.train()for epoch in range(10): for step,(x,y) in enumerate(train_loader): # 28*28*1 32*32*3 x,y = x.to(device),y.to(device) output = net(x) loss = loss_func(output,y) optim.zero_grad() loss.backward() optim.step() print("epoch:",epoch,'loss:',loss)安装labelimg,以及报错
目标检测标注工具:labelimg
安装 pip install labelimg
使用 labelimg

报错
ERROR: spyder 4.1.4 requires pyqtwebengine<5.13; python_version >= “3”, which is not installed. ERROR: spyder 4.1.4 has requirement pyqt5<5.13; python_version >= “3”, but you’ll have pyqt5 5.15.6 which is incompatible
版本不匹配问题
打开Anaconda Prompt
使用命令安装Spyder
pip install spyder==4.1.4
或者
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple/ spyder==4.1.4

使用 labelimg
在安装环境下找到labelimg.exe复制到桌面
打开
打开一张图片

边栏推荐
猜你喜欢

首次失败后,爱美客第二次冲刺港交所上市,财务负责人变动频繁

C语言基础语法

【LeetCode】13、罗马数字转整数

Xinchuang operating system -- kylin kylin desktop operating system (project 10 security center)

Basic grammar of C language

3. caller service call - dapr

In depth learning foundation summary

MIPS assembly language learning-01-sum of two numbers, environment configuration and how to run

5 minutes to make a bouncing ball game

SaaS application management platform solution in the education industry: help enterprises realize the integration of operation and management
随机推荐
成功迁移到云端需要采取的步骤
看界面控件DevExpress WinForms如何创建一个虚拟键盘
openGauss内核:SQL解析过程分析
关于针对tron API签名广播时使用curl的json解析问题解决方案及针对json.loads方法的问题记录
In depth learning foundation summary
隐私计算 FATE - 离线预测
Realization of a springboard machine
国债与定期存款哪个更安全 两者之间有何区别
Flutter dart语言特点总结
Practice of curve replacing CEPH in Netease cloud music
Sample explanation of batch inserting data using MySQL bulkloader
24岁秃头程序员教你微服务交付下如何持续集成交付,学不会砍我
Qt 界面库
【推荐系统】多任务学习之ESMM模型(更新ing)
全球陆续拥抱Web3.0,多国已明确开始抢占先机
Flutter简单实现多语言国际化
扩充C盘(将D盘的内存分给C盘)
Basic grammar of C language
数组中的第K大元素[堆排 + 建堆的实际时间复杂度]
Go zero micro Service Practice Series (VII. How to optimize such a high demand)