当前位置:网站首页>使用TensorBoard可视化训练过程
使用TensorBoard可视化训练过程
2022-07-25 11:37:00 【Alexa2077】
一,可视化训练过程
训练过程的可视化在深度学习模型训练中扮演着重要的角色。学习的过程是一个优化的过程,我们需要找到最优的点作为训练过程的输出产物。
希望可视化loss,输入数据(尤其是图片)、模型结构、参数分布等,这些对于我们在debug中查找问题来源非常重要(比如输入数据和我们想象的是否一致)。
TensorBoard作为一款可视化工具能够满足上面提到的各种需求。TensorBoard由TensorFlow团队开发,最早和TensorFlow配合使用,后来广泛应用于各种深度学习框架的可视化中来。
1,安装
在已安装PyTorch的环境下使用pip安装即可,也可以使用PyTorch自带的tensorboard工具,此时不需要额外安装tensorboard。
pip install tensorboardX
2,可视化逻辑
可以将TensorBoard看做一个记录员,它可以记录我们指定的数据,包括模型每一层的feature map,权重,以及训练loss等等。TensorBoard将记录下来的内容保存在一个用户指定的文件夹里,程序不断运行中TensorBoard会不断记录,记录下的内容可以通过网页的形式加以可视化。
3,配置与启动
在使用TensorBoard前,我们需要先指定一个文件夹供TensorBoard保存记录下来的数据。然后调用tensorboard中的SummaryWriter作为上述记录员:
from tensorboardX import SummaryWriter
writer = SummaryWriter('./runs')
上面的操作实例化SummaryWritter为变量writer,并指定writer的输出目录为当前目录下的"runs"目录。也就是说,之后tensorboard记录下来的内容都会保存在runs。
如果使用PyTorch自带的tensorboard,则采用如下方式import:
from torch.utils.tensorboard import SummaryWriter
是否可以手动往runs文件夹里添加数据用于可视化,或者把runs文件夹里的数据放到其他机器上可视化呢?答案是可以的。只要数据被记录,你可以将这个数据分享给其他人,其他人在安装了tensorboard的情况下就会看到你分享的数据。
启动tensorboard也很简单,在命令行中输入
tensorboard --logdir=/path/to/logs/ --port=xxxx
其中“path/to/logs/“是指定的保存tensorboard记录结果的文件路径(等价于上面的“./runs”,port是外部访问TensorBoard的端口号,可以通过访问ip:port访问tensorboard,这一操作和jupyter notebook的使用类似。如果不是在服务器远程使用的话则不需要配置port。
有时,为了tensorboard能够不断地在后台运行,也可以使用nohup命令或者tmux工具来运行tensorboard。
4,模型结构可视化
首先定义模型结构:
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(in_channels=3,out_channels=32,kernel_size = 3)
self.pool = nn.MaxPool2d(kernel_size = 2,stride = 2)
self.conv2 = nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5)
self.adaptive_pool = nn.AdaptiveMaxPool2d((1,1))
self.flatten = nn.Flatten()
self.linear1 = nn.Linear(64,32)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(32,1)
self.sigmoid = nn.Sigmoid()
def forward(self,x):
x = self.conv1(x)
x = self.pool(x)
x = self.conv2(x)
x = self.pool(x)
x = self.adaptive_pool(x)
x = self.flatten(x)
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
y = self.sigmoid(x)
return y
model = Net()
print(model)
输出如下:
Net(
(conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))
(pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
(adaptive_pool): AdaptiveMaxPool2d(output_size=(1, 1))
(flatten): Flatten(start_dim=1, end_dim=-1)
(linear1): Linear(in_features=64, out_features=32, bias=True)
(relu): ReLU()
(linear2): Linear(in_features=32, out_features=1, bias=True)
(sigmoid): Sigmoid()
)
可视化模型的思路就是给定一个输入数据,前向传播后得到模型的结构,再通过TensorBoard进行可视化,使用add_graph:
writer.add_graph(model, input_to_model = torch.rand(1, 3, 224, 224))
writer.close()
结果如下:
5,图像可视化
可视化方法:
- 对于单张图片的显示使用add_image
- 对于多张图片的显示使用add_images
- 有时需要使用torchvision.utils.make_grid将多张图片拼成一张图片后,用writer.add_image显示
例程:使用torchvision的CIFAR10数据集为例
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
transform_train = transforms.Compose(
[transforms.ToTensor()])
transform_test = transforms.Compose(
[transforms.ToTensor()])
train_data = datasets.CIFAR10(".", train=True, download=True, transform=transform_train)
test_data = datasets.CIFAR10(".", train=False, download=True, transform=transform_test)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=64)
images, labels = next(iter(train_loader))
# 仅查看一张图片
writer = SummaryWriter('./pytorch_tb')
writer.add_image('images[0]', images[0])
writer.close()
# 将多张图片拼接成一张图片,中间用黑色网格分割
# create grid of images
writer = SummaryWriter('./pytorch_tb')
img_grid = torchvision.utils.make_grid(images)
writer.add_image('image_grid', img_grid)
writer.close()
# 将多张图片直接写入
writer = SummaryWriter('./pytorch_tb')
writer.add_images("images",images,global_step = 0)
writer.close()
6,连续变量可视化
TensorBoard可以用来可视化连续变量(或时序变量)的变化过程,通过add_scalar实现:
writer = SummaryWriter('./pytorch_tb')
for i in range(500):
x = i
y = x**2
writer.add_scalar("x", x, i) #日志中记录x在第step i 的值
writer.add_scalar("y", y, i) #日志中记录y在第step i 的值
writer.close()
如果想在同一张图中显示多个曲线,则需要分别建立存放子路径(使用SummaryWriter指定路径即可自动创建,但需要在tensorboard运行目录下),同时在add_scalar中修改曲线的标签使其一致即可:
writer1 = SummaryWriter('./pytorch_tb/x')
writer2 = SummaryWriter('./pytorch_tb/y')
for i in range(500):
x = i
y = x*2
writer1.add_scalar("same", x, i) #日志中记录x在第step i 的值
writer2.add_scalar("same", y, i) #日志中记录y在第step i 的值
writer1.close()
writer2.close()
7,参数分布可视化
当我们需要对参数(或向量)的变化,或者对其分布进行研究时,可以方便地用TensorBoard来进行可视化,通过add_histogram实现。下面给出一个例子
import torch
import numpy as np
# 创建正态分布的张量模拟参数矩阵
def norm(mean, std):
t = std * torch.randn((100, 20)) + mean
return t
writer = SummaryWriter('./pytorch_tb/')
for step, mean in enumerate(range(-10, 10, 1)):
w = norm(mean, 1)
writer.add_histogram("w", w, step)
writer.flush()
writer.close()
8,服务器端使用TensorBoard
一般情况下,我们会连接远程的服务器来对模型进行训练,但由于服务器端是没有浏览器的(纯命令模式),因此,我们需要进行相应的配置,才可以在本地浏览器,使用tensorboard查看服务器运行的训练过程。 本文提供以下几种方式进行,其中前两种方法都是建立SSH隧道,实现远程端口到本机端口的转发,最后一种方法适用于没有下载Xshell等SSH连接工具的用户。
MobaXterm:
在MobaXterm点击Tunneling
选择New SSH tunnel,我们会出现以下界面
1,对新建的SSH通道做以下设置,第一栏我们选择Local port forwarding,< Remote Server>我们填写localhost,< Remote port>填写6006,tensorboard默认会在6006端口进行显示,我们也可以根据 tensorboard --logdir=/path/to/logs/ --port=xxxx的命令中的port进行修改,< SSH server> 填写我们连接服务器的ip地址,填写我们连接的服务器的用户名,填写端口号(通常为22),< forwarded port>填写的是本地的一个端口号,以便我们后面可以对其进行访问。
2,设定好之后,点击Save,然后Start。在启动tensorboard,这样我们就可以在本地的浏览器输入http://localhost:6006/对其进行访问了
xshell:
1,Xshell的连接方法与MobaXterm的连接方式本质上是一样的,具体操作如下:
2,连接上服务器后,打开当前会话属性,会出现下图,我们选择隧道,点击添加

