当前位置:网站首页>【深度学习】AI一键换天

【深度学习】AI一键换天

2022-07-07 23:16:00 InfoQ

1.实验目标

1.了解图像分割的基本应用;
2.了解运动估计的基本应用;
3.了解图像混合的基本应用。

2.案例内容介绍

案例链接
OBS - JupyterLab (huaweicloud.com)
无论是拍人拍景或是其他,“天空”都可以说是摄像中的关键元素。比如,一张平平无奇的景色图加上落日余晖的天空色调,是不是有内味了?(随手就可以变换出各种天空效果:晴天、彩虹、晚霞、暮光、夕阳等等)
当然,自然的天空还不是最酷炫的,今天给大家介绍一款基于原生视频的AI处理方法,不仅可以一键切置换天空背景,还可以打造任意“天空之城”。比如,《星际迷航》等科幻电影中经常出现的浩瀚星空、宇宙飞船,也可以利用这项技术融入随手拍的视频中,路人拍摄的公路片也能秒变科幻片,画面毫无违和感。好像只要脑洞够大,利用这项AI技术,可以创作无限种玩法。
基于视觉的视频天空替换和协调方法,该方法可以在具有可控风格的视频中自动生成逼真的天空背景。与以前的天空编辑方法专注于静态照片或需要集成在智能手机中的惯性测量装置拍摄视频不同,该方法完全基于视觉,对捕获设备没有任何要求,并且可以很好地应用于在线或离线处理场景。
算法流程大致可以分为三个步骤:
(1) 天空抠图
这一步主要是通过对蒙版数据集进行训练,将图片中的天空和其它物体进行像素级的划分,将天空部分从图片中分离。
(2) 运动估计
对图片中物体的位移情况进行分析,预估相机的移动方向,使替换后的天空和之前的天空位移一致。
(3) 图像混合
将去掉天空的原视频和要替换后的天空视频进行融合,同时对非天空的部分采用色彩叠加,使天空和其它物体的视觉效果相近,使视频效果更加逼真。

最后,算法使用数据增强的方法模拟出同一张图片在不同光照和天气的情况下的图片,使算法具有更强的适应性。

3.实验步骤

3.1安装和导入依赖包

import os
import moxing as mox
 
file_name = 'SkyAR'
if not os.path.exists(file_name):
 mox.file.copy('obs://modelarts-labs-bj4-v2/case_zoo/SkyAR/SkyAR.zip', 'SkyAR.zip')
 os.system('unzip SkyAR.zip')
 os.system('rm SkyAR.zip')
mox.file.copy_parallel('obs://modelarts-labs-bj4-v2/case_zoo/SkyAR/resnet50-19c8e357.pth', '/home/ma-user/.cache/torch/checkpoints/resnet50-19c8e357.pth')

!pip uninstall opencv-python -y
!pip uninstall opencv-contrib-python -y

!pip install opencv-contrib-python==4.5.3.56

cd SkyAR/
import time
import json
import base64
import numpy as np
import matplotlib.pyplot as plt
import cv2
import argparse
from networks import *
from skyboxengine import *
import utils
import torch
from IPython.display import clear_output, Image, display, HTML
 
%matplotlib inline
 
# 如果存在GPU则在GPU上面运行
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

3.2设定算法参数 

SkyAR算法提供了以下五个参数来调整换天的效果:
skybox_center_crop
: 天空体中心偏移
auto_light_matching
: 是否自动亮度匹配
relighting_factor
: 补光
recoloring_factor
: 重新着色
halo_effect
: 是否开启光环效应
且提供了 
datadir
 和 
skybox
 两个参数来指定待处理的原视频和要替换的天空图片,通过路径进行指定即可,如下所示:
parameter = {
 "net_G": "coord_resnet50",
 "ckptdir": "./checkpoints_G_coord_resnet50",
 
 "input_mode": "video",
 "datadir": "./test_videos/sky.mp4", # 待处理的原视频路径
 "skybox": "sky.jpg", # 要替换的天空图片路径
 
 "in_size_w": 384,
 "in_size_h": 384,
 "out_size_w": 845,
 "out_size_h": 480,
 
 "skybox_center_crop": 0.5,
 "auto_light_matching": False,
 "relighting_factor": 0.8,
 "recoloring_factor": 0.5,
 "halo_effect": True,
 
 "output_dir": "./jpg_output",
 "save_jpgs": False
}
 
str_json = json.dumps(parameter)

3.3预览一下原视频

video_name = parameter['datadir']
 
def arrayShow(img):
 img = cv2.resize(img, (0, 0), fx=0.25, fy=0.25, interpolation=cv2.INTER_NEAREST)
 _,ret = cv2.imencode('.jpg', img)
 return Image(data=ret)
 
# 打开一个视频流
cap = cv2.VideoCapture(video_name)
 
