当前位置:网站首页>半监督学习之数据增强
半监督学习之数据增强
2022-07-01 21:54:00 【跨考上浙大】
一般选择的未标记图片及其伪标签不直接送入网络进行retraining ,在此之前,我们需要对未标记图片及其伪标签做相应的增强操作,以便网络能学习额外的特征并且缓解对噪音的过拟合。
常见的增强操作:
1:裁剪
直接裁剪或者填充为指定大小
2:翻转
上下翻转、左右翻转、旋转指定角度
3:标准化
4:滤波
5:resize(插值法填充)
6:cutout
cutout是2017年提出的一种数据增强方法,想法比较简单,即在训练时随机裁剪掉图像的一部分,也可以看作是一种类似dropout的正则化方法。
Improved Regularization of Convolutional Neural Networks with Cutout
paper: https://arxiv.org/pdf/1708.04552.pdf
code: https://github.com/uoguelph-mlrg/Cutout
将以上操作使用PIL库实现并封装,方便后续调用。
import numpy as np
from PIL import Image, ImageOps, ImageFilter
import random
import torch
from torchvision import transforms
import cv2
class DataAugmentations():
def __init__(self):
pass
def crop(self,img, mask, size):
# padding height or width if smaller than cropping size
w, h = img.size
padw = size - w if w < size else 0
padh = size - h if h < size else 0
img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=255)
# cropping
w, h = img.size
x = random.randint(0, w - size)
y = random.randint(0, h - size)
img = img.crop((x, y, x + size, y + size))
mask = mask.crop((x, y, x + size, y + size))
return img, mask
def hflip(self,img, mask, p=0.5):
if random.random() < p:
''' FLIP_LEFT_RIGHT = 0 左右翻转 FLIP_TOP_BOTTOM = 1 上下翻转 ROTATE_90 = 2 旋转 ROTATE_180 = 3 ROTATE_270 = 4 TRANSPOSE = 5 TRANSVERSE = 6 '''
img = img.transpose(Image.FLIP_LEFT_RIGHT)
mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
return img, mask
def normalize(self,img, mask=None):
""" :param img: PIL image :param mask: PIL image, corresponding mask :return: normalized torch tensor of image and mask """
img = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])(img)
if mask is not None:
mask = torch.from_numpy(np.array(mask)).long()
return img, mask
return img
def resize(self,img, mask, base_size, ratio_range):
w, h = img.size
long_side = random.randint(int(base_size * ratio_range[0]), int(base_size * ratio_range[1]))
if h > w:
oh = long_side
ow = int(1.0 * w * long_side / h + 0.5)
else:
ow = long_side
oh = int(1.0 * h * long_side / w + 0.5)
img = img.resize((ow, oh), Image.BILINEAR)
mask = mask.resize((ow, oh), Image.NEAREST)
return img, mask
def blur(self,img, p=0.5):
if random.random() < p:
sigma = np.random.uniform(0.1, 2.0)
img = img.filter(ImageFilter.GaussianBlur(radius=sigma))
return img
def cutout(self,img, mask, p=0.5, size_min=0.02, size_max=0.4, ratio_1=0.3,
ratio_2=1/0.3, value_min=0, value_max=255, pixel_level=True):
if random.random() < p:
img = np.array(img)
mask = np.array(mask)
img_h, img_w, img_c = img.shape
while True:
size = np.random.uniform(size_min, size_max) * img_h * img_w
ratio = np.random.uniform(ratio_1, ratio_2)
erase_w = int(np.sqrt(size / ratio))
erase_h = int(np.sqrt(size * ratio))
x = np.random.randint(0, img_w)
y = np.random.randint(0, img_h)
if x + erase_w <= img_w and y + erase_h <= img_h:
break
if pixel_level:
value = np.random.uniform(value_min, value_max, (erase_h, erase_w, img_c))
else:
value = np.random.uniform(value_min, value_max)
img[y:y + erase_h, x:x + erase_w] = value
mask[y:y + erase_h, x:x + erase_w] = 255
img = Image.fromarray(img.astype(np.uint8))
mask = Image.fromarray(mask.astype(np.uint8))
return img, mask
if __name__=="__main__":
SDA=DataAugmentations()
img=Image.open("your path")
mask=Image.open("your path")
image1,mask1=SDA.hflip(img,mask)
image1.show(title="刘亦菲")
cv交流群
837038258
边栏推荐
猜你喜欢
Turn -- the underlying debugging principle of GDB is so simple
447-哔哩哔哩面经1
元宇宙可能成为互联网发展的新方向
台积电全球员工薪酬中位数约46万,CEO约899万;苹果上调日本的 iPhone 售价 ;Vim 9.0 发布|极客头条
[jetcache] how to use jetcache
Turn -- go deep into Lua scripting language, so that you can thoroughly understand the debugging principle
Selection of all-optical technology in the park - Part 2
Daily question brushing record (10)
转--原来gdb的底层调试原理这么简单
转--深入LUA脚本语言,让你彻底明白调试原理
随机推荐
激发新动能 多地发力数字经济
Appium自动化测试基础 — APPium安装(一)
Today's sleep quality record 71 points
Appium自动化测试基础 — 补充:Desired Capabilities参数介绍
I graduated from college in 14 years and changed to software testing in 3 months. My monthly salary was 13.5k. At the age of 32, I finally found the right direction
Understanding of inverted residuals
14年本科毕业,3个月转行软件测试月薪13.5k,32的岁我终于找对了方向
SAP GUI 里的收藏夹事务码管理工具
台积电全球员工薪酬中位数约46万,CEO约899万;苹果上调日本的 iPhone 售价 ;Vim 9.0 发布|极客头条
Use and function of spark analyze command map join broadcast join
Little red book scheme jumps to the specified page
Electron学习(三)之简单交互操作
"Trust machine" empowers development
Use three JS realize the 'ice cream' earth, and let the earth cool for a summer
Réimpression de l'article csdn
【QT小作】封装一个简单的线程管理类
使用 EMQX Cloud 实现物联网设备一机一密验证
转--利用C语言中的setjmp和longjmp,来实现异常捕获和协程
rxjs Observable of 操作符的单步调试分析
使用 Three.js 实现'雪糕'地球,让地球也凉爽一夏