3,按照下方图进行选择,其中目标主机代表的是服务器,源主机代表的是本地,端口的选择根据实际情况而定。
4,启动tensorboard,在本地127.0.0.1:6006 或者 localhost:6006进行访问
SSL:
1,该方法是将服务器的6006端口重定向到自己机器上来,我们可以在本地的终端里输入以下代码:其中16006代表映射到本地的端口,6006代表的是服务器上的端口
ssh -L 16006:127.0.0.1:6006 [email protected]_server_ip
2,在服务上使用默认的6006端口正常启动tensorboard
tensorboard --logdir=xxx --port=6006
3,在本地的浏览器输入地址
127.0.0.1:16006 或者 localhost:16006
其实TensorBoard的逻辑还是很简单的,它的基本逻辑就是文件的读写逻辑,写入想要可视化的数据,然后TensorBoard自己会读出来。
注:本文参与DataWhale-深入浅出pytorch小组学习项目!
边栏推荐
- Transformer variants (spark transformer, longformer, switch transformer)
- Zuul网关使用
- 容错机制记录
- R语言ggplot2可视化:可视化散点图并为散点图中的部分数据点添加文本标签、使用ggrepel包的geom_text_repel函数避免数据点之间的标签互相重叠(为数据点标签添加线段、指定线段的角度
- 和特朗普吃了顿饭后写下了这篇文章
- PHP curl post length required error setting header header
- Week303 of leetcode (20220724)
- keepalived实现mysql的高可用
- R语言使用ggpubr包的ggarrange函数将多幅图像组合起来、使用ggexport函数将可视化图像保存为jpeg格式(width参数指定宽度、height参数指定高度、res参数指定分辨率)
- 【AI4Code】《Pythia: AI-assisted Code Completion System》(KDD 2019)
猜你喜欢

3.2.1 什么是机器学习?

【AI4Code】《CodeBERT: A Pre-Trained Model for Programming and Natural Languages》 EMNLP 2020

【GCN-RS】Region or Global? A Principle for Negative Sampling in Graph-based Recommendation (TKDE‘22)

Location analysis of recording an online deadlock

2.1.2 机器学习的应用
![[dark horse morning post] eBay announced its shutdown after 23 years of operation; Wei Lai throws an olive branch to Volkswagen CEO; Huawei's talented youth once gave up their annual salary of 3.6 mil](/img/d7/4671b5a74317a8f87ffd36be2b34e1.jpg)
[dark horse morning post] eBay announced its shutdown after 23 years of operation; Wei Lai throws an olive branch to Volkswagen CEO; Huawei's talented youth once gave up their annual salary of 3.6 mil

Meta learning (meta learning and small sample learning)

【GCN多模态RS】《Pre-training Representations of Multi-modal Multi-query E-commerce Search》 KDD 2022

通信总线协议一 :UART

Application and innovation of low code technology in logistics management
随机推荐
[high concurrency] a lock faster than read-write lock in high concurrency scenarios. I'm completely convinced after reading it!! (recommended Collection)
马斯克的“灵魂永生”:一半炒作,一半忽悠
logstash
NLP knowledge - pytorch, back propagation, some small pieces of notes for predictive tasks
Eureka注册中心开启密码认证-记录
scrapy 爬虫框架简介
R语言可视化散点图、使用ggrepel包的geom_text_repel函数避免数据点之间的标签互相重叠(设置min.segment.length参数为Inf不添加标签线段)
Resttemplate and ribbon are easy to use
利用wireshark对TCP抓包分析
Dr. water 2
Pycharm connects to the remote server SSH -u reports an error: no such file or directory
From cloud native to intelligent, in-depth interpretation of the industry's first "best practice map of live video technology"
【AI4Code】《Contrastive Code Representation Learning》 (EMNLP 2021)
嵌套事务 UnexpectedRollbackException 分析与事务传播策略
[high concurrency] Why is the simpledateformat class thread safe? (six solutions are attached, which are recommended for collection)
Transformer variants (routing transformer, linformer, big bird)
Those young people who left Netease
PHP curl post length required error setting header header
R语言ggplot2可视化:使用ggpubr包的ggviolin函数可视化小提琴图、设置add参数在小提琴内部添加抖动数据点以及均值标准差竖线(jitter and mean_sd)
和特朗普吃了顿饭后写下了这篇文章