当前位置:网站首页>Pytorch学习笔记(一)安装与常用函数的使用
Pytorch学习笔记(一)安装与常用函数的使用
2022-07-26 13:29:00 【小胡今天有变强吗】
文章目录
安装Pytorch
在Anaconda环境中创建pytorch环境
conda create -n pytorch python=3.6
激活环境
conda activate pytorch
查看包列表
pip list
pytorch官网:
https://pytorch.org/
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-wMEEx9Ga-1658802973736)(C:\Users\Husheng\Desktop\学习笔记\image-20220721105827500.png)]](/img/ef/c15f4af71b02b2c566e404cf82fad5.png)
安装命令:
conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
出现True说明pytorch可以使用GPU。
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-guSm5J4A-1658802973737)(C:\Users\Husheng\Desktop\学习笔记\image-20220712220003299.png)]](/img/0c/57e8f43dcac3bf681b7f4935f3faed.png)
在pytorch中安装jupyter
conda install nb_conda
启动jupyter
jupyter notebook
shift + 回车,表示跳转到另一个代码块,并且运行上一个代码块。
两个常用函数
(1)dir() 工具箱以及工具箱的分隔区中有什么东西;
(2)help() 每个工具是如何使用的,工具的使用方法。
三种方式编码的区别
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-6HjdTzgI-1658802973737)(C:\Users\Husheng\Desktop\学习笔记\image-20220721131337757.png)]](/img/40/b9087f6de8df6b3111626c4d09384f.png)
Dataset实战
from torch.utils.data import Dataset
from PIL import Image
import os
class MyData(Dataset):
def __init__(self, root_dir, label_dir):
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir, self.label_dir)
self.img_path = os.listdir(self.path)
def __getitem__(self, idx):
img_name = self.img_path[idx]
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
img = Image.open(img_item_path)
label = self.label_dir
return img, label
def __len__(self):
return len(self.img_path)
root_dir = "dataset/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = MyData(root_dir, ants_label_dir)
bees_dataset = MyData(root_dir, bees_label_dir)
train_dataset = ants_dataset + bees_dataset
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-MiOH3EHJ-1658802973738)(C:\Users\Husheng\Desktop\学习笔记\image-20220721170630153.png)]](/img/3a/3dcd63471b3fcbbf73e5b9b8622a66.png)
TensorBoard的使用
安装tensorboard
pip install tensorboard
报错:
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-TVY5sptY-1658802973738)(C:\Users\Husheng\Desktop\学习笔记\image-20220721172628638.png)]](/img/3f/aaad2046a1936e4b69314f2fe2dd14.png)
解决:使用管理员身份运行Anaconda Prompt,重新安装。
安装运行报错:![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-maB8KXuL-1658802973738)(C:\Users\Husheng\Desktop\学习笔记\image-20220721173651485.png)]](/img/5d/0d11edeb5933b710730703ac287d6d.png)
解决:setuptools版本有问题,将其卸载重装。
pip uninstall setuptools
pip install setuptools==59.5.0
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-uGXu9wRs-1658802973738)(C:\Users\Husheng\Desktop\学习笔记\image-20220721174213013.png)]](/img/0a/7bc6decb7f61844e85f293d3e47c8d.png)
–logdir=logs中间不能加空格
指定端口:
tensorboard --logdir=logs --port=6007
如果进行绘制了y = 2x,又绘制了3x,会自动拟合图像,解决方法是删除logs下的所有文件,重新运行代码,重新运行tensorboard --logdir=logs --port=6007
显示图片:
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from PIL import Image
writer = SummaryWriter("logs")
img_path = "data/train/ants_image/0013035.jpg"
img_PIL = Image.open(img_path)
img_array = np.array(img_PIL)
print(type(img_array))
print(img_array.shape)
writer.add_image("train", img_array, 1, dataformats='HWC')
# y = x
for i in range(100):
writer.add_scalar("y = 2x", 3*i, i)
writer.close()
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-cFyyh1MN-1658802973739)(C:\Users\Husheng\Desktop\学习笔记\image-20220721225136606.png)]](/img/f3/b97661812068a688549a8f1d288910.png)
Transforms

