当前位置:网站首页>模板代码概述

模板代码概述

2022-06-22 07:42:00 鹿衔草啊

模板代码概述

1. 数据集函数

class MyDataset(Dataset):
    def __init__(self, img_id_list, IMG_SIZE, mode='train', augmentation=False):
        """传参,定义参数 1. 数据集列表, - 本地数据,文件名/图片名 - API,图片ID 2. 图片读取尺寸 3. 训练模式or推理模式 4. 是否做Data augmentation ... """
        pass
    
    def __getitem__(self, idx):
        """读取下一个样本 1. 读取本地图片,或读API接口获取base64格式图片 2. 预处理, 如变换图片尺寸 3. 若训练集,读取Mask图片 4. Data augmentation """
        pass

    def __len__(self):
        """定义样本个数 """
        pass
def prepare_trainset():
    """ 1. 切分数据集,训练集/验证集 2. 定义MyDataset训练集、MyDataset验证集 3. 定义Pytorch的DataLoader train_dl = DataLoader( train_dataset, batch_size=16, shuffle=True, #sampler=sampler, num_workers=8, drop_last=True ) val_dl = DataLoader( val_dataset, batch_size=16, shuffle=False, #sampler=sampler, num_workers=8, drop_last=True ) """
    pass

2. Utils函数

  • 训练日志
  • 训练checkpoint
  • GPU交互

3. 分割的评估函数

在这里插入图片描述

4. 训练脚本

def run_training():
    """training pipline 1. 读取network - 加载预训练模型 - 定义训练全部层的参数/哪几层参数 - 定义学习率/为每一层定义学习率 - 定义优化函数optimizer、学习率变化方案scheduler - 2. 训练N_EPOCH次迭代,每一个迭代内: - 用DataLoader循环读取训练集上每一个batch数据(N个图片、N个mask) - 将N个图片传入network,输出模型最后一层的预测(sigmoid概率) - 计算这个batch上的loss、metric,并存下来 - 反向传播,更新参数(.backward())(是否梯度累加) - 计算所有batch上loss、metric的总体均值,代表这个EPOCH - 用DataLoader循环读取验证集上每一个batch数据,与以上操作相似,计算验证集上的loss、metric,用于决定哪一个EPOCH停止训练 - 更新logging、保存checkpoint """
    pass

5. Unet介绍

在这里插入图片描述

原网站

版权声明
本文为[鹿衔草啊]所创,转载请带上原文链接,感谢
https://blog.csdn.net/weixin_45649258/article/details/125395720