frame_id = 0
while True:
 try:
 clear_output(wait=True) # 清除之前的显示
 ret, frame = cap.read() # 读取一帧图片
 if ret:
 frame_id += 1
 if frame_id > 200:
 break
 cv2.putText(frame, str(frame_id), (5, 15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1) # 画frame_id
 tmp = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # 转换色彩模式
 img = arrayShow(frame)
 display(img) # 显示图片
 time.sleep(0.05) # 线程睡眠一段时间再处理下一帧图片
 else:
 break
 except KeyboardInterrupt:
 cap.release()
cap.release()

预览一下要替换的天空图片

img= cv2.imread(os.path.join('./skybox', parameter['skybox']))
img2 = img[:, :, ::-1]
plt.imshow(img2)

3.4定义SkyFilter类 

class Struct:
 def __init__(self, **entries):
 self.__dict__.update(entries)
 
def parse_config():
 data = json.loads(str_json)
 args = Struct(**data)
 
 return args
 
args = parse_config()

class SkyFilter():
 
 def __init__(self, args):
 
 self.ckptdir = args.ckptdir
 self.datadir = args.datadir
 self.input_mode = args.input_mode
 
 self.in_size_w, self.in_size_h = args.in_size_w, args.in_size_h
 self.out_size_w, self.out_size_h = args.out_size_w, args.out_size_h
 
 self.skyboxengine = SkyBox(args)
 
 self.net_G = define_G(input_nc=3, output_nc=1, ngf=64, netG=args.net_G).to(device)
 self.load_model()
 
 self.video_writer = cv2.VideoWriter('out.avi',
 cv2.VideoWriter_fourcc(*'MJPG'),
 20.0,
 (args.out_size_w, args.out_size_h))
 self.video_writer_cat = cv2.VideoWriter('compare.avi',
 cv2.VideoWriter_fourcc(*'MJPG'),
 20.0,
 (2*args.out_size_w, args.out_size_h))
 
 if os.path.exists(args.output_dir) is False:
 os.mkdir(args.output_dir)
 
 self.output_img_list = []
 
 self.save_jpgs = args.save_jpgs
 
 
 def load_model(self):
 # 加载预训练的天空抠图模型
 print('loading the best checkpoint...')
 checkpoint = torch.load(os.path.join(self.ckptdir, 'best_ckpt.pt'),
 map_location=device)
 self.net_G.load_state_dict(checkpoint['model_G_state_dict'])
 self.net_G.to(device)
 self.net_G.eval()
 
 
 def write_video(self, img_HD, syneth):
 
 frame = np.array(255.0 * syneth[:, :, ::-1], dtype=np.uint8)
 self.video_writer.write(frame)
 
 frame_cat = np.concatenate([img_HD, syneth], axis=1)
 frame_cat = np.array(255.0 * frame_cat[:, :, ::-1], dtype=np.uint8)
 self.video_writer_cat.write(frame_cat)
 
 # 定义结果缓冲区
 self.output_img_list.append(frame_cat)
 
 
 def synthesize(self, img_HD, img_HD_prev):
 
 h, w, c = img_HD.shape
 
 img = cv2.resize(img_HD, (self.in_size_w, self.in_size_h))
 
 img = np.array(img, dtype=np.float32)
 img = torch.tensor(img).permute([2, 0, 1]).unsqueeze(0)
 
 with torch.no_grad():
 G_pred = self.net_G(img.to(device))
 G_pred = torch.nn.functional.interpolate(G_pred,
 (h, w),
 mode='bicubic',
 align_corners=False)
 G_pred = G_pred[0, :].permute([1, 2, 0])
 G_pred = torch.cat([G_pred, G_pred, G_pred], dim=-1)
 G_pred = np.array(G_pred.detach().cpu())
 G_pred = np.clip(G_pred, a_max=1.0, a_min=0.0)
 
 skymask = self.skyboxengine.skymask_refinement(G_pred, img_HD)
 
 syneth = self.skyboxengine.skyblend(img_HD, img_HD_prev, skymask)
 
 return syneth, G_pred, skymask
 
 
 def cvtcolor_and_resize(self, img_HD):
 
 img_HD = cv2.cvtColor(img_HD, cv2.COLOR_BGR2RGB)
 img_HD = np.array(img_HD / 255., dtype=np.float32)
 img_HD = cv2.resize(img_HD, (self.out_size_w, self.out_size_h))
 
 return img_HD
 
 
 def process_video(self):
 # 逐帧处理视频
 cap = cv2.VideoCapture(self.datadir)
 m_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
 img_HD_prev = None
 
 for idx in range(m_frames):
 ret, frame = cap.read()
 if ret:
 img_HD = self.cvtcolor_and_resize(frame)
 
 if img_HD_prev is None:
 img_HD_prev = img_HD
 
 syneth, G_pred, skymask = self.synthesize(img_HD, img_HD_prev)
 
 self.write_video(img_HD, syneth)
 
 img_HD_prev = img_HD
 
 if (idx + 1) % 50 == 0:
 print(f'processing video, frame {idx + 1} / {m_frames} ... ')
 
 else: # 如果到达最后一帧
 break

3.5开始处理视频 

sf = SkyFilter(args)
sf.process_video()

3.6对比原视频和处理后的视频

video_name = "compare.avi"
 
def arrayShow(img):
 _,ret = cv2.imencode('.jpg', img)
 return Image(data=ret)
 
# 打开一个视频流
cap = cv2.VideoCapture(video_name)
 
frame_id = 0
while True:
 try:
 clear_output(wait=True) # 清除之前的显示
 ret, frame = cap.read() # 读取一帧图片
 if ret:
 frame_id += 1
 cv2.putText(frame, str(frame_id), (5, 15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1) # 画frame_id
 tmp = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # 转换色彩模式
 img = arrayShow(frame)
 display(img) # 显示图片
 time.sleep(0.05) # 线程睡眠一段时间再处理下一帧图片
 else:
 break
 except KeyboardInterrupt:
 cap.release()
cap.release()

3.7 生成你自己的换天视频

三个步骤实现自定义视频的换天效果:
(1)在自己本地电脑上准备好一个待处理的mp4视频文件和一张天空图片;
(2)参考
此文档
,将视频文件和图片文件分别上传到ModelArts JupyterLab的SkyAR/test_videos目录和SkyAR/skybox目录下;
(3)修改步骤2 “设定算法参数” 中
datadir
 和 
skybox
 两个参数的路径为你刚上传的视频和图片路径;
(4)重新运行步骤2~6。
原网站

版权声明
本文为[InfoQ]所创,转载请带上原文链接,感谢
https://xie.infoq.cn/article/885d15be6dabede8b65b6ea5e