当前位置:网站首页>【深度学习】AI一键换天
【深度学习】AI一键换天
2022-07-07 23:16:00 【InfoQ】
1.实验目标
2.案例内容介绍
3.实验步骤
3.1安装和导入依赖包
import os
import moxing as mox
file_name = 'SkyAR'
if not os.path.exists(file_name):
mox.file.copy('obs://modelarts-labs-bj4-v2/case_zoo/SkyAR/SkyAR.zip', 'SkyAR.zip')
os.system('unzip SkyAR.zip')
os.system('rm SkyAR.zip')
mox.file.copy_parallel('obs://modelarts-labs-bj4-v2/case_zoo/SkyAR/resnet50-19c8e357.pth', '/home/ma-user/.cache/torch/checkpoints/resnet50-19c8e357.pth')
!pip uninstall opencv-python -y
!pip uninstall opencv-contrib-python -y
!pip install opencv-contrib-python==4.5.3.56
cd SkyAR/
import time
import json
import base64
import numpy as np
import matplotlib.pyplot as plt
import cv2
import argparse
from networks import *
from skyboxengine import *
import utils
import torch
from IPython.display import clear_output, Image, display, HTML
%matplotlib inline
# 如果存在GPU则在GPU上面运行
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
3.2设定算法参数
parameter = {
"net_G": "coord_resnet50",
"ckptdir": "./checkpoints_G_coord_resnet50",
"input_mode": "video",
"datadir": "./test_videos/sky.mp4", # 待处理的原视频路径
"skybox": "sky.jpg", # 要替换的天空图片路径
"in_size_w": 384,
"in_size_h": 384,
"out_size_w": 845,
"out_size_h": 480,
"skybox_center_crop": 0.5,
"auto_light_matching": False,
"relighting_factor": 0.8,
"recoloring_factor": 0.5,
"halo_effect": True,
"output_dir": "./jpg_output",
"save_jpgs": False
}
str_json = json.dumps(parameter)
3.3预览一下原视频
video_name = parameter['datadir']
def arrayShow(img):
img = cv2.resize(img, (0, 0), fx=0.25, fy=0.25, interpolation=cv2.INTER_NEAREST)
_,ret = cv2.imencode('.jpg', img)
return Image(data=ret)
# 打开一个视频流
cap = cv2.VideoCapture(video_name)
frame_id = 0
while True:
try:
clear_output(wait=True) # 清除之前的显示
ret, frame = cap.read() # 读取一帧图片
if ret:
frame_id += 1
if frame_id > 200:
break
cv2.putText(frame, str(frame_id), (5, 15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1) # 画frame_id
tmp = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # 转换色彩模式
img = arrayShow(frame)
display(img) # 显示图片
time.sleep(0.05) # 线程睡眠一段时间再处理下一帧图片
else:
break
except KeyboardInterrupt:
cap.release()
cap.release()
预览一下要替换的天空图片
img= cv2.imread(os.path.join('./skybox', parameter['skybox']))
img2 = img[:, :, ::-1]
plt.imshow(img2)
3.4定义SkyFilter类
class Struct:
def __init__(self, **entries):
self.__dict__.update(entries)
def parse_config():
data = json.loads(str_json)
args = Struct(**data)
return args
args = parse_config()
class SkyFilter():
def __init__(self, args):
self.ckptdir = args.ckptdir
self.datadir = args.datadir
self.input_mode = args.input_mode
self.in_size_w, self.in_size_h = args.in_size_w, args.in_size_h
self.out_size_w, self.out_size_h = args.out_size_w, args.out_size_h
self.skyboxengine = SkyBox(args)
self.net_G = define_G(input_nc=3, output_nc=1, ngf=64, netG=args.net_G).to(device)
self.load_model()
self.video_writer = cv2.VideoWriter('out.avi',
cv2.VideoWriter_fourcc(*'MJPG'),
20.0,
(args.out_size_w, args.out_size_h))
self.video_writer_cat = cv2.VideoWriter('compare.avi',
cv2.VideoWriter_fourcc(*'MJPG'),
20.0,
(2*args.out_size_w, args.out_size_h))
if os.path.exists(args.output_dir) is False:
os.mkdir(args.output_dir)
self.output_img_list = []
self.save_jpgs = args.save_jpgs
def load_model(self):
# 加载预训练的天空抠图模型
print('loading the best checkpoint...')
checkpoint = torch.load(os.path.join(self.ckptdir, 'best_ckpt.pt'),
map_location=device)
self.net_G.load_state_dict(checkpoint['model_G_state_dict'])
self.net_G.to(device)
self.net_G.eval()
def write_video(self, img_HD, syneth):
frame = np.array(255.0 * syneth[:, :, ::-1], dtype=np.uint8)
self.video_writer.write(frame)
frame_cat = np.concatenate([img_HD, syneth], axis=1)
frame_cat = np.array(255.0 * frame_cat[:, :, ::-1], dtype=np.uint8)
self.video_writer_cat.write(frame_cat)
# 定义结果缓冲区
self.output_img_list.append(frame_cat)
def synthesize(self, img_HD, img_HD_prev):
h, w, c = img_HD.shape
img = cv2.resize(img_HD, (self.in_size_w, self.in_size_h))
img = np.array(img, dtype=np.float32)
img = torch.tensor(img).permute([2, 0, 1]).unsqueeze(0)
with torch.no_grad():
G_pred = self.net_G(img.to(device))
G_pred = torch.nn.functional.interpolate(G_pred,
(h, w),
mode='bicubic',
align_corners=False)
G_pred = G_pred[0, :].permute([1, 2, 0])
G_pred = torch.cat([G_pred, G_pred, G_pred], dim=-1)
G_pred = np.array(G_pred.detach().cpu())
G_pred = np.clip(G_pred, a_max=1.0, a_min=0.0)
skymask = self.skyboxengine.skymask_refinement(G_pred, img_HD)
syneth = self.skyboxengine.skyblend(img_HD, img_HD_prev, skymask)
return syneth, G_pred, skymask
def cvtcolor_and_resize(self, img_HD):
img_HD = cv2.cvtColor(img_HD, cv2.COLOR_BGR2RGB)
img_HD = np.array(img_HD / 255., dtype=np.float32)
img_HD = cv2.resize(img_HD, (self.out_size_w, self.out_size_h))
return img_HD
def process_video(self):
# 逐帧处理视频
cap = cv2.VideoCapture(self.datadir)
m_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
img_HD_prev = None
for idx in range(m_frames):
ret, frame = cap.read()
if ret:
img_HD = self.cvtcolor_and_resize(frame)
if img_HD_prev is None:
img_HD_prev = img_HD
syneth, G_pred, skymask = self.synthesize(img_HD, img_HD_prev)
self.write_video(img_HD, syneth)
img_HD_prev = img_HD
if (idx + 1) % 50 == 0:
print(f'processing video, frame {idx + 1} / {m_frames} ... ')
else: # 如果到达最后一帧
break
3.5开始处理视频
sf = SkyFilter(args)
sf.process_video()
3.6对比原视频和处理后的视频
video_name = "compare.avi"
def arrayShow(img):
_,ret = cv2.imencode('.jpg', img)
return Image(data=ret)
# 打开一个视频流
cap = cv2.VideoCapture(video_name)
frame_id = 0
while True:
try:
clear_output(wait=True) # 清除之前的显示
ret, frame = cap.read() # 读取一帧图片
if ret:
frame_id += 1
cv2.putText(frame, str(frame_id), (5, 15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1) # 画frame_id
tmp = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # 转换色彩模式
img = arrayShow(frame)
display(img) # 显示图片
time.sleep(0.05) # 线程睡眠一段时间再处理下一帧图片
else:
break
except KeyboardInterrupt:
cap.release()
cap.release()
3.7 生成你自己的换天视频
边栏推荐
- Letcode43: string multiplication
- What does interface testing test?
- Deep dive kotlin synergy (XXII): flow treatment
- Introduction to paddle - using lenet to realize image classification method I in MNIST
- After going to ByteDance, I learned that there are so many test engineers with an annual salary of 40W?
- 22年秋招心得
- Codeforces Round #804 (Div. 2)(A~D)
- 赞!idea 如何单窗口打开多个项目?
- Implementation of adjacency table of SQLite database storage directory structure 2-construction of directory tree
- Deep dive kotlin collaboration (the end of 23): sharedflow and stateflow
猜你喜欢
CVE-2022-28346:Django SQL注入漏洞
A network composed of three convolution layers completes the image classification task of cifar10 data set
9.卷积神经网络介绍
Binder core API
接口测试要测试什么?
Malware detection method based on convolutional neural network
What does interface testing test?
New library online | cnopendata China Star Hotel data
Course of causality, taught by Jonas Peters, University of Copenhagen
v-for遍历元素样式失效
随机推荐
AI zhetianchuan ml novice decision tree
11.递归神经网络RNN
3.MNIST数据集分类
QT establish signal slots between different classes and transfer parameters
My best game based on wechat applet development
Thinkphp内核工单系统源码商业开源版 多用户+多客服+短信+邮件通知
A brief history of information by James Gleick
第四期SFO销毁,Starfish OS如何对SFO价值赋能?
LeetCode刷题
NTT template for Tourism
3 years of experience, can't you get 20K for the interview and test post? Such a hole?
Four stages of sand table deduction in attack and defense drill
基于卷积神经网络的恶意软件检测方法
Qt不同类之间建立信号槽,并传递参数
新库上线 | 中国记者信息数据
12. RNN is applied to handwritten digit recognition
Introduction to paddle - using lenet to realize image classification method I in MNIST
Codeforces Round #804 (Div. 2)(A~D)
German prime minister says Ukraine will not receive "NATO style" security guarantee
【愚公系列】2022年7月 Go教学课程 006-自动推导类型和输入输出