当前位置:网站首页>Pytorch深度学习代码技巧
Pytorch深度学习代码技巧
2022-06-26 14:41:00 【八个牙履】
在实际搭建深度学习网络中遇到很多坑,也在读别人的代码时看到很多技巧,统一做一个记录,也方便自己查阅
参数配置
Argparser库
Argparser库是python自带的库,使用Argparser能让我们像在Linux系统上一样用命令行去设置参数,生成的parse_args对象将所有的参数打包,在多个文件中传递修改参数时非常方便
import argparse
parser = argparse.ArgumentParser(description='MODELname') # 为参数解析器赋予一个名字
register_data_args(parser)
parser.add_argument("--dropout", type=float, default=0.5,
help="dropout probability")
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--lr", type=float, default=1e-2,
help="learning rate")
parser.add_argument("--n-epochs", type=int, default=200,
help="number of training epochs")
parser.add_argument("--n-hidden", type=int, default=16,
help="number of hidden gcn units")
parser.add_argument("--n-layers", type=int, default=1,
help="number of hidden gcn layers")
parser.add_argument("--weight-decay", type=float, default=5e-4,
help="Weight for L2 loss")
parser.add_argument("--aggregator-type", type=str, default="gcn",
help="Aggregator type: mean/gcn/pool/lstm")
config = parser.parse_args()
# 之后就能够用config.xxx代表各个参数了
# 调试,在终端输入 python train.py --n-epoch 100 --lr 1e-3 ....
模型框架
Dataloader
- 要把load data的部分放在getitem函数下面,类的主体只记录train data的路径,这样在训练的时候用一些调一些,不会导致过高的CPU内存占用
- len()函数重写时要注意一定是和getitem中的数据量相匹配
class TrainDataset(Dataset):
def __init__(self,listdir=list_dir):
super(TrainDataset, self).__init__()
self.train_dirs = []
for dir in listdir:
self.train_dirs.append(dir)
...
def __getitem__(self, index):
path = self.train_dirs[index]
data = np.load(path)
...
def __len__(self):
return len(self.train_dirs)
- collate_fn
在构建Dataloader对象时可以设置collate_fn参数,传入数据处理的函数,需要自己写。
trainloader = Dataloader(dataset = train_dataset,shuffle=True,collate_fn=my_func)
collate_fn的作用在于能够自定义数据的获取方式
学习率
- lr_scheduler
torch.optim.lr_scheduler模块提供了一些根据epoch训练次数来调整学习率(learning rate)的方法。一般情况下我们会设置随着epoch的增大而逐渐减小学习率从而达到更好的训练效果。
常用的学习率调整策略有:
StepLR:等间隔调整学习率,每次调整为 lr*gamma,调整间隔为step_size
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size,gamma=0.1,last_epoch=-1,verbose=False)
参数:
1、optimizer:设置的优化器
2、step_size:学习率调整步长,每经过step_size更新一次
3、gamma:学习率调整倍数
4、last_epoch:last_epoch之后恢复lr为initial_lr(如果是训练了很多个epoch后中断了 继续训练 这个值就等于加载的模型的epoch 默认为-1表示从头开始训练,即从epoch=1开始
5、verbose:是否每次改变都输出一次lr的值
MultiStepLR : 当前epoch数满足设定值时,调整学习率。这个方法适合后期调试使用,观察loss曲线,为每个实验制定学习率调整时期
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones, gamma=0.1, last_epoch=-1, verbose=False)
milestones:一个包含epoch索引的list,列表中的每个索引代表调整学习率的epoch。list中的值必须是递增的。 如 [20, 50, 100] 表示在epoch为20,50,100时调整学习率。
其他参数设置方法相同
ExponentialLR:按指数衰减调整学习率,调整公式:lr = lr*gamma**epoch
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma, last_epoch=-1, verbose=False)
CosineAnnealingLR:余弦退火策略,周期性地调整学习率
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max, eta_min=0, last_epoch=-1, verbose=False)
参数:1、T_max :学习率调整的周期,每个周期重新将学习率调整回初始值,再衰减,这样的策略有助于跳出马鞍面。通常设置为len(train_dataset)
2、eta_min:衰减到的最小的学习率,defult为0
3、last_epoch:上一个epoch数,这个变量用于指示学习率是否需要调整。当last_epoch符合设定的间隔时就会调整学习率。当设置为-1时,学习率设置为初始值
在每个train step中更新学习率:
scheduler.step()
- warmup
warmup是一种学习率优化方法,最早出现在resnet论文中,在模型训练初期选用较小的学习率,训练一段时间之后(10epoch 或者 10000steps)使用预设的学习率进行训练
使用warmup的原因:
模型训练初期,权重随机化,对数据的理解为0,在第一个epoch中,模型会根据输入的数据进行快速的调参,此时如果采用较大的学习率,有很大的可能使模型学偏,后续需要更多的轮次才能拉回来
模型训练一段时间之后,对数据有一定的先验知识,此时使用较大的学习率模型不容易学偏,可以使用较大的学习率加速训练。
模型使用较大的学习率训练一段时间之后,模型的分布相对比较稳定,此时不宜从数据中再学到新的特点,如果继续使用较大的学习率会破坏模型的稳定性,而使用较小的学习率更获得最优。
warm_up实现
class WarmupLR(_LRScheduler):
"""The WarmupLR scheduler This scheduler is almost same as NoamLR Scheduler except for following difference: NoamLR: lr = optimizer.lr * model_size ** -0.5 * min(step ** -0.5, step * warmup_step ** -1.5) WarmupLR: lr = optimizer.lr * warmup_step ** 0.5 * min(step ** -0.5, step * warmup_step ** -1.5) Note that the maximum lr equals to optimizer.lr in this scheduler. """
def __init__(
self,
optimizer: torch.optim.Optimizer,
warmup_steps: Union[int, float] = 25000,
last_epoch: int = -1,
):
assert check_argument_types()
self.warmup_steps = warmup_steps
# __init__() must be invoked before setting field
# because step() is also invoked in __init__()
super().__init__(optimizer, last_epoch)
def __repr__(self):
return f"{
self.__class__.__name__}(warmup_steps={
self.warmup_steps})"
def get_lr(self):
step_num = self.last_epoch + 1
return [
lr
* self.warmup_steps ** 0.5
* min(step_num ** -0.5, step_num * self.warmup_steps ** -1.5)
for lr in self.base_lrs
]
def set_step(self, step: int):
self.last_epoch = step
warm_up配合lr_scheduler一起使用,先线性增大学习率,再衰减:
if step_counter>10:
scheduler.step()
模型训练
分布式
- nn.DataParallel
通过torch.nn.DataParallel进行分布式训练,要求主机上有不止一块GPU,训练时将数据分摊到多块GPU上进行多进程训练,每块GPU上进行独立的optimize,再将loss进行汇总求平均,将反向传播的梯度广播到每一块GPU上
model= MY_MODEL()
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
开分布式要注意不能使用next(model.parameters()),并且在模型中存在 着foriegner model时,不能在模型内部给这些foriegner model指定设备号(cuda:0),可以用torch.device('cuda')代替,让程序自己去分配,数据也是如此,不然会出现模型运算的数据放在了不同的GPU的情况
- 查看模型、数据所在设备
# data type:torch.tensor
print(next(model.parameters()).device)
print(data.device)
判断数据是否在gpu上:
print(data.is_cuda)
数据在gpu、cpu之间的相互转换:
# 最简单的,data type:torch.tensor
data.cpu()
data.cuda()
# 设置设备
device = torch.device('cuda:0' if torch.cuda_is_avaliable() else 'cpu')
data.to(device)
- local rank
在pytorch上进行分布式训练,指定local rank,即进程内的GPU 编号,非显式参数,由 torch.distributed.launch 内部指定。主机GPU的local_rank为0。比方说, rank = 3,local_rank = 0 表示第 3 个进程内的第 1 块 GPU主机进程local rank=0
在每次迭代中,每个进程具有自己的 optimizer ,并独立完成所有的优化步骤,进程内与一般的训练无异。
在各进程梯度计算完成之后,各进程需要将梯度进行汇总平均,然后再由 rank=0 的进程,将其 broadcast 到所有进程。之后,各进程用该梯度来更新参数。
设置第几块GPU工作
torch.cuda.set_device(arg.local_rank)
device = torch.device('cuda',arg.local_rank)
调试提示、数据可视化
- 设置按钮,检查是否下载了数据集,出错则raise RuntimeError
if not self.check_integrity():
raise RuntimeError('Dataset not found or corrupted.' +
' You need to download it from official website.')
# 检查数据集的路径是否存在
def check_integrity(self):
if not os.path.exists(self.root_dir):
return False
else:
return True
- 用enumerate和列表推导式(for表达式)生成word2idx,labels2idx,这样idx不用靠单独设置变量再for循环+=1来设置
self.label2index = {
label: index for index, label in enumerate(sorted(set(labels)))}
- 记录模型的参数总量
# pytorch中的numel函数统计tensor中的总元素量
num_params = sum(p.numel() for p in model.parameters())
print(num_params)
训练日志
- logging模块
用logging模块记录训练日志,在后台在线写入日志文件,也可以在线输出,方便监控训练信息
1 import logging
2
# 文件保存地址,文件名、信息级别
3 logging.basicConfig(filename=os.path.join(self.writer.log_dir, 'training.log'),level=logging.DEBUG, format='%(asctime)s - %(name)s - %(message)s')
4
5 logging.debug('this is debug message') # 这些信息将被记录在training.log中
6 logging.info('this is info message')
7 logging.warning('this is warning message')
8
9 ''''' 10 结果: 11 2017-08-23 14:22:25,713 - root - this is debug message 12 2017-08-23 14:22:25,713 - root - this is info message 13 2017-08-23 14:22:25,714 - root - this is warning message 14 '''
logging.basicConfig 函数各参数:
filename: 指定日志文件名
filemode: 和file函数意义相同,指定日志文件的打开模式,‘w’或’a’
format: 指定输出的格式和内容,format可以输出很多有用信息,如上例所示:
%(levelno)s: 打印日志级别的数值
%(levelname)s: 打印日志级别名称
%(pathname)s: 打印当前执行程序的路径,其实就是sys.argv[0]
%(filename)s: 打印当前执行程序名
%(funcName)s: 打印日志的当前函数
%(lineno)d: 打印日志的当前行号
%(asctime)s: 打印日志的时间
%(thread)d: 打印线程ID
%(threadName)s: 打印线程名称
%(process)d: 打印进程ID
%(message)s: 打印日志信息
datefmt: 指定时间格式,同time.strftime()
level: 设置日志级别,默认为logging.WARNING
stream: 指定将日志的输出流,可以指定输出到sys.stderr,sys.stdout或者文件,默认输出到sys.stderr,当stream和filename同时指定时,stream被忽略
- 将日志同时输出到文件和屏幕:
# 设置日志名称,通常用主文件命名
logger1 = logging.getLogger(__name__)
logging.basicConfig(filename=os.path.join(self.writer.log_dir, 'training.log'),level=logging.DEBUG)
logging.info(f"Start SimCLR training for {
self.args.epochs} epochs.")
logging.info(f"Training with gpu: {
self.args.disable_cuda}.")
....
# 日志输出到文件
fh1 = logging.FileHandler(filename='a1.log', encoding='utf-8') # 文件a1
fh2 = logging.FileHandler(filename='a2.log', encoding='utf-8') # 文件a2
sh = logging.StreamHandler() # 日志输出到终端片
边栏推荐
- MHA high availability coordination and failover
- 券商经理给的开户链接办理股票开户安全吗?我想开个户
- Two point answer, 01 score planning (mean / median conversion), DP
- 备战数学建模32-相关性分析2
- Keil4打开单片机工程一片空白,cpu100%程序卡死的问题解决
- Summary of decimal point of amount and price at work and pit
- Use abp Zero builds a third-party login module (II): server development
- 聊聊几位大厂清华同学的近况
- 信息学奥赛一本通 1400:统计单词数 (字符串匹配)
- Redis事务与watch指令
猜你喜欢

