当前位置:网站首页>深度学习7 Transformer系列实例分割Mask2Former
深度学习7 Transformer系列实例分割Mask2Former
2022-07-04 13:09:00 【狂奔的CD】
前言
正文
开源地址
https://github.com/facebookresearch/Mask2Former
安装
参考 https://github.com/facebookresearch/Mask2Former/blob/main/INSTALL.md
conda create --name mask2former python=3.8 -y
conda activate mask2former
pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
# 这里给cudatoolkit 换源
conda install cudatoolkit -c anaconda
pip install opencv-python
### 安装detectron2 API
git clone [email protected]:facebookresearch/detectron2.git
cd detectron2
pip install -e .
pip install git+https://github.com/cocodataset/panopticapi.git
pip install git+https://github.com/mcordts/cityscapesScripts.git
cd ..
git clone [email protected]:facebookresearch/Mask2Former.git
cd Mask2Former
pip install -r requirements.txt
cd mask2former/modeling/pixel_decoder/ops
sh make.sh
验证(下载对应模型)
conda activate mask2former
cd Mask2Former/demo
python demo.py --config-file ../configs/coco/panoptic-segmentation/maskformer2_R50_bs16_50ep.yaml --input 1.jpg --output ./output
python demo.py --config-file ../configs/coco/instance-segmentation/swin/maskformer2_swin_tiny_bs16_50ep.yaml --input 2.jpg --output ./tiny --opts MODEL.WEIGHTS "../weights/swin_tiny_patch4_window7_224.pkl"
python demo.py --config-file ../configs/coco/instance-segmentation/swin/maskformer2_swin_large_IN21k_384_bs16_100ep.yaml --input 2.jpg --output ./large --opts MODEL.WEIGHTS "../weights/swin_large_patch4_window12_384_22k.pkl"
训练
Mask2Former的训练推理都基于detectron2 API, 训练前需要构建自己的数据集,并向detectron2 API 注册
注册自定义数据集
具体说明:
https://detectron2.readthedocs.io/en/latest/tutorials/datasets.html
注册实例:
https://colab.research.google.com/drive/16jcaJoc6bCFAQ96jDe2HwtXj7BMD_-m5#scrollTo=PIbAM2pv-urF
from detectron2.structures import BoxMode
def get_balloon_dicts(img_dir):
json_file = os.path.join(img_dir, "via_region_data.json")
with open(json_file) as f:
imgs_anns = json.load(f)
dataset_dicts = []
for idx, v in enumerate(imgs_anns.values()):
record = {
}
filename = os.path.join(img_dir, v["filename"])
height, width = cv2.imread(filename).shape[:2]
record["file_name"] = filename
record["image_id"] = idx
record["height"] = height
record["width"] = width
annos = v["regions"]
objs = []
for _, anno in annos.items():
assert not anno["region_attributes"]
anno = anno["shape_attributes"]
px = anno["all_points_x"]
py = anno["all_points_y"]
poly = [(x + 0.5, y + 0.5) for x, y in zip(px, py)]
poly = [p for x in poly for p in x]
obj = {
"bbox": [np.min(px), np.min(py), np.max(px), np.max(py)],
"bbox_mode": BoxMode.XYXY_ABS,
"segmentation": [poly],
"category_id": 0,
}
objs.append(obj)
record["annotations"] = objs
dataset_dicts.append(record)
return dataset_dicts
for d in ["train", "val"]:
DatasetCatalog.register("balloon_" + d, lambda d=d: get_balloon_dicts("balloon/" + d))
MetadataCatalog.get("balloon_" + d).set(thing_classes=["balloon"])
balloon_metadata = MetadataCatalog.get("balloon_train")
COCO格式数据集,请直接调用API注册
from detectron2.data.datasets import register_coco_instances
register_coco_instances("my_dataset_train", {
}, "json_annotation_train.json", "path/to/image/dir")
register_coco_instances("my_dataset_val", {
}, "json_annotation_val.json", "path/to/image/dir")
指定训练数据集
BASE: …/maskformer2_R50_bs16_50ep.yaml
DATASETS:
TRAIN: (“my_dataset_train”,)
TEST: (“my_dataset_val”,)
MODEL:
BACKBONE:
NAME: “D2SwinTransformer”
SWIN:
EMBED_DIM: 192
DEPTHS: [2, 2, 18, 2]
NUM_HEADS: [6, 12, 24, 48]
WINDOW_SIZE: 12
APE: False
DROP_PATH_RATE: 0.3
PATCH_NORM: True
PRETRAIN_IMG_SIZE: 384
WEIGHTS: “swin_large_patch4_window12_384_22k.pkl”
PIXEL_MEAN: [123.675, 116.280, 103.530]
PIXEL_STD: [58.395, 57.120, 57.375]
MASK_FORMER:
NUM_OBJECT_QUERIES: 200
SOLVER:
STEPS: (655556, 710184)
MAX_ITER: 737500
训练
cd Mask2Former
python train_net.py --num-gpus 1 --config-file configs/coco/instance-segmentation/swin/maskformer2_swin_large_IN21k_384_bs16_100ep.yaml MODEL.WEIGHTS "weights/swin_large_patch4_window12_384_22k.pkl"
状况处理
1)显存不够
RuntimeError: CUDA out of memory. Tried to allocate 410.00 MiB (GPU 0; 10.91 GiB total capacity; 4.24 GiB already allocated; 151.44 MiB free; 4.62 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
【解决方案】采用更小的模型和更小的batch_size, 在配置文件中修改,其配置文件层层依赖,注意每一层设置的参数
SOLVER:
IMS_PER_BATCH: 1
2)
File “/dataset/projects/Mask2Former/mask2former/modeling/matcher.py”, line 141, in memory_efficient_forward
cost_dice = batch_dice_loss_jit(out_mask, tgt_mask)
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: Global alloc not supported yet
【解决方案】 参考 https://github.com/facebookresearch/Mask2Former/issues/4
将batch_dice_loss_jit 替换为batch_dice_loss
# cost_dice = batch_dice_loss_jit(out_mask, tgt_mask)
cost_dice = batch_dice_loss(out_mask, tgt_mask)
3)数据集分类数与模型不一致
修改配置文件即可
_BASE_: ../maskformer2_R50_bs16_50ep.yaml
MODEL:
RETINANET:
NUM_CLASSES: 2
ROI_HEADS:
NUM_CLASSES: 2
SEM_SEG_HEAD:
NUM_CLASSES: 2
BACKBONE:
NAME: "D2SwinTransformer"
SWIN:
EMBED_DIM: 96
DEPTHS: [2, 2, 18, 2]
NUM_HEADS: [3, 6, 12, 24]
WINDOW_SIZE: 7
APE: False
DROP_PATH_RATE: 0.3
PATCH_NORM: True
WEIGHTS: "swin_small_patch4_window7_224.pkl"
PIXEL_MEAN: [123.675, 116.280, 103.530]
PIXEL_STD: [58.395, 57.120, 57.375]
DATASETS:
TRAIN: ("my_dataset_train",)
TEST: ("my_dataset_val",)
SOLVER:
IMS_PER_BATCH: 1
DATALOADER:
NUM_WORKERS: 1
OUTPUT_DIR: ./output/small_wf_alarm
边栏推荐
- 去除重复字母[贪心+单调栈(用数组+len来维持单调序列)]
- What is the difference between Bi financial analysis in a narrow sense and financial analysis in a broad sense?
- [information retrieval] experiment of classification and clustering
- 10.(地图数据篇)离线地形数据处理(供Cesium使用)
- 第十六章 字符串本地化和消息字典(二)
- Detailed index of MySQL
- 实战解惑 | OpenCV中如何提取不规则ROI区域
- Nowcoder reverse linked list
- Redis daily notes
- R language uses the DOTPLOT function of epidisplay package to visualize the frequency of data points in different intervals in the form of point graph, and uses the by parameter to specify the groupin
猜你喜欢
The implementation of OSD on rk1126 platform supports color translucency and multi-channel support for Chinese
No servers available for service: xxxx
迅为IMX6Q开发板QT系统移植tinyplay
Digi restarts XBee Pro S2C production. Some differences need to be noted
第十七章 进程内存
Leetcode T48:旋转图像
【信息检索】链接分析
【MySQL从入门到精通】【高级篇】(四)MySQL权限管理与控制
实战解惑 | OpenCV中如何提取不规则ROI区域
Innovation and development of independent industrial software
随机推荐
ML之shap:基于boston波士顿房价回归预测数据集利用Shap值对LiR线性回归模型实现可解释性案例
Gin integrated Alipay payment
vscode 常用插件汇总
No servers available for service: xxxx
redis 日常笔记
Popular framework: the use of glide
R language uses dplyr package group_ The by function and the summarize function calculate the mean and standard deviation of the target variables based on the grouped variables
No servers available for service: xxxx
Oppo find N2 product form first exposure: supplement all short boards
Leetcode t49: grouping of alphabetic words
Chapter 16 string localization and message Dictionary (2)
92.(cesium篇)cesium楼栋分层
Leetcode T48:旋转图像
Leetcode 61: 旋转链表
[MySQL from introduction to proficiency] [advanced chapter] (V) SQL statement execution process of MySQL
统计php程序运行时间及设置PHP最长运行时间
Data center concept
Progress in architecture
Redis daily notes
Respect others' behavior