安装opencv失败:![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-I5LKEQJT-1658802973743)(C:\Users\Husheng\Desktop\学习笔记\image-20220721232654305.png)]](/img/2b/2257ea19f20b10a825cecda170aa90.png)
这个错误也解决了好久,试了好几种方法(上一篇博客有专门解决了这个问题:https://blog.csdn.net/hshudoudou/article/details/125930549?spm=1001.2014.3001.5502),最后解决方法是指定版本安装:
代码:
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
img_path = "dataset/train/ants/0013035.jpg"
img = Image.open(img_path)
# print(img)
writer = SummaryWriter("logs")
# 1. transforms如何使用
tensor_trans = transforms.ToTensor()
tensor_img = tensor_trans(img)
writer.add_image("Tensor_img", tensor_img)
writer.close()
常见的Transforms

from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
writer = SummaryWriter("logs")
img = Image.open("imgages/blue.jpg")
print(img)
#ToTensor的使用
trans_totensor = transforms.ToTensor()
img_tensor = trans_totensor(img)
writer.add_image("ToTensor", img_tensor)
#Normalize归一化
print(img_tensor[0][0][0])
trans_norm = transforms.Normalize([6, 3, 2], [9, 3, 5])
img_norm = trans_norm(img_tensor)
print(img_norm[0][0][0])
writer.add_image("Normalize", img_norm, 2)
#Resize
print(img.size)
trans_resize = transforms.Resize((512, 512))
# img PIL -> resize -> img_size PIL
img_resize = trans_resize(img)
# img_resize PIL -> totensor -> img_resize tensor
img_resize = trans_totensor(img_resize)
writer.add_image("Resize", img_resize, 0)
print(img_resize)
# Compose - resize - 2
trans_resize_2 = transforms.Resize(512)
# PIL -> PIL -> tensor
trans_compose = transforms.Compose([trans_resize_2, trans_totensor])
img_resize_2 = trans_compose(img)
writer.add_image("Resize", img_resize_2, 2)
# RandomCrop
trans_random = transforms.RandomCrop((500, 1000))
trans_compose_2 = transforms.Compose([trans_random, trans_totensor])
for i in range(10):
img_crop = trans_compose_2(img)
writer.add_image("RandomCropHW", img_crop, i)
writer.close()

使用transforms的注意点:
关注输入和输出,多看官方文档,关注方法需要哪些参数。
torchvision中的数据集的使用

import torchvision
from torch.utils.tensorboard import SummaryWriter
dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transform, download=True)
# print(test_set[0])
# print(test_set.classes)
# img, target = test_set[0]
# print(img)
# print(target)
# print(test_set.classes[target])
# img.show()
#
# print(test_set[0])
writer = SummaryWriter("p10")
for i in range(10):
img, target = test_set[i]
writer.add_image("test_set", img, i)
writer.close()
DataLoader的使用
import torchvision
from torch.utils.data import DataLoader
# 准备的测试数据集
from torch.utils.tensorboard import SummaryWriter
test_data = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)
# 测试数据集中第一张图片及target
img, target = test_data[0]
print(img.shape)
print(target)
writer = SummaryWriter("dataloader")
# step = 0
# for data in test_loader:
# imgs, targets = data
# # print(imgs.shape)
# # print(targets)
# writer.add_images("test_data", imgs, step)
# step += 1
for epoch in range(2):
step = 0
for data in test_loader:
imgs, targets = data
# print(imgs.shape)
# print(targets)
writer.add_images("Epoch:{}".format(epoch), imgs, step)
step += 1
writer.close()
dataloader相当于如何从dataset中取出数据,dataset:数据集,batch_size:每一组的最大数量,shuffle:是否打乱,num_workers:线程数量,drop_last:是否舍弃最后不被整除的图片。
参考资料
边栏推荐
- ROS2学习(1)ROS2简述
- 官宣!艾德韦宣集团与百度希壤达成深度共创合作
- Click El dropdown item/@click.native
- Golang端口扫描设计
- Feixin, which lasted 15 years and had 500million users, was completely dead
- LeetCode 217. 存在重复元素
- 3D modeling and rendering based on B é zier curve
- Implementation of SAP ABAP daemon
- 白帽子揭秘:互联网千亿黑产吓退马斯克
- LeetCode 2119. 反转两次的数字
猜你喜欢

Solution 5g technology helps build smart Parks

El table implements editable table

Can I take your subdomain? Exploring Same-Site Attacks in the Modern Web

12 brand management of commodity system in gulimall background management

Feixin, which lasted 15 years and had 500million users, was completely dead

Comparison between SIGMOD function and softmax function

Dimension disaster dimension disaster suspense

Photoshop(CC2020)未完

Niuke brush sql---2

云智技术论坛工业专场 明天见!
随机推荐
Concept and handling of exceptions
421. 数组中两个数的最大异或值
This article explains the FS file module and path module in nodejs in detail
key&key_ Len & ref & filtered (4) - MySQL execution plan (50)
B+树索引使用(7)匹配列前缀,匹配值范围(十九)
How to write the introduction of GIS method journals and papers?
Oom caused by improper use of multithreading
Comparison between SIGMOD function and softmax function
如何构建以客户为中心的产品蓝图:来自首席技术官的建议
Codeforces round 810 (Div. 2) [competition record]
Team research and development from ants' foraging process (Reprint)
MySQL data directory (3) -- table data structure MyISAM (XXVI)
终极套娃 2.0 | 云原生交付的封装
Tianjin emergency response Bureau and central enterprises in Tianjin signed an agreement to deepen the construction of emergency linkage mechanism
Learn about Pinia state getters actions plugins
panic: Error 1045: Access denied for user ‘root‘@‘117.61.242.215‘ (using password: YES)
Probability theory and mathematical statistics
官宣!艾德韦宣集团与百度希壤达成深度共创合作
冒泡排序的时间复杂度分析
Exploration on cache design optimization of community like business