当前位置:网站首页>【深度学习】AI一键换天
【深度学习】AI一键换天
2022-07-07 23:16:00 【InfoQ】
1.实验目标
2.案例内容介绍

3.实验步骤
3.1安装和导入依赖包
import os
import moxing as mox
file_name = 'SkyAR'
if not os.path.exists(file_name):
mox.file.copy('obs://modelarts-labs-bj4-v2/case_zoo/SkyAR/SkyAR.zip', 'SkyAR.zip')
os.system('unzip SkyAR.zip')
os.system('rm SkyAR.zip')
mox.file.copy_parallel('obs://modelarts-labs-bj4-v2/case_zoo/SkyAR/resnet50-19c8e357.pth', '/home/ma-user/.cache/torch/checkpoints/resnet50-19c8e357.pth')
!pip uninstall opencv-python -y
!pip uninstall opencv-contrib-python -y
!pip install opencv-contrib-python==4.5.3.56
cd SkyAR/
import time
import json
import base64
import numpy as np
import matplotlib.pyplot as plt
import cv2
import argparse
from networks import *
from skyboxengine import *
import utils
import torch
from IPython.display import clear_output, Image, display, HTML
%matplotlib inline
# 如果存在GPU则在GPU上面运行
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
3.2设定算法参数
parameter = {
"net_G": "coord_resnet50",
"ckptdir": "./checkpoints_G_coord_resnet50",
"input_mode": "video",
"datadir": "./test_videos/sky.mp4", # 待处理的原视频路径
"skybox": "sky.jpg", # 要替换的天空图片路径
"in_size_w": 384,
"in_size_h": 384,
"out_size_w": 845,
"out_size_h": 480,
"skybox_center_crop": 0.5,
"auto_light_matching": False,
"relighting_factor": 0.8,
"recoloring_factor": 0.5,
"halo_effect": True,
"output_dir": "./jpg_output",
"save_jpgs": False
}
str_json = json.dumps(parameter)
3.3预览一下原视频
video_name = parameter['datadir']
def arrayShow(img):
img = cv2.resize(img, (0, 0), fx=0.25, fy=0.25, interpolation=cv2.INTER_NEAREST)
_,ret = cv2.imencode('.jpg', img)
return Image(data=ret)
# 打开一个视频流
cap = cv2.VideoCapture(video_name)
frame_id = 0
while True:
try:
clear_output(wait=True) # 清除之前的显示
ret, frame = cap.read() # 读取一帧图片
if ret:
frame_id += 1
if frame_id > 200:
break
cv2.putText(frame, str(frame_id), (5, 15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1) # 画frame_id
tmp = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # 转换色彩模式
img = arrayShow(frame)
display(img) # 显示图片
time.sleep(0.05) # 线程睡眠一段时间再处理下一帧图片
else:
break
except KeyboardInterrupt:
cap.release()
cap.release()
预览一下要替换的天空图片
img= cv2.imread(os.path.join('./skybox', parameter['skybox']))
img2 = img[:, :, ::-1]
plt.imshow(img2)
3.4定义SkyFilter类
class Struct:
def __init__(self, **entries):
self.__dict__.update(entries)
def parse_config():
data = json.loads(str_json)
args = Struct(**data)
return args
args = parse_config()
class SkyFilter():
def __init__(self, args):
self.ckptdir = args.ckptdir
self.datadir = args.datadir
self.input_mode = args.input_mode
self.in_size_w, self.in_size_h = args.in_size_w, args.in_size_h
self.out_size_w, self.out_size_h = args.out_size_w, args.out_size_h
self.skyboxengine = SkyBox(args)
self.net_G = define_G(input_nc=3, output_nc=1, ngf=64, netG=args.net_G).to(device)
self.load_model()
self.video_writer = cv2.VideoWriter('out.avi',
cv2.VideoWriter_fourcc(*'MJPG'),
20.0,
(args.out_size_w, args.out_size_h))
self.video_writer_cat = cv2.VideoWriter('compare.avi',
cv2.VideoWriter_fourcc(*'MJPG'),
20.0,
(2*args.out_size_w, args.out_size_h))
if os.path.exists(args.output_dir) is False:
os.mkdir(args.output_dir)
self.output_img_list = []
self.save_jpgs = args.save_jpgs
def load_model(self):
# 加载预训练的天空抠图模型
print('loading the best checkpoint...')
checkpoint = torch.load(os.path.join(self.ckptdir, 'best_ckpt.pt'),
map_location=device)
self.net_G.load_state_dict(checkpoint['model_G_state_dict'])
self.net_G.to(device)
self.net_G.eval()
def write_video(self, img_HD, syneth):
frame = np.array(255.0 * syneth[:, :, ::-1], dtype=np.uint8)
self.video_writer.write(frame)
frame_cat = np.concatenate([img_HD, syneth], axis=1)
frame_cat = np.array(255.0 * frame_cat[:, :, ::-1], dtype=np.uint8)
self.video_writer_cat.write(frame_cat)
# 定义结果缓冲区
self.output_img_list.append(frame_cat)
def synthesize(self, img_HD, img_HD_prev):
h, w, c = img_HD.shape
img = cv2.resize(img_HD, (self.in_size_w, self.in_size_h))
img = np.array(img, dtype=np.float32)
img = torch.tensor(img).permute([2, 0, 1]).unsqueeze(0)
with torch.no_grad():
G_pred = self.net_G(img.to(device))
G_pred = torch.nn.functional.interpolate(G_pred,
(h, w),
mode='bicubic',
align_corners=False)
G_pred = G_pred[0, :].permute([1, 2, 0])
G_pred = torch.cat([G_pred, G_pred, G_pred], dim=-1)
G_pred = np.array(G_pred.detach().cpu())
G_pred = np.clip(G_pred, a_max=1.0, a_min=0.0)
skymask = self.skyboxengine.skymask_refinement(G_pred, img_HD)
syneth = self.skyboxengine.skyblend(img_HD, img_HD_prev, skymask)
return syneth, G_pred, skymask
def cvtcolor_and_resize(self, img_HD):
img_HD = cv2.cvtColor(img_HD, cv2.COLOR_BGR2RGB)
img_HD = np.array(img_HD / 255., dtype=np.float32)
img_HD = cv2.resize(img_HD, (self.out_size_w, self.out_size_h))
return img_HD
def process_video(self):
# 逐帧处理视频
cap = cv2.VideoCapture(self.datadir)
m_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
img_HD_prev = None
for idx in range(m_frames):
ret, frame = cap.read()
if ret:
img_HD = self.cvtcolor_and_resize(frame)
if img_HD_prev is None:
img_HD_prev = img_HD
syneth, G_pred, skymask = self.synthesize(img_HD, img_HD_prev)
self.write_video(img_HD, syneth)
img_HD_prev = img_HD
if (idx + 1) % 50 == 0:
print(f'processing video, frame {idx + 1} / {m_frames} ... ')
else: # 如果到达最后一帧
break
3.5开始处理视频
sf = SkyFilter(args)
sf.process_video()
3.6对比原视频和处理后的视频
video_name = "compare.avi"
def arrayShow(img):
_,ret = cv2.imencode('.jpg', img)
return Image(data=ret)
# 打开一个视频流
cap = cv2.VideoCapture(video_name)
frame_id = 0
while True:
try:
clear_output(wait=True) # 清除之前的显示
ret, frame = cap.read() # 读取一帧图片
if ret:
frame_id += 1
cv2.putText(frame, str(frame_id), (5, 15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1) # 画frame_id
tmp = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # 转换色彩模式
img = arrayShow(frame)
display(img) # 显示图片
time.sleep(0.05) # 线程睡眠一段时间再处理下一帧图片
else:
break
except KeyboardInterrupt:
cap.release()
cap.release()
3.7 生成你自己的换天视频
边栏推荐
- 基于人脸识别实现课堂抬头率检测
- 《因果性Causality》教程,哥本哈根大学Jonas Peters讲授
- After going to ByteDance, I learned that there are so many test engineers with an annual salary of 40W?
- 股票开户免费办理佣金最低的券商,手机上开户安全吗
- 2022-07-07: the original array is a monotonic array with numbers greater than 0 and less than or equal to K. there may be equal numbers in it, and the overall trend is increasing. However, the number
- My best game based on wechat applet development
- 接口测试进阶接口脚本使用—apipost(预/后执行脚本)
- Letcode43: string multiplication
- How does starfish OS enable the value of SFO in the fourth phase of SFO destruction?
- 10.CNN应用于手写数字识别
猜你喜欢
Lecture 1: the entry node of the link in the linked list
Huawei switch s5735s-l24t4s-qa2 cannot be remotely accessed by telnet
利用GPU训练网络模型
v-for遍历元素样式失效
From starfish OS' continued deflationary consumption of SFO, the value of SFO in the long run
12. RNN is applied to handwritten digit recognition
New library launched | cnopendata China Time-honored enterprise directory
RPA cloud computer, let RPA out of the box with unlimited computing power?
Kubernetes static pod (static POD)
接口测试要测试什么?
随机推荐
11.递归神经网络RNN
攻防演练中沙盘推演的4个阶段
New library launched | cnopendata China Time-honored enterprise directory
Prediction of the victory or defeat of the League of heroes -- simple KFC Colonel
6.Dropout应用
Hotel
What is load balancing? How does DNS achieve load balancing?
Langchao Yunxi distributed database tracing (II) -- source code analysis
Introduction to ML regression analysis of AI zhetianchuan
STL -- common function replication of string class
5g NR system messages
Cause analysis and solution of too laggy page of [test interview questions]
9.卷积神经网络介绍
基于卷积神经网络的恶意软件检测方法
What does interface testing test?
Reentrantlock fair lock source code Chapter 0
手写一个模拟的ReentrantLock
新库上线 | CnOpenData中华老字号企业名录
Four stages of sand table deduction in attack and defense drill
Marubeni official website applet configuration tutorial is coming (with detailed steps)