当前位置:网站首页>(5) fastai application
(5) fastai application
2022-07-31 00:22:00 【_helen_520】
目前:fastai lesson8~lesson11parts have been refactored
- mnist数据集比较简单、28×28的像素,都是一样的.The background is also relatively clean,It is also a classification task,It can be handled with a simple network.
1. 使用线性模型
# 在这里对mnistThe dataset is classified,实现acc的提升
from exp.nb_09c import *
""" 0.数据准备
I didn't write it myselfDataBunch,ItemList等接口.ImageList的get是要去open的
mnist走的还是pytorch的Dataloader的接口
"""
x_train,y_train,x_valid,y_valid = get_data() # 这个函数在nb_02.py中定义
x_train,x_valid = normalize_to(x_train,x_valid) # nb_05.py中
n,m = x_train.shape
c = y_train.max().item() + 1
bs = 512
# 使用Dataset来管理batch数据: nb_03.py
train_ds,valid_ds = Dataset(x_train, y_train),Dataset(x_valid, y_valid)
# nb_08.py get_dls在nb_03.py,使用的是Dataloader
data = DataBunch(*get_dls(train_ds, valid_ds, bs), c)
loss_func = F.cross_entropy
""" 1. 线性模型(50,10),使用pytorch的nn.Module基类,not refactored
"""
nh = 50
def init_linear_(m, f):
if isinstance(m, nn.Linear):
f(m.weight, a=0.1)
if getattr(m, 'bias', None) is not None: m.bias.data.zero_()
for l in m.children(): init_linear_(l, f)
def init_linear(m, uniform=False):
f = init.kaiming_uniform_ if uniform else init.kaiming_normal_
init_linear_(m, f)
# ① model,Because it is a custom linear model,没有初始化
model = nn.Sequential(nn.Linear(m, nh), nn.ReLU(), nn.Linear(nh, c))
lr = 0.5
# get_runner nb_06.py 由于不是CNN网络,所以不是get_cnn_runner
# 使用get_runner而不是get_learner
# device = torch.device('cuda', 0)
# torch.cuda.set_device(device)
cbfs = [partial(AvgStatsCallback, accuracy), CudaCallback, Recorder, ProgressCallback]
phases = combine_scheds([0.3, 0.7], cos_1cycle_anneal(0.2, 0.6, 0.2))
sched = ParamScheduler('lr', phases)
# Learner在nb_09b.py 线性模型、交叉熵loss、lr、cbfs、opt 在Learner.fit中有opt的初始化函数的.
# ② 优化器 nb_09b.py 简单的sgd梯度下降,weight_decay是l2正则化
learn = Learner(model=model, data=data, loss_func=loss_func, lr=lr, cb_funcs=cbfs)
# 可以在fit的时候添加一个cbs
# sgd: p = p - lr*p.grad
# weight_decay: p = p * ( 1 - lr*wd)
def append_stats(hook, mod, inp, outp):
if not hasattr(hook,'stats'): hook.stats = ([],[],[])
means,stds,hists = hook.stats
means.append(outp.data.mean().cpu()) # The value of the activation element
stds .append(outp.data.std().cpu())
hists.append(outp.data.cpu().histc(40,0,10)) #histc isn't implemented on the GPU
def get_hist(h):
return torch.stack(h.stats[2]).t().float().log1p() # h.stats[2]为直方图
with Hooks(model, append_stats) as hooks:
learn.fit(1) # pytorch_init + sgd
fig, [ax0, ax1] = plt.subplots(1,2, figsize=(10,4))
for h in hooks:
ms, ss, hi = h.stats
ax0.plot(ms), ax0.set_title("act_means", loc='center'), ax0.set_xlabel('batches')
ax0.legend(range(3))
ax1.plot(ss), ax1.set_title("act_stds", loc='center'), ax1.set_xlabel('batches')
ax1.legend(range(3))
fig,axes = plt.subplots(2,2, figsize=(15,6))
for ax,h in zip(axes.flatten(), hooks[:3]):
ax.imshow(get_hist(h), origin='lower'), ax.set_title("acts_hist", loc='center'), ax.set_xlabel('activiations')
ax.axis('off')
plt.tight_layout()
def get_min(h): # Add up the first two numbers of the histogram
h1 = torch.stack(h.stats[2]).t().float()
return h1[:2].sum(0)/h1.sum(0)
fig,axes = plt.subplots(2,2, figsize=(15,6))
for ax,h in zip(axes.flatten(), hooks[:3]):
ax.plot(get_min(h)), ax.set_title("hist[:2] zero ratio", loc='center'), plt.xlabel('batches')
ax.set_ylim(0,1)
plt.tight_layout()


① Linear的模型,需要自己写一个.Learner在nb_09b.py中,opt是在fittime to buildOpt的对象.
② opt如果是sgd,就是默认的.不写就可以了.
③ 如果cuda启动不起来,电脑需要重启.
2. ImagenetteDataset debug logging
Pytorch 调试常用
代码仓库:Dive-into-DL-PyTorch/2.2_tensor.md at master · ShusenTang/Dive-into-DL-PyTorch · GitHub
李沐的《动手学深度学习》原书中MXNet代码实现改为PyTorch实现.本项目面向对深度学习感兴趣,尤其是想使用PyTorch进行深度学习的童鞋.本项目并不要求你有任何深度学习或者机器学习的背景知识,你只需了解基础的数学和编程,如基础的线性代数、微分和概率,以及基础的Python编程.
目录如下所示:

1. Tensor的使用



边栏推荐
- Point Cloud Scene Reconstruction with Depth Estimation
- Jetpack Compose学习(8)——State及remeber
- ES6中 async 函数、await表达式 的基本用法
- 封装、获取系统用户信息、角色及权限控制
- How to install joiplay emulator rtp
- Homework: iptables prevent nmap scan and binlog
- Error ER_NOT_SUPPORTED_AUTH_MODE Client does not support authentication protocol requested by serv
- XSS相关知识
- WebServer process explanation (registration module)
- 从两个易错的笔试题深入理解自增运算符
猜你喜欢

@requestmapping注解的作用及用法

h264和h265解码上的区别

Error ER_NOT_SUPPORTED_AUTH_MODE Client does not support authentication protocol requested by serv
How to ensure the consistency of database and cache data?

WEB安全基础 - - -漏洞扫描器

xss靶机训练【实现弹窗即成功】

MySQL的grant语句

IOT cross-platform component design scheme

How to solve types joiplay simulator does not support this game

Understand from the 11 common examples of judging equality of packaging types in the written test: packaging types, the principle of automatic boxing and unboxing, the timing of boxing and unboxing, a
随机推荐
【深入浅出玩转FPGA学习13-----------测试用例设计1】
binglog日志追踪:数据备份并备份追踪
ELK部署脚本---亲测可用
mysql索引失效的常见9种原因详解
[In-depth and easy-to-follow FPGA learning 14----------Test case design 2]
How to import game archives in joiplay emulator
joiplay模拟器不支持此游戏类型怎么解决
借助深度估计的点云场景重建
SWM32系列教程6-Systick和PWM
DNS resolution process [visit website]
joiplay模拟器如何调中文
数据库的严格模式
PHP图片添加文字水印
Method for deduplication of object collection
45.【list链表的应用】
【唐宇迪 深度学习-3D点云实战系列】学习笔记
asser利用蚁剑登录
Restricted character bypass
论文理解:“Designing and training of a dual CNN for image denoising“
【深度学习】Transformer模型详解