当前位置:网站首页>【深度学习】AI一键换天
【深度学习】AI一键换天
2022-07-07 23:16:00 【InfoQ】
1.实验目标
2.案例内容介绍
3.实验步骤
3.1安装和导入依赖包
import os
import moxing as mox
file_name = 'SkyAR'
if not os.path.exists(file_name):
mox.file.copy('obs://modelarts-labs-bj4-v2/case_zoo/SkyAR/SkyAR.zip', 'SkyAR.zip')
os.system('unzip SkyAR.zip')
os.system('rm SkyAR.zip')
mox.file.copy_parallel('obs://modelarts-labs-bj4-v2/case_zoo/SkyAR/resnet50-19c8e357.pth', '/home/ma-user/.cache/torch/checkpoints/resnet50-19c8e357.pth')
!pip uninstall opencv-python -y
!pip uninstall opencv-contrib-python -y
!pip install opencv-contrib-python==4.5.3.56
cd SkyAR/
import time
import json
import base64
import numpy as np
import matplotlib.pyplot as plt
import cv2
import argparse
from networks import *
from skyboxengine import *
import utils
import torch
from IPython.display import clear_output, Image, display, HTML
%matplotlib inline
# 如果存在GPU则在GPU上面运行
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
3.2设定算法参数
parameter = {
"net_G": "coord_resnet50",
"ckptdir": "./checkpoints_G_coord_resnet50",
"input_mode": "video",
"datadir": "./test_videos/sky.mp4", # 待处理的原视频路径
"skybox": "sky.jpg", # 要替换的天空图片路径
"in_size_w": 384,
"in_size_h": 384,
"out_size_w": 845,
"out_size_h": 480,
"skybox_center_crop": 0.5,
"auto_light_matching": False,
"relighting_factor": 0.8,
"recoloring_factor": 0.5,
"halo_effect": True,
"output_dir": "./jpg_output",
"save_jpgs": False
}
str_json = json.dumps(parameter)
3.3预览一下原视频
video_name = parameter['datadir']
def arrayShow(img):
img = cv2.resize(img, (0, 0), fx=0.25, fy=0.25, interpolation=cv2.INTER_NEAREST)
_,ret = cv2.imencode('.jpg', img)
return Image(data=ret)
# 打开一个视频流
cap = cv2.VideoCapture(video_name)
frame_id = 0
while True:
try:
clear_output(wait=True) # 清除之前的显示
ret, frame = cap.read() # 读取一帧图片
if ret:
frame_id += 1
if frame_id > 200:
break
cv2.putText(frame, str(frame_id), (5, 15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1) # 画frame_id
tmp = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # 转换色彩模式
img = arrayShow(frame)
display(img) # 显示图片
time.sleep(0.05) # 线程睡眠一段时间再处理下一帧图片
else:
break
except KeyboardInterrupt:
cap.release()
cap.release()
预览一下要替换的天空图片
img= cv2.imread(os.path.join('./skybox', parameter['skybox']))
img2 = img[:, :, ::-1]
plt.imshow(img2)
3.4定义SkyFilter类
class Struct:
def __init__(self, **entries):
self.__dict__.update(entries)
def parse_config():
data = json.loads(str_json)
args = Struct(**data)
return args
args = parse_config()
class SkyFilter():
def __init__(self, args):
self.ckptdir = args.ckptdir
self.datadir = args.datadir
self.input_mode = args.input_mode
self.in_size_w, self.in_size_h = args.in_size_w, args.in_size_h
self.out_size_w, self.out_size_h = args.out_size_w, args.out_size_h
self.skyboxengine = SkyBox(args)
self.net_G = define_G(input_nc=3, output_nc=1, ngf=64, netG=args.net_G).to(device)
self.load_model()
self.video_writer = cv2.VideoWriter('out.avi',
cv2.VideoWriter_fourcc(*'MJPG'),
20.0,
(args.out_size_w, args.out_size_h))
self.video_writer_cat = cv2.VideoWriter('compare.avi',
cv2.VideoWriter_fourcc(*'MJPG'),
20.0,
(2*args.out_size_w, args.out_size_h))
if os.path.exists(args.output_dir) is False:
os.mkdir(args.output_dir)
self.output_img_list = []
self.save_jpgs = args.save_jpgs
def load_model(self):
# 加载预训练的天空抠图模型
print('loading the best checkpoint...')
checkpoint = torch.load(os.path.join(self.ckptdir, 'best_ckpt.pt'),
map_location=device)
self.net_G.load_state_dict(checkpoint['model_G_state_dict'])
self.net_G.to(device)
self.net_G.eval()
def write_video(self, img_HD, syneth):
frame = np.array(255.0 * syneth[:, :, ::-1], dtype=np.uint8)
self.video_writer.write(frame)
frame_cat = np.concatenate([img_HD, syneth], axis=1)
frame_cat = np.array(255.0 * frame_cat[:, :, ::-1], dtype=np.uint8)
self.video_writer_cat.write(frame_cat)
# 定义结果缓冲区
self.output_img_list.append(frame_cat)
def synthesize(self, img_HD, img_HD_prev):
h, w, c = img_HD.shape
img = cv2.resize(img_HD, (self.in_size_w, self.in_size_h))
img = np.array(img, dtype=np.float32)
img = torch.tensor(img).permute([2, 0, 1]).unsqueeze(0)
with torch.no_grad():
G_pred = self.net_G(img.to(device))
G_pred = torch.nn.functional.interpolate(G_pred,
(h, w),
mode='bicubic',
align_corners=False)
G_pred = G_pred[0, :].permute([1, 2, 0])
G_pred = torch.cat([G_pred, G_pred, G_pred], dim=-1)
G_pred = np.array(G_pred.detach().cpu())
G_pred = np.clip(G_pred, a_max=1.0, a_min=0.0)
skymask = self.skyboxengine.skymask_refinement(G_pred, img_HD)
syneth = self.skyboxengine.skyblend(img_HD, img_HD_prev, skymask)
return syneth, G_pred, skymask
def cvtcolor_and_resize(self, img_HD):
img_HD = cv2.cvtColor(img_HD, cv2.COLOR_BGR2RGB)
img_HD = np.array(img_HD / 255., dtype=np.float32)
img_HD = cv2.resize(img_HD, (self.out_size_w, self.out_size_h))
return img_HD
def process_video(self):
# 逐帧处理视频
cap = cv2.VideoCapture(self.datadir)
m_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
img_HD_prev = None
for idx in range(m_frames):
ret, frame = cap.read()
if ret:
img_HD = self.cvtcolor_and_resize(frame)
if img_HD_prev is None:
img_HD_prev = img_HD
syneth, G_pred, skymask = self.synthesize(img_HD, img_HD_prev)
self.write_video(img_HD, syneth)
img_HD_prev = img_HD
if (idx + 1) % 50 == 0:
print(f'processing video, frame {idx + 1} / {m_frames} ... ')
else: # 如果到达最后一帧
break
3.5开始处理视频
sf = SkyFilter(args)
sf.process_video()
3.6对比原视频和处理后的视频
video_name = "compare.avi"
def arrayShow(img):
_,ret = cv2.imencode('.jpg', img)
return Image(data=ret)
# 打开一个视频流
cap = cv2.VideoCapture(video_name)
frame_id = 0
while True:
try:
clear_output(wait=True) # 清除之前的显示
ret, frame = cap.read() # 读取一帧图片
if ret:
frame_id += 1
cv2.putText(frame, str(frame_id), (5, 15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1) # 画frame_id
tmp = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # 转换色彩模式
img = arrayShow(frame)
display(img) # 显示图片
time.sleep(0.05) # 线程睡眠一段时间再处理下一帧图片
else:
break
except KeyboardInterrupt:
cap.release()
cap.release()
3.7 生成你自己的换天视频
边栏推荐
- 第一讲:链表中环的入口结点
- "An excellent programmer is worth five ordinary programmers", and the gap lies in these seven key points
- Introduction to paddle - using lenet to realize image classification method II in MNIST
- Summary of the third course of weidongshan
- From starfish OS' continued deflationary consumption of SFO, the value of SFO in the long run
- How does starfish OS enable the value of SFO in the fourth phase of SFO destruction?
- Cve-2022-28346: Django SQL injection vulnerability
- Which securities company has a low, safe and reliable account opening commission
- 【obs】Impossible to find entrance point CreateDirect3D11DeviceFromDXGIDevice
- RPA cloud computer, let RPA out of the box with unlimited computing power?
猜你喜欢
C # generics and performance comparison
1293_FreeRTOS中xTaskResumeAll()接口的实现分析
Huawei switch s5735s-l24t4s-qa2 cannot be remotely accessed by telnet
13.模型的保存和載入
13.模型的保存和载入
完整的模型训练套路
SDNU_ ACM_ ICPC_ 2022_ Summer_ Practice(1~2)
12. RNN is applied to handwritten digit recognition
Deep dive kotlin synergy (XXII): flow treatment
Redis, do you understand the list
随机推荐
5G NR 系统消息
Four stages of sand table deduction in attack and defense drill
赞!idea 如何单窗口打开多个项目?
Cause analysis and solution of too laggy page of [test interview questions]
Letcode43: string multiplication
【GO记录】从零开始GO语言——用GO语言做一个示波器(一)GO语言基础
[note] common combined filter circuit
10.CNN应用于手写数字识别
Deep dive kotlin collaboration (the end of 23): sharedflow and stateflow
22年秋招心得
Semantic segmentation model base segmentation_ models_ Detailed introduction to pytorch
Interface test advanced interface script use - apipost (pre / post execution script)
SDNU_ ACM_ ICPC_ 2022_ Summer_ Practice(1~2)
5.过拟合,dropout,正则化
完整的模型验证(测试,demo)套路
ReentrantLock 公平锁源码 第0篇
基于卷积神经网络的恶意软件检测方法
Reentrantlock fair lock source code Chapter 0
STL -- common function replication of string class
Summary of weidongshan phase II course content