ArcGIS batch render layer script

The engine "node" is inconsistent with this module
![[cloud native] codeless IVX editor programmable by](/img/10/7c56e46df69be6be522a477b00ec05.png)
[cloud native] codeless IVX editor programmable by "everyone"

一篇抄十篇,CVPR Oral被指大量抄袭,大会最后一天曝光!

【雲原生】 ”人人皆可“ 編程的無代碼 iVX 編輯器

710. random numbers in the blacklist

重磅白皮书发布,华为持续引领未来智慧园区建设新模式

【使用yarn运行报错】The engine “node“ is incompatible with this module.

One copy ten, CVPR oral was accused of plagiarizing a lot, and it was exposed on the last day of the conference!

Deploy the flask environment using the pagoda panel
随机推荐
Is it safe to open a stock account with the account manager online??
Naacl2022: (code practice) good visual guidance promotes better feature extraction, multimodal named entity recognition (with source code download)
It's natural for the landlord to take the rent to repay the mortgage
NVIDIA SMI error
Stream常用操作以及原理探索
Is the account opening link given by the broker manager safe? Who can I open an account with?
Datasets dataset class (2)
Flex & bison start
R语言epiDisplay包的dotplot函数通过点图的形式可视化不同区间数据点的频率、使用by参数指定分组参数可视化不同分组的点图分布、使用cex.X.axis参数指定X轴轴刻度数值标签字体的大小
Understand the difference and use between jsonarray and jsonobject
A标签去掉下划线
Excel-vba quick start (II. Condition judgment and circulation)
Question bank and answers of the latest Guizhou construction eight (Mechanics) simulated examination in 2022
JVM 输出 GC 日志导致 JVM 卡住,我 TM 人傻了
Summary of decimal point of amount and price at work and pit
工作上对金额价格类小数点的总结以及坑
Use abp Zero builds a third-party login module (II): server development
GDAL multiband synthesis tool
Leaflet loading ArcGIS for server map layers
扩展-Hooks