当前位置:网站首页>多卡服务器使用
多卡服务器使用
2022-06-30 02:36:00 【MallocLu】
参考文章:https://blog.csdn.net/qq_42255269/article/details/123427094?spm=1001.2014.3001.5506
示例程序
即最简单的用 vgg16模型 对 CIFAR10数据集 进行10分类
from torch import nn
import torchvision
import argparse
from torch.utils.data import DataLoader
import torch
class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.vgg16 = torchvision.models.vgg16()
self.fc = nn.Linear(1000, 10)
def forward(self, x):
out = self.vgg16(x)
out = self.fc(out)
return out
if __name__ == '__main__':
# 获取batch_size参数
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', type=int)
args = parser.parse_args()
device = 'cuda'
# 配置数据集
train_data = torchvision.datasets.CIFAR10(root="data", train=True, transform=torchvision.transforms.ToTensor(), download=True)
train_dataloader = DataLoader(train_data, batch_size=args.batch_size)
# 创建网络模型
net = MyNet().to(device=device)
# 损失函数
loss_fn = nn.CrossEntropyLoss()
# 优化器
optimizer = torch.optim.SGD(net.parameters(), lr=1e-2)
for i in range(1000):
for step, data in enumerate(train_dataloader):
imgs, targets = data
imgs = imgs.to(device=device)
targets = targets.to(device=device)
outputs = net(imgs)
loss = loss_fn(outputs, targets)
# 优化器优化模型
optimizer.zero_grad()
loss.backward()
optimizer.step()
print("epoch:{} step:{} loss:{}".format(i, step, loss.item()))
单 卡机器 单 卡运行
python main.py --batch_size 16
多 卡机器 单 卡运行
# 卡0运行
CUDA_VISIBLE_DEVICES=0 python main.py --batch_size 16
# 卡1运行
CUDA_VISIBLE_DEVICES=1 python main.py --batch_size 16
多 卡机器 多 卡运行
nn.DataParallel(不推荐)
优点:使用简单,对代码的修改比较小(仅3处修改)
缺点:只是提高训练速度但不能提高batch_size(数据需要先加载到主GPU上再进行其他操作,所有主GPU的显存限制了batch_size的大小)
# batch_size大小应为显卡数量的整数倍
python main.py --batch_size 512
from torch import nn
import torchvision
import argparse
from torch.utils.data import DataLoader
import torch
class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.vgg16 = torchvision.models.vgg16()
self.fc = nn.Linear(1000, 10)
def forward(self, x):
out = self.vgg16(x)
out = self.fc(out)
return out
if __name__ == '__main__':
# 获取batch_size参数
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', type=int)
args = parser.parse_args()
# 修改1
gpus = [0, 1]
device = 'cuda:{}'.format(gpus[0])
# 配置数据集
train_data = torchvision.datasets.CIFAR10(root="data", train=True, transform=torchvision.transforms.ToTensor(), download=True)
train_dataloader = DataLoader(train_data, batch_size=args.batch_size)
# 修改2 创建网络模型
net = nn.DataParallel(MyNet().to(device=device), device_ids=gpus, output_device=gpus[0])
# 损失函数
loss_fn = nn.CrossEntropyLoss()
# 优化器
optimizer = torch.optim.SGD(net.parameters(), lr=1e-2)
for i in range(1000):
for step, data in enumerate(train_dataloader):
imgs, targets = data
# 修改3
imgs = imgs.to(device=device, non_blocking=True)
targets = targets.to(device=device, non_blocking=True)
outputs = net(imgs)
loss = loss_fn(outputs, targets)
# 优化器优化模型
optimizer.zero_grad()
loss.backward()
optimizer.step()
print("epoch:{} step:{} loss:{}".format(i, step, loss.item()))
torch.distributed(推荐)
优点:可以提高训练速度和batch_size
缺点:使用稍复杂,数据集不能shuffle
新版本
CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.run --nproc_per_node=2 main.py --batch_size 512
from torch import nn
import torchvision
import argparse
from torch.utils.data import DataLoader
import torch
import torch.distributed as dist
import os
class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.vgg16 = torchvision.models.vgg16()
self.fc = nn.Linear(1000, 10)
def forward(self, x):
out = self.vgg16(x)
out = self.fc(out)
return out
if __name__ == '__main__':
# 获取batch_size参数
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', type=int)
args = parser.parse_args()
# 修改1
local_rank = int(os.environ["LOCAL_RANK"])
# 修改2 设置GPU之间通信使用的后端和端口
dist.init_process_group(backend='nccl')
# 修改3
device = 'cuda:{}'.format(local_rank)
# 配置数据集
train_data = torchvision.datasets.CIFAR10(root="data", train=True, transform=torchvision.transforms.ToTensor(), download=True)
# 修改4 注意不要设置DataLoader的shuffle为True
train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)
train_dataloader = DataLoader(train_data, batch_size=args.batch_size, sampler=train_sampler)
# 修改5 创建网络模型
net = nn.parallel.DistributedDataParallel(MyNet().to(device), device_ids=[local_rank])
# 损失函数
loss_fn = nn.CrossEntropyLoss()
# 优化器
optimizer = torch.optim.SGD(net.parameters(), lr=1e-2)
for i in range(1000):
for step, data in enumerate(train_dataloader):
imgs, targets = data
# 修改6
imgs = imgs.to(device=device, non_blocking=True)
targets = targets.to(device=device, non_blocking=True)
outputs = net(imgs)
loss = loss_fn(outputs, targets)
# 优化器优化模型
optimizer.zero_grad()
loss.backward()
optimizer.step()
print("epoch:{} step:{} loss:{}".format(i, step, loss.item()))
老版本torch.distributed.launch
CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 main.py --batch_size 512
from torch import nn
import torchvision
import argparse
from torch.utils.data import DataLoader
import torch
import torch.distributed as dist
class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.vgg16 = torchvision.models.vgg16()
self.fc = nn.Linear(1000, 10)
def forward(self, x):
out = self.vgg16(x)
out = self.fc(out)
return out
if __name__ == '__main__':
# 获取batch_size参数
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', type=int)
# 修改1
parser.add_argument('--local_rank', default=-1, type=int,
help='node rank for distributed training')
args = parser.parse_args()
# 修改2 设置GPU之间通信使用的后端和端口
dist.init_process_group(backend='nccl')
# 修改3
device = 'cuda:{}'.format(args.local_rank)
# 配置数据集
train_data = torchvision.datasets.CIFAR10(root="data", train=True, transform=torchvision.transforms.ToTensor(), download=True)
# 修改4 注意不要设置DataLoader的shuffle为True
train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)
train_dataloader = DataLoader(train_data, batch_size=args.batch_size, sampler=train_sampler)
# 修改5 创建网络模型
net = nn.parallel.DistributedDataParallel(MyNet().to(device), device_ids=[args.local_rank])
# 损失函数
loss_fn = nn.CrossEntropyLoss()
# 优化器
optimizer = torch.optim.SGD(net.parameters(), lr=1e-2)
for i in range(1000):
for step, data in enumerate(train_dataloader):
imgs, targets = data
# 修改6
imgs = imgs.to(device=device, non_blocking=True)
targets = targets.to(device=device, non_blocking=True)
outputs = net(imgs)
loss = loss_fn(outputs, targets)
# 优化器优化模型
optimizer.zero_grad()
loss.backward()
optimizer.step()
print("epoch:{} step:{} loss:{}".format(i, step, loss.item()))
边栏推荐
- 选择排序
- LeetCode 3. 无重复字符的最长子串
- SiteLock九个常见问题
- Steam elements hidden in science and Technology Education
- Bucket sort
- FAQs for code signature and driver signature
- What are the requirements for NPDP product manager international certification examination?
- NPDP产品经理国际认证考试报名有什么要求?
- Select sort
- Four, forty, fourhundred swatches
猜你喜欢
Steam elements hidden in science and Technology Education
[Postgres] Postgres database migration
五个最便宜的通配符SSL证书品牌
What should academic presentation /ppt do?
直接插入排序
Seven common errors of SSL certificate and their solutions
SiteLock九个常见问题
2. < tag dynamic programming and 0-1 knapsack problem > lt.416 Split equal sum subset + lt.1049 Weight of the last stone II
JMeter obtains cookies across thread groups or JMeter thread groups share cookies
Implementation of Sanzi chess with C language
随机推荐
什么是X.509证书?X.509证书工作原理及应用?
有流量,但没有销售?增加网站销量的 6 个步骤
DHU programming exercise
What is digicert smart seal?
Unity3D UGUI强制刷新Layout(布局)组件
JS advanced -h5 new features
Five cheapest wildcard SSL certificate brands
True love forever valentine's Day gifts
Pytorch学习(二)
Enlightenment from the revocation of Russian digital certificate by mainstream CA: upgrade the SSL certificate of state secret algorithm to help China's network security to be autonomous and controlla
Xunwei NXP itop-imx6 development platform
Quick sort
JMeter obtains cookies across thread groups or JMeter thread groups share cookies
Simple distinction between break and continue
走进江苏作家诗人胭脂茉莉|世界读书日
Le Code autojs peut - il être chiffré? Oui, présentation des techniques de chiffrement autojs
LeetCode 3. 无重复字符的最长子串
FDA ESG regulation: digital certificate must be used to ensure communication security
Global and Chinese markets of liquid optical waveguides 2022-2028: Research Report on technology, participants, trends, market size and share
Implementation of Sanzi chess with C language