当前位置:网站首页>Pytorch学习笔记(一)安装与常用函数的使用
Pytorch学习笔记(一)安装与常用函数的使用
2022-07-26 13:29:00 【小胡今天有变强吗】
文章目录
安装Pytorch
在Anaconda环境中创建pytorch环境
conda create -n pytorch python=3.6
激活环境
conda activate pytorch
查看包列表
pip list
pytorch官网:
https://pytorch.org/
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-wMEEx9Ga-1658802973736)(C:\Users\Husheng\Desktop\学习笔记\image-20220721105827500.png)]](/img/ef/c15f4af71b02b2c566e404cf82fad5.png)
安装命令:
conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
出现True说明pytorch可以使用GPU。
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-guSm5J4A-1658802973737)(C:\Users\Husheng\Desktop\学习笔记\image-20220712220003299.png)]](/img/0c/57e8f43dcac3bf681b7f4935f3faed.png)
在pytorch中安装jupyter
conda install nb_conda
启动jupyter
jupyter notebook
shift + 回车,表示跳转到另一个代码块,并且运行上一个代码块。
两个常用函数
(1)dir() 工具箱以及工具箱的分隔区中有什么东西;
(2)help() 每个工具是如何使用的,工具的使用方法。
三种方式编码的区别
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-6HjdTzgI-1658802973737)(C:\Users\Husheng\Desktop\学习笔记\image-20220721131337757.png)]](/img/40/b9087f6de8df6b3111626c4d09384f.png)
Dataset实战
from torch.utils.data import Dataset
from PIL import Image
import os
class MyData(Dataset):
def __init__(self, root_dir, label_dir):
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir, self.label_dir)
self.img_path = os.listdir(self.path)
def __getitem__(self, idx):
img_name = self.img_path[idx]
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
img = Image.open(img_item_path)
label = self.label_dir
return img, label
def __len__(self):
return len(self.img_path)
root_dir = "dataset/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = MyData(root_dir, ants_label_dir)
bees_dataset = MyData(root_dir, bees_label_dir)
train_dataset = ants_dataset + bees_dataset
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-MiOH3EHJ-1658802973738)(C:\Users\Husheng\Desktop\学习笔记\image-20220721170630153.png)]](/img/3a/3dcd63471b3fcbbf73e5b9b8622a66.png)
TensorBoard的使用
安装tensorboard
pip install tensorboard
报错:
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-TVY5sptY-1658802973738)(C:\Users\Husheng\Desktop\学习笔记\image-20220721172628638.png)]](/img/3f/aaad2046a1936e4b69314f2fe2dd14.png)
解决:使用管理员身份运行Anaconda Prompt,重新安装。
安装运行报错:![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-maB8KXuL-1658802973738)(C:\Users\Husheng\Desktop\学习笔记\image-20220721173651485.png)]](/img/5d/0d11edeb5933b710730703ac287d6d.png)
解决:setuptools版本有问题,将其卸载重装。
pip uninstall setuptools
pip install setuptools==59.5.0
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-uGXu9wRs-1658802973738)(C:\Users\Husheng\Desktop\学习笔记\image-20220721174213013.png)]](/img/0a/7bc6decb7f61844e85f293d3e47c8d.png)
–logdir=logs中间不能加空格
指定端口:
tensorboard --logdir=logs --port=6007
如果进行绘制了y = 2x,又绘制了3x,会自动拟合图像,解决方法是删除logs下的所有文件,重新运行代码,重新运行tensorboard --logdir=logs --port=6007
显示图片:
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from PIL import Image
writer = SummaryWriter("logs")
img_path = "data/train/ants_image/0013035.jpg"
img_PIL = Image.open(img_path)
img_array = np.array(img_PIL)
print(type(img_array))
print(img_array.shape)
writer.add_image("train", img_array, 1, dataformats='HWC')
# y = x
for i in range(100):
writer.add_scalar("y = 2x", 3*i, i)
writer.close()
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-cFyyh1MN-1658802973739)(C:\Users\Husheng\Desktop\学习笔记\image-20220721225136606.png)]](/img/f3/b97661812068a688549a8f1d288910.png)
Transforms

