当前位置:网站首页>Anomaly-Transformer (ICLR 2022 Spotlight)复现过程及问题
Anomaly-Transformer (ICLR 2022 Spotlight)复现过程及问题
2022-07-01 22:56:00 【理心炼丹】
作者推荐的是 python3.6,pytorch 1.4
1. 环境修改
尝试安装 pytorch 1.4 运行,但是代码会卡住,并且没有报错。定位错误在:Anomaly-Transformer/model/attn.py
self.distances = torch.zeros((window_size, window_size)).cuda()
.cuda() 卡住:原因是 安装的 pytorch 1.4 对应的CUDA 版本为 10.x,算力是 sm_86,CUDA 10.x 最高支持到 sm_75,因此需要CUDA 11.x来支持sm_8.x。
因此升级 我的环境 python3.7, pytorch 1.12 , 显卡3080Ti, CUDA 版本:11.3
conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
再次运行训练脚本,又报错:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [512, 25]], which is output 0 of AsStridedBackward0, is at version 2; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
解决:注释掉Anomaly-Transformer/solver.py 的第一个 .step():
# Minimax strategy
loss1.backward(retain_graph=True)
# self.optimizer.step()
loss2.backward()
self.optimizer.step()
参考:Why the optimizer.step() write twice? · Issue #8 · thuml/Anomaly-Transformer · GitHub
2. 恭喜! 成功运行!
python main.py --anormly_ratio 1 --num_epochs 3 --batch_size 128 --mode train --dataset PSM --data_path dataset/PSM --input_c 25 --output_c 25
------------ Options -------------
anormly_ratio: 1.0
batch_size: 128
data_path: dataset/PSM
dataset: PSM
input_c: 25
k: 3
lr: 0.0001
mode: train
model_save_path: checkpoints
num_epochs: 3
output_c: 25
pretrained_model: None
win_size: 100
======================TEST MODE======================
/opt/conda/lib/python3.7/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='none' instead.
warnings.warn(warning.format(ret))
Threshold : 0.002150955616962149
pred: (87800,)
gt: (87800,)
pred: (87800,)
gt: (87800,)
Accuracy : 0.9848, Precision : 0.9713, Recall : 0.9739, F-score : 0.9726
论文中的结果:对于PSM数据集
P: 96.91,R: 98.9, F1: 97.89
复现的 Recall 略低。但是 Precision 略高。二者本就是需要权衡。可以通过调整上面的 Threshold : 0.002150955616962149 平衡二者。
边栏推荐
- 神经网络物联网的发展趋势和未来方向
- from pip._internal.cli.main import main ModuleNotFoundError: No module named ‘pip‘
- 2021 RoboCom 世界机器人开发者大赛-本科组初赛
- Postgresql源码(58)元组拼接heap_form_tuple剖析
- AAAI22 | 结构标记和交互建模:用于图分类的“SLIM”网络
- What is the difference between memory leak and memory overflow?
- Matplotlib常用設置
- Current situation and future development trend of Internet of things
- 【无标题】
- Daily three questions 6.29
猜你喜欢
Postgresql源码(57)HOT更新为什么性能差距那么大?
硅谷产品实战学习感触
2022年危险化学品经营单位安全管理人员考试题及在线模拟考试
CKS CKA CKAD 将终端更改为远程桌面
What professional classification does the application of Internet of things technology belong to
SWT / anr problem - SWT causes kernel fuse deadlock
赵福全:短期解决保供,长期要打造安全、高效有韧性的供应链
2021 RoboCom 世界机器人开发者大赛-高职组复赛
Stm32f030f4 drives tim1637 nixie tube chip
2021 RoboCom 世界机器人开发者大赛-高职组初赛
随机推荐
Understanding threads
Redis~02 缓存:更新数据时如何保证MySQL和Redis中的数据一致性?
Zero foundation tutorial of Internet of things development
Switch to software testing, knowing these four points is enough!
The difference between timer and scheduledthreadpoolexecutor
RPA: Bank digitalization, business process automation "a small step", and loan review efficiency "a big step"
YOGA27多维一体电脑,兼具出色外观与高端配置
STM32F030F4驱动TIM1637数码管芯片
Microservice stability management
玻璃马赛克
认识线程
jpa手写sql,用自定义实体类接收
Postgresql源码(57)HOT更新为什么性能差距那么大?
神经网络物联网的未来趋势与发展
硅谷产品实战学习感触
Matplotlib common settings
Linux基础 —— CentOS7 离线安装 MySQL
Wechat personal small store one click opening assistant applet development
flutter Unable to load asset: assets/images/888. png
Redis~02 cache: how to ensure data consistency in MySQL and redis when updating data?