当前位置:网站首页>Deep learning 7 transformer series instance segmentation mask2former
Deep learning 7 transformer series instance segmentation mask2former
2022-07-04 14:39:00 【Racing CD】
List of articles
Preface
Text
Open source address
https://github.com/facebookresearch/Mask2Former
install
Reference resources https://github.com/facebookresearch/Mask2Former/blob/main/INSTALL.md
conda create --name mask2former python=3.8 -y
conda activate mask2former
pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
# Here to cudatoolkit Source change
conda install cudatoolkit -c anaconda
pip install opencv-python
### install detectron2 API
git clone [email protected]:facebookresearch/detectron2.git
cd detectron2
pip install -e .
pip install git+https://github.com/cocodataset/panopticapi.git
pip install git+https://github.com/mcordts/cityscapesScripts.git
cd ..
git clone [email protected]:facebookresearch/Mask2Former.git
cd Mask2Former
pip install -r requirements.txt
cd mask2former/modeling/pixel_decoder/ops
sh make.sh
verification ( Download the corresponding model )
conda activate mask2former
cd Mask2Former/demo
python demo.py --config-file ../configs/coco/panoptic-segmentation/maskformer2_R50_bs16_50ep.yaml --input 1.jpg --output ./output
python demo.py --config-file ../configs/coco/instance-segmentation/swin/maskformer2_swin_tiny_bs16_50ep.yaml --input 2.jpg --output ./tiny --opts MODEL.WEIGHTS "../weights/swin_tiny_patch4_window7_224.pkl"
python demo.py --config-file ../configs/coco/instance-segmentation/swin/maskformer2_swin_large_IN21k_384_bs16_100ep.yaml --input 2.jpg --output ./large --opts MODEL.WEIGHTS "../weights/swin_large_patch4_window12_384_22k.pkl"
Training
Mask2Former All training reasoning is based on detectron2 API, You need to build your own data set before training , And to detectron2 API register
Register a custom dataset
Specify :
https://detectron2.readthedocs.io/en/latest/tutorials/datasets.html
Registration instance :
https://colab.research.google.com/drive/16jcaJoc6bCFAQ96jDe2HwtXj7BMD_-m5#scrollTo=PIbAM2pv-urF
from detectron2.structures import BoxMode
def get_balloon_dicts(img_dir):
json_file = os.path.join(img_dir, "via_region_data.json")
with open(json_file) as f:
imgs_anns = json.load(f)
dataset_dicts = []
for idx, v in enumerate(imgs_anns.values()):
record = {
}
filename = os.path.join(img_dir, v["filename"])
height, width = cv2.imread(filename).shape[:2]
record["file_name"] = filename
record["image_id"] = idx
record["height"] = height
record["width"] = width
annos = v["regions"]
objs = []
for _, anno in annos.items():
assert not anno["region_attributes"]
anno = anno["shape_attributes"]
px = anno["all_points_x"]
py = anno["all_points_y"]
poly = [(x + 0.5, y + 0.5) for x, y in zip(px, py)]
poly = [p for x in poly for p in x]
obj = {
"bbox": [np.min(px), np.min(py), np.max(px), np.max(py)],
"bbox_mode": BoxMode.XYXY_ABS,
"segmentation": [poly],
"category_id": 0,
}
objs.append(obj)
record["annotations"] = objs
dataset_dicts.append(record)
return dataset_dicts
for d in ["train", "val"]:
DatasetCatalog.register("balloon_" + d, lambda d=d: get_balloon_dicts("balloon/" + d))
MetadataCatalog.get("balloon_" + d).set(thing_classes=["balloon"])
balloon_metadata = MetadataCatalog.get("balloon_train")
COCO Format datasets , Please call directly API register
from detectron2.data.datasets import register_coco_instances
register_coco_instances("my_dataset_train", {
}, "json_annotation_train.json", "path/to/image/dir")
register_coco_instances("my_dataset_val", {
}, "json_annotation_val.json", "path/to/image/dir")
Designated training data set
BASE: …/maskformer2_R50_bs16_50ep.yaml
DATASETS:
TRAIN: (“my_dataset_train”,)
TEST: (“my_dataset_val”,)
MODEL:
BACKBONE:
NAME: “D2SwinTransformer”
SWIN:
EMBED_DIM: 192
DEPTHS: [2, 2, 18, 2]
NUM_HEADS: [6, 12, 24, 48]
WINDOW_SIZE: 12
APE: False
DROP_PATH_RATE: 0.3
PATCH_NORM: True
PRETRAIN_IMG_SIZE: 384
WEIGHTS: “swin_large_patch4_window12_384_22k.pkl”
PIXEL_MEAN: [123.675, 116.280, 103.530]
PIXEL_STD: [58.395, 57.120, 57.375]
MASK_FORMER:
NUM_OBJECT_QUERIES: 200
SOLVER:
STEPS: (655556, 710184)
MAX_ITER: 737500
Training
cd Mask2Former
python train_net.py --num-gpus 1 --config-file configs/coco/instance-segmentation/swin/maskformer2_swin_large_IN21k_384_bs16_100ep.yaml MODEL.WEIGHTS "weights/swin_large_patch4_window12_384_22k.pkl"
Condition handling
1) Insufficient memory
RuntimeError: CUDA out of memory. Tried to allocate 410.00 MiB (GPU 0; 10.91 GiB total capacity; 4.24 GiB already allocated; 151.44 MiB free; 4.62 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
【 Solution 】 Use smaller models and smaller batch_size, Modify... In the configuration file , Its configuration file depends on layers , Pay attention to the parameters set at each layer
SOLVER:
IMS_PER_BATCH: 1
2)
File “/dataset/projects/Mask2Former/mask2former/modeling/matcher.py”, line 141, in memory_efficient_forward
cost_dice = batch_dice_loss_jit(out_mask, tgt_mask)
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: Global alloc not supported yet
【 Solution 】 Reference resources https://github.com/facebookresearch/Mask2Former/issues/4
take batch_dice_loss_jit Replace with batch_dice_loss
# cost_dice = batch_dice_loss_jit(out_mask, tgt_mask)
cost_dice = batch_dice_loss(out_mask, tgt_mask)
3) The classification number of data set is inconsistent with the model
Just modify the configuration file
_BASE_: ../maskformer2_R50_bs16_50ep.yaml
MODEL:
RETINANET:
NUM_CLASSES: 2
ROI_HEADS:
NUM_CLASSES: 2
SEM_SEG_HEAD:
NUM_CLASSES: 2
BACKBONE:
NAME: "D2SwinTransformer"
SWIN:
EMBED_DIM: 96
DEPTHS: [2, 2, 18, 2]
NUM_HEADS: [3, 6, 12, 24]
WINDOW_SIZE: 7
APE: False
DROP_PATH_RATE: 0.3
PATCH_NORM: True
WEIGHTS: "swin_small_patch4_window7_224.pkl"
PIXEL_MEAN: [123.675, 116.280, 103.530]
PIXEL_STD: [58.395, 57.120, 57.375]
DATASETS:
TRAIN: ("my_dataset_train",)
TEST: ("my_dataset_val",)
SOLVER:
IMS_PER_BATCH: 1
DATALOADER:
NUM_WORKERS: 1
OUTPUT_DIR: ./output/small_wf_alarm
边栏推荐
- 【MySQL从入门到精通】【高级篇】(五)MySQL的SQL语句执行流程
- (1)性能调优的标准和做好调优的正确姿势-有性能问题,上HeapDump性能社区!
- nowcoder重排链表
- Nowcoder reverse linked list
- Combined with case: the usage of the lowest API (processfunction) in Flink framework
- AI与生命科学
- LVGL 8.2 List
- 92. (cesium chapter) cesium building layering
- Leetcode 61: 旋转链表
- LVGL 8.2 Line wrap, recoloring and scrolling
猜你喜欢

