当前位置:网站首页>半监督学习之数据增强
半监督学习之数据增强
2022-07-01 21:54:00 【跨考上浙大】
一般选择的未标记图片及其伪标签不直接送入网络进行retraining ,在此之前,我们需要对未标记图片及其伪标签做相应的增强操作,以便网络能学习额外的特征并且缓解对噪音的过拟合。
常见的增强操作:
1:裁剪
直接裁剪或者填充为指定大小
2:翻转
上下翻转、左右翻转、旋转指定角度
3:标准化
4:滤波
5:resize(插值法填充)
6:cutout
cutout是2017年提出的一种数据增强方法,想法比较简单,即在训练时随机裁剪掉图像的一部分,也可以看作是一种类似dropout的正则化方法。
Improved Regularization of Convolutional Neural Networks with Cutout
paper: https://arxiv.org/pdf/1708.04552.pdf
code: https://github.com/uoguelph-mlrg/Cutout
将以上操作使用PIL库实现并封装,方便后续调用。
import numpy as np
from PIL import Image, ImageOps, ImageFilter
import random
import torch
from torchvision import transforms
import cv2
class DataAugmentations():
def __init__(self):
pass
def crop(self,img, mask, size):
# padding height or width if smaller than cropping size
w, h = img.size
padw = size - w if w < size else 0
padh = size - h if h < size else 0
img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=255)
# cropping
w, h = img.size
x = random.randint(0, w - size)
y = random.randint(0, h - size)
img = img.crop((x, y, x + size, y + size))
mask = mask.crop((x, y, x + size, y + size))
return img, mask
def hflip(self,img, mask, p=0.5):
if random.random() < p:
''' FLIP_LEFT_RIGHT = 0 左右翻转 FLIP_TOP_BOTTOM = 1 上下翻转 ROTATE_90 = 2 旋转 ROTATE_180 = 3 ROTATE_270 = 4 TRANSPOSE = 5 TRANSVERSE = 6 '''
img = img.transpose(Image.FLIP_LEFT_RIGHT)
mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
return img, mask
def normalize(self,img, mask=None):
""" :param img: PIL image :param mask: PIL image, corresponding mask :return: normalized torch tensor of image and mask """
img = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])(img)
if mask is not None:
mask = torch.from_numpy(np.array(mask)).long()
return img, mask
return img
def resize(self,img, mask, base_size, ratio_range):
w, h = img.size
long_side = random.randint(int(base_size * ratio_range[0]), int(base_size * ratio_range[1]))
if h > w:
oh = long_side
ow = int(1.0 * w * long_side / h + 0.5)
else:
ow = long_side
oh = int(1.0 * h * long_side / w + 0.5)
img = img.resize((ow, oh), Image.BILINEAR)
mask = mask.resize((ow, oh), Image.NEAREST)
return img, mask
def blur(self,img, p=0.5):
if random.random() < p:
sigma = np.random.uniform(0.1, 2.0)
img = img.filter(ImageFilter.GaussianBlur(radius=sigma))
return img
def cutout(self,img, mask, p=0.5, size_min=0.02, size_max=0.4, ratio_1=0.3,
ratio_2=1/0.3, value_min=0, value_max=255, pixel_level=True):
if random.random() < p:
img = np.array(img)
mask = np.array(mask)
img_h, img_w, img_c = img.shape
while True:
size = np.random.uniform(size_min, size_max) * img_h * img_w
ratio = np.random.uniform(ratio_1, ratio_2)
erase_w = int(np.sqrt(size / ratio))
erase_h = int(np.sqrt(size * ratio))
x = np.random.randint(0, img_w)
y = np.random.randint(0, img_h)
if x + erase_w <= img_w and y + erase_h <= img_h:
break
if pixel_level:
value = np.random.uniform(value_min, value_max, (erase_h, erase_w, img_c))
else:
value = np.random.uniform(value_min, value_max)
img[y:y + erase_h, x:x + erase_w] = value
mask[y:y + erase_h, x:x + erase_w] = 255
img = Image.fromarray(img.astype(np.uint8))
mask = Image.fromarray(mask.astype(np.uint8))
return img, mask
if __name__=="__main__":
SDA=DataAugmentations()
img=Image.open("your path")
mask=Image.open("your path")
image1,mask1=SDA.hflip(img,mask)
image1.show(title="刘亦菲")
cv交流群
837038258
边栏推荐
- The principle, testing and Countermeasures of malicious software reverse closing EDR
- Appium automated testing foundation - Supplement: introduction to desired capabilities parameters
- 好友新书发布,祝贺(送福利)
- Explain kubernetes network model in detail
- Use three JS realize the 'ice cream' earth, and let the earth cool for a summer
- 转载csdn文章操作
- Using emqx cloud to realize one machine one secret verification of IOT devices
- Ffmpeg learning notes
- 聊一聊Zabbix都监控哪些参数
- 隐藏用户的创建和使用
猜你喜欢

Rust语言——小小白的入门学习05

3DE resources have nothing or nothing wrong

Kubernetes create service access pod

14年本科毕业,3个月转行软件测试月薪13.5k,32的岁我终于找对了方向

104. SAP UI5 表格控件的支持复选(Multi-Select)以及如何用代码一次选中多个表格行项目

Understanding of indexes in MySQL

Vsphere+ and vsan+ are coming! VMware hybrid cloud focus: native, fast migration, mixed load

The median salary of TSMC's global employees is about 460000, and the CEO is about 8.99 million; Apple raised the price of iPhone in Japan; VIM 9.0 release | geek headlines

Fully annotated SSM framework construction

The second anniversary of the three winged bird: the wings are getting richer and the take-off is just around the corner
随机推荐
Stimulate new kinetic energy and promote digital economy in multiple places
Turn -- use setjmp and longjmp in C language to realize exception capture and collaboration
Fully annotated SSM framework construction
分享一个一年经历两次裁员的程序员的一些感触
每日刷题记录 (十)
激发新动能 多地发力数字经济
Tourism Management System
效率提升 - 鼓捣个性化容器开发环境
447 Bili Bili noodles warp 1
倒置残差的理解
General use of qstringlist
Kubernetes create service access pod
mixconv代码
Turn -- bring it and use it: share a gadget for checking memory leaks
Preparation of functional test report
Slope compensation
[daily training] 326 Power of 3
Efficiency improvement - encourage personalized container development environment
FFMpeg学习笔记
[daily training] 66 add one-tenth