安装opencv失败:![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-I5LKEQJT-1658802973743)(C:\Users\Husheng\Desktop\学习笔记\image-20220721232654305.png)]](/img/2b/2257ea19f20b10a825cecda170aa90.png)
这个错误也解决了好久,试了好几种方法(上一篇博客有专门解决了这个问题:https://blog.csdn.net/hshudoudou/article/details/125930549?spm=1001.2014.3001.5502),最后解决方法是指定版本安装:
代码:
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
img_path = "dataset/train/ants/0013035.jpg"
img = Image.open(img_path)
# print(img)
writer = SummaryWriter("logs")
# 1. transforms如何使用
tensor_trans = transforms.ToTensor()
tensor_img = tensor_trans(img)
writer.add_image("Tensor_img", tensor_img)
writer.close()
常见的Transforms

from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
writer = SummaryWriter("logs")
img = Image.open("imgages/blue.jpg")
print(img)
#ToTensor的使用
trans_totensor = transforms.ToTensor()
img_tensor = trans_totensor(img)
writer.add_image("ToTensor", img_tensor)
#Normalize归一化
print(img_tensor[0][0][0])
trans_norm = transforms.Normalize([6, 3, 2], [9, 3, 5])
img_norm = trans_norm(img_tensor)
print(img_norm[0][0][0])
writer.add_image("Normalize", img_norm, 2)
#Resize
print(img.size)
trans_resize = transforms.Resize((512, 512))
# img PIL -> resize -> img_size PIL
img_resize = trans_resize(img)
# img_resize PIL -> totensor -> img_resize tensor
img_resize = trans_totensor(img_resize)
writer.add_image("Resize", img_resize, 0)
print(img_resize)
# Compose - resize - 2
trans_resize_2 = transforms.Resize(512)
# PIL -> PIL -> tensor
trans_compose = transforms.Compose([trans_resize_2, trans_totensor])
img_resize_2 = trans_compose(img)
writer.add_image("Resize", img_resize_2, 2)
# RandomCrop
trans_random = transforms.RandomCrop((500, 1000))
trans_compose_2 = transforms.Compose([trans_random, trans_totensor])
for i in range(10):
img_crop = trans_compose_2(img)
writer.add_image("RandomCropHW", img_crop, i)
writer.close()

使用transforms的注意点:
关注输入和输出,多看官方文档,关注方法需要哪些参数。
torchvision中的数据集的使用

import torchvision
from torch.utils.tensorboard import SummaryWriter
dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transform, download=True)
# print(test_set[0])
# print(test_set.classes)
# img, target = test_set[0]
# print(img)
# print(target)
# print(test_set.classes[target])
# img.show()
#
# print(test_set[0])
writer = SummaryWriter("p10")
for i in range(10):
img, target = test_set[i]
writer.add_image("test_set", img, i)
writer.close()
DataLoader的使用
import torchvision
from torch.utils.data import DataLoader
# 准备的测试数据集
from torch.utils.tensorboard import SummaryWriter
test_data = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)
# 测试数据集中第一张图片及target
img, target = test_data[0]
print(img.shape)
print(target)
writer = SummaryWriter("dataloader")
# step = 0
# for data in test_loader:
# imgs, targets = data
# # print(imgs.shape)
# # print(targets)
# writer.add_images("test_data", imgs, step)
# step += 1
for epoch in range(2):
step = 0
for data in test_loader:
imgs, targets = data
# print(imgs.shape)
# print(targets)
writer.add_images("Epoch:{}".format(epoch), imgs, step)
step += 1
writer.close()
dataloader相当于如何从dataset中取出数据,dataset:数据集,batch_size:每一组的最大数量,shuffle:是否打乱,num_workers:线程数量,drop_last:是否舍弃最后不被整除的图片。
参考资料
边栏推荐
- Control the probability of random winning [C | random]
- 深度学习3D人体姿态估计国内外研究现状及痛点
- Implementation of SAP ABAP daemon
- Photoshop(CC2020)未完
- SuperMap iclient for leaflet loads Gauss Kruger projection three-dimensional zonation CGCS2000 geodetic coordinate system WMTs service
- Huawei computer test ~ offset realizes string encryption
- Can I take your subdomain? Exploring Same-Site Attacks in the Modern Web
- Algorithm -- continuous sequence (kotlin)
- Win11+vs2019 configuration yolox
- Leetcode 263. ugly number
猜你喜欢

Unicorn, valued at $1.5 billion, was suddenly laid off, and another track was cold?
Exploration on cache design optimization of community like business

Chat system based on webrtc and websocket

Unicode文件解析方法及存在问题

HCIP第十二天笔记整理(BGP联邦、选路规则)

Seven steps to copywriting script ---- document team collaborative management

历时15年、拥有5亿用户的飞信,彻底死了

File upload and download performance test based on the locust framework
![[upper computer tutorial] Application of integrated stepping motor and Delta PLC (as228t) under CANopen communication](/img/d4/c677de31f73a0e0a4b8b10b91e984a.png)
[upper computer tutorial] Application of integrated stepping motor and Delta PLC (as228t) under CANopen communication

Oom caused by improper use of multithreading
随机推荐
421. 数组中两个数的最大异或值
Brief introduction of reflection mechanism
B+ tree index uses (7) matching column prefix, matching value range (19)
[flower carving hands-on] fun music visualization series small project (12) -- meter tube fast rhythm light
Leetcode 1523. count odd numbers within the interval
vector的一些实用操作
带你熟悉云网络的“电话簿”:DNS
目标检测网络R-CNN 系列
Solve the problem that the remote host cannot connect to the MySQL database
LeetCode 263.丑数
AI-理论-知识图谱1-基础
(Reprint) creation methods of various points in ArcGIS Engine
Huawei computer test ~ offset realizes string encryption
Emotion analysis model based on Bert
B+ tree index use (6) leftmost principle -- MySQL from entry to proficiency (18)
Some practical operations of vector
panic: Error 1045: Access denied for user ‘root‘@‘117.61.242.215‘ (using password: YES)
Unicode文件解析方法及存在问题
B+ tree index use (8) sorting use and precautions (20)
SuperMap iclient for leaflet loads Gauss Kruger projection three-dimensional zonation CGCS2000 geodetic coordinate system WMTs service