当前位置:网站首页>BasicVSR_PlusPlus-master测试视频、图片
BasicVSR_PlusPlus-master测试视频、图片
2022-07-06 15:00:00 【cv-daily】
代码:https://github.com/ckkelvinchan/BasicVSR_PlusPlus
BasicVSR_PlusPlus-master测试图片和是视频总是报out of memory,显存不够,但是又需要测试,修改代码。
修改restoration_video_demo.py
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import cv2
import mmcv
import numpy as np
import torch
from mmedit.apis import init_model, restoration_video_inference
from mmedit.core import tensor2img
from mmedit.utils import modify_args
import time
VIDEO_EXTENSIONS = ('.mp4', '.mov')
def parse_args():
modify_args()
parser = argparse.ArgumentParser(description='Restoration demo')
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument('input_dir', help='directory of the input video')
parser.add_argument('output_dir', help='directory of the output video')
parser.add_argument(
'--start-idx',
type=int,
default=0,
help='index corresponds to the first frame of the sequence')
parser.add_argument(
'--filename-tmpl',
default='{:08d}.png',
help='template of the file names')
parser.add_argument(
'--window-size',
type=int,
default=0,
help='window size if sliding-window framework is used')
parser.add_argument(
'--max-seq-len',
type=int,
default=None,
help='maximum sequence length if recurrent framework is used')
parser.add_argument('--device', type=int, default=0, help='CUDA device id')
args = parser.parse_args()
return args
def main():
""" Demo for video restoration models. Note that we accept video as input/output, when 'input_dir'/'output_dir' is set to the path to the video. But using videos introduces video compression, which lowers the visual quality. If you want actual quality, please save them as separate images (.png). """
args = parse_args()
model = init_model(
args.config, args.checkpoint, device=torch.device('cuda', args.device))
for i in range(10000):
start_idx=i
# time.sleep(500)
output = restoration_video_inference(model, args.input_dir,
args.window_size, start_idx,
args.filename_tmpl, args.max_seq_len)
torch.cuda.empty_cache()
time.sleep(10)
file_extension = os.path.splitext(args.output_dir)[1]
if file_extension in VIDEO_EXTENSIONS: # save as video
h, w = output.shape[-2:]
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
video_writer = cv2.VideoWriter(args.output_dir, fourcc, 25, (w, h))
for i in range(0, output.size(1)):
img = tensor2img(output[:, i, :, :, :])
video_writer.write(img.astype(np.uint8))
cv2.destroyAllWindows()
video_writer.release()
else:
for i in range(args.start_idx, args.start_idx + output.size(1)):
output_i = output[:, i - args.start_idx, :, :, :]
output_i = tensor2img(output_i)
print(args.filename_tmpl.format(start_idx))
# save_path_i = f'{args.output_dir}/{args.filename_tmpl.format(i)}'
save_path_i = f'{
args.output_dir}/{
args.filename_tmpl.format(start_idx)}'
mmcv.imwrite(output_i, save_path_i)
if __name__ == '__main__':
main()
修改restoration_video_inference.py
# Copyright (c) OpenMMLab. All rights reserved.
import glob
import os.path as osp
import re
from functools import reduce
import mmcv
import numpy as np
import torch
from mmedit.datasets.pipelines import Compose
VIDEO_EXTENSIONS = ('.mp4', '.mov')
def pad_sequence(data, window_size):
padding = window_size // 2
data = torch.cat([
data[:, 1 + padding:1 + 2 * padding].flip(1), data,
data[:, -1 - 2 * padding:-1 - padding].flip(1)
],
dim=1)
return data
def restoration_video_inference(model,
img_dir,
window_size,
start_idx,
filename_tmpl,
max_seq_len=None,
):
"""Inference image with the model. Args: model (nn.Module): The loaded model. img_dir (str): Directory of the input video. window_size (int): The window size used in sliding-window framework. This value should be set according to the settings of the network. A value smaller than 0 means using recurrent framework. start_idx (int): The index corresponds to the first frame in the sequence. filename_tmpl (str): Template for file name. max_seq_len (int | None): The maximum sequence length that the model processes. If the sequence length is larger than this number, the sequence is split into multiple segments. If it is None, the entire sequence is processed at once. Returns: Tensor: The predicted restoration result. """
device = next(model.parameters()).device # model device
# build the data pipeline
if model.cfg.get('demo_pipeline', None):
test_pipeline = model.cfg.demo_pipeline
elif model.cfg.get('test_pipeline', None):
test_pipeline = model.cfg.test_pipeline
else:
test_pipeline = model.cfg.val_pipeline
print(img_dir)
# check if the input is a video
file_extension = osp.splitext(img_dir)[1]
if file_extension in VIDEO_EXTENSIONS:
video_reader = mmcv.VideoReader(img_dir)
# load the images
data = dict(lq=[], lq_path=None, key=img_dir)
for frame in video_reader:
data['lq'].append(np.flip(frame, axis=2))
# remove the data loading pipeline
tmp_pipeline = []
for pipeline in test_pipeline:
if pipeline['type'] not in [
'GenerateSegmentIndices', 'LoadImageFromFileList'
]:
tmp_pipeline.append(pipeline)
test_pipeline = tmp_pipeline
else:
# the first element in the pipeline must be 'GenerateSegmentIndices'
if test_pipeline[0]['type'] != 'GenerateSegmentIndices':
raise TypeError('The first element in the pipeline must be '
f'"GenerateSegmentIndices", but got '
f'"{
test_pipeline[0]["type"]}".')
# specify start_idx and filename_tmpl
print('start_idx', start_idx)
print('filename_tmpl', filename_tmpl)
test_pipeline[0]['start_idx'] = start_idx
test_pipeline[0]['filename_tmpl'] = filename_tmpl
# prepare data
# sequence_length = len(glob.glob(osp.join(img_dir, '*')))
sequence_length = 1
img_dir_split = re.split(r'[\\/]', img_dir)
print(img_dir)
key = img_dir_split[-1]
lq_folder = reduce(osp.join, img_dir_split[:-1])
print(lq_folder)
data = dict(
lq_path=lq_folder,
gt_path='',
key=key,
sequence_length=sequence_length)
# compose the pipeline
test_pipeline = Compose(test_pipeline)
data = test_pipeline(data)
print("data_lq",data['lq'].shape)
data = data['lq'].unsqueeze(0) # in cpu
data = data.unsqueeze(0) # in cpu
print("data",data.shape)
# forward the model
with torch.no_grad():
if window_size > 0: # sliding window framework
data = pad_sequence(data, window_size)
result = []
for i in range(0, data.size(1) - 2 * (window_size // 2)):
data_i = data[:, i:i + window_size].to(device)
result.append(model(lq=data_i, test_mode=True)['output'].cpu())
result = torch.stack(result, dim=1)
else: # recurrent framework
if max_seq_len is None:
result = model(
lq=data.to(device), test_mode=True)['output'].cpu()
else:
result = []
for i in range(0, data.size(1), max_seq_len):
result.append(
model(
lq=data[:, i:i + max_seq_len].to(device),
test_mode=True)['output'].cpu())
result = torch.cat(result, dim=1)
return result
边栏推荐
- 硬件開發筆記(十): 硬件開發基本流程,制作一個USB轉RS232的模塊(九):創建CH340G/MAX232封裝庫sop-16並關聯原理圖元器件
- (十八)LCD1602实验
- Assembly and interface technology experiment 5-8259 interrupt experiment
- 0 basic learning C language - interrupt
- Classic sql50 questions
- 3DMax指定面贴图
- Assembly and Interface Technology Experiment 6 - ADDA conversion experiment, AD acquisition system in interrupt mode
- 2022-07-05 use TPCC to conduct sub query test on stonedb
- MySQL教程的天花板,收藏好,慢慢看
- Export MySQL table data in pure mode
猜你喜欢
Web APIs DOM 时间对象
基於 QEMUv8 搭建 OP-TEE 開發環境
LeetCode刷题(十一)——顺序刷题51至55
pytorch_ Yolox pruning [with code]
【编译原理】做了一半的LR(0)分析器
Pit encountered by handwritten ABA
That's why you can't understand recursion
Export MySQL table data in pure mode
CCNA-思科网络 EIGRP协议
Unity3d minigame-unity-webgl-transform插件转换微信小游戏报错To use dlopen, you need to use Emscripten‘s...问题
随机推荐
中国VOCs催化剂行业研究与投资战略报告(2022版)
Installation and use of labelimg
return 关键字
Leetcode exercise - Sword finger offer 26 Substructure of tree
Should novice programmers memorize code?
case 关键字后面的值有什么要求吗?
Attack and defense world ditf Misc
UE4蓝图学习篇(四)--流程控制ForLoop和WhileLoop
在IPv6中 链路本地地址的优势
自定义 swap 函数
Data processing skills (7): MATLAB reads the data in the text file TXT with mixed digital strings
重磅新闻 | Softing FG-200获得中国3C防爆认证 为客户现场测试提供安全保障
pytorch_ Yolox pruning [with code]
ResNet-RS:谷歌领衔调优ResNet,性能全面超越EfficientNet系列 | 2021 arxiv
图像的spatial domain 和 frequency domain 图像压缩
【雅思口语】安娜口语学习记录part1
做接口测试都测什么?有哪些通用测试点?
Unity3d minigame-unity-webgl-transform插件转换微信小游戏报错To use dlopen, you need to use Emscripten‘s...问题
const关键字
Aardio - 通过变量名将变量值整合到一串文本中