LVGL 8.2 LED

Learn kernel 3: use GDB to track the kernel call chain

程序员自曝接私活:10个月时间接了30多个单子,纯收入40万

10. (map data) offline terrain data processing (for cesium)

Detailed analysis of pytorch's automatic derivation mechanism, pytorch's core magic

《opencv学习笔记》-- 线性滤波:方框滤波、均值滤波、高斯滤波

商業智能BI財務分析,狹義的財務分析和廣義的財務分析有何不同?
![[MySQL from introduction to proficiency] [advanced chapter] (IV) MySQL permission management and control](/img/cc/70007321395afe3a9fc6b6032d30aa.png)
[MySQL from introduction to proficiency] [advanced chapter] (IV) MySQL permission management and control

A keepalived high availability accident made me learn it again

Digi重启XBee-Pro S2C生产,有些差别需要注意
随机推荐
[information retrieval] link analysis
Abnormal value detection using shap value
Digi restarts XBee Pro S2C production. Some differences need to be noted
Chapter 16 string localization and message Dictionary (2)
First experience of ViewModel
聊聊保证线程安全的 10 个小技巧
nowcoder重排链表
数据埋点的一些问题和想法
Chapter 17 process memory
曝光一下阿里的工资待遇和职位级别
[cloud native] how can I compete with this database?
R language dplyr package summary_ If function calculates the mean and median of all numerical data columns in dataframe data, and summarizes all numerical variables based on conditions
Redis daily notes
LVLG 8.2 circular scrolling animation of a label
Solutions aux problèmes d'utilisation de l'au ou du povo 2 dans le riz rouge k20pro MIUI 12.5
LVGL 8.2 Line
Practical puzzle solving | how to extract irregular ROI regions in opencv
Nowcoder rearrange linked list
Summary of common problems in development
Digi XBee 3 rf: 4 protocols, 3 packages, 10 major functions