当前位置:网站首页>【PyTorchVideo教程01】快速实现视频动作识别
【PyTorchVideo教程01】快速实现视频动作识别
2022-07-30 19:18:00 【CV-杨帆】
1 PyTorchVideo介绍
PyTorchVideo是Facebook2021年4月份发布,主要针对视频深度学习应用。
b站:https://www.bilibili.com/video/BV1QT411j7M3
1.1 参考资料:
pytorchvideo官网:https://pytorchvideo.org/
pytorchvideo Github:https://github.com/facebookresearch/pytorchvideo
Tutorials:https://pytorchvideo.org/docs/tutorial_torchhub_inference
深入浅出PyTorch:8.3 PyTorchVideo简介
PyTorchVideo: 针对视频深度学习,你想要的它都有:https://zhuanlan.zhihu.com/p/390909705
PyTorchVideo: A Deep Learning Library for Video Understanding:https://arxiv.org/pdf/2111.09887.pdf
1.2 介绍
近几年来,随着传播媒介和视频平台的发展,视频正在取代图片成为下一代的主流媒体,这也使得有关视频的深度学习模型正在获得越来越多的关注。
然而,有关视频的深度学习模型仍然有着许多缺点:
- 计算资源耗费更多,并且没有高质量的 model zoo,不能像图片一样进行迁移学习和论文复现。
- 数据集处理较麻烦,但没有一个很好的视频处理工具。
- 随着多模态越来越流行,亟需一个工具来处理其他模态。
除此之外,还有部署优化等问题,为了解决这些问题,Meta推出了PyTorchVideo深度学习库(包含组件如Figure 1所示)。PyTorchVideo 是一个专注于视频理解工作的深度学习库。PytorchVideo 提供了加速视频理解研究所需的可重用、模块化和高效的组件。PyTorchVideo 是使用PyTorch开发的,支持不同的深度学习视频组件,如视频模型、视频数据集和视频特定转换。
正文开始之前先放一个demo,PyTorchVideo通过模型部署优化模组(accelerator)率先实现了移动端的实时视频动作识别(基于X3D模型),未来视频模型跑在移动端不再是梦想。
PyTorchVideo 移动端的实时视频动作识别
PyTorchVideo A deep learning library for video understanding
3 GPU平台
极链AI:https://cloud.videojj.com/auth/register?inviter=18452&activityChannel=student_invite
镜像快速搭建
4 安装pytorchvideo
cd /home
pip install pytorchvideo
wget https://dl.fbaipublicfiles.com/pyslowfast/dataset/class_names/kinetics_classnames.json
wget https://dl.fbaipublicfiles.com/pytorchvideo/projects/archery.mp4
如果archery.mp4无法下载,可以先下载好,然后上传,视频资源我已经上传到了阿里云盘:
https://www.aliyundrive.com/s/xjzfmH3uoFB
我在csdn上也上传了视频资源:archery.mp4 行为识别 pytorchvideo demo演示视频(行为识别)
5 demo演示
需要提前准备好一个视频
开始搭建(使用Notebook,主要是查看中间的步骤)
import torch
import json
from torchvision.transforms import Compose, Lambda
from torchvision.transforms._transforms_video import (
CenterCropVideo,
NormalizeVideo,
)
from pytorchvideo.data.encoded_video import EncodedVideo
from pytorchvideo.transforms import (
ApplyTransformToKey,
ShortSideScale,
UniformTemporalSubsample,
UniformCropVideo
)
from typing import Dict
# Device on which to run the model
# Set to cuda to load on GPU
device = "cpu"
# Pick a pretrained model and load the pretrained weights
model_name = "slowfast_r50"
model = torch.hub.load("facebookresearch/pytorchvideo", model=model_name, pretrained=True)
# Set to eval mode and move to desired device
model = model.to(device)
model = model.eval()
with open("kinetics_classnames.json", "r") as f:
kinetics_classnames = json.load(f)
# Create an id to label name mapping
kinetics_id_to_classname = {
}
for k, v in kinetics_classnames.items():
kinetics_id_to_classname[v] = str(k).replace('"', "")
####################
# SlowFast transform
####################
side_size = 256
mean = [0.45, 0.45, 0.45]
std = [0.225, 0.225, 0.225]
crop_size = 256
num_frames = 32
sampling_rate = 2
frames_per_second = 30
alpha = 4
class PackPathway(torch.nn.Module):
""" Transform for converting video frames as a list of tensors. """
def __init__(self):
super().__init__()
def forward(self, frames: torch.Tensor):
fast_pathway = frames
# Perform temporal sampling from the fast pathway.
slow_pathway = torch.index_select(
frames,
1,
torch.linspace(
0, frames.shape[1] - 1, frames.shape[1] // alpha
).long(),
)
frame_list = [slow_pathway, fast_pathway]
return frame_list
transform = ApplyTransformToKey(
key="video",
transform=Compose(
[
UniformTemporalSubsample(num_frames),
Lambda(lambda x: x/255.0),
NormalizeVideo(mean, std),
ShortSideScale(
size=side_size
),
CenterCropVideo(crop_size),
PackPathway()
]
),
)
# The duration of the input clip is also specific to the model.
clip_duration = (num_frames * sampling_rate)/frames_per_second
# Load the example video
video_path = "archery.mp4"
# Select the duration of the clip to load by specifying the start and end duration
# The start_sec should correspond to where the action occurs in the video
start_sec = 0
end_sec = start_sec + clip_duration
# Initialize an EncodedVideo helper class
video = EncodedVideo.from_path(video_path)
# Load the desired clip
video_data = video.get_clip(start_sec=start_sec, end_sec=end_sec)
# Apply a transform to normalize the video input
video_data = transform(video_data)
# Move the inputs to the desired device
inputs = video_data["video"]
inputs = [i.to(device)[None, ...] for i in inputs]
# Pass the input clip through the model
preds = model(inputs)
# Get the predicted classes
post_act = torch.nn.Softmax(dim=1)
preds = post_act(preds)
pred_classes = preds.topk(k=5).indices
# Map the predicted classes to the label names
pred_class_names = [kinetics_id_to_classname[int(i)] for i in pred_classes[0]]
print("Predicted labels: %s" % ", ".join(pred_class_names))
处理结果:
Predicted labels: archery, throwing axe, playing paintball, disc golfing, riding or walking with horse

边栏推荐
- MYSQL (Basic) - An article takes you into the wonderful world of MYSQL
- MySql中@符号的使用
- 6块钱1斤,日本公司为何来中国收烟头?
- MindSpore:npu 多卡训练自定义数据集如何给不同npu传递不同数据
- 【网站放大镜效果】两种方式实现
- 跨进程启动后台服务
- 【剑指 Offe】剑指 Offer 17. 打印从1到最大的n位数
- SimpleOSS third-party library libcurl and engine libcurl error solution
- 6 yuan per catty, why do Japanese companies come to China to collect cigarette butts?
- MindSpore:Cifar10Dataset‘s num_workers=8, this value is not within the required range of [1, cpu_thr
猜你喜欢

【MindSpore1.2.0-rc1产品】num_workers问题

MindSpore:【MindSpore1.1】Mindspore安装后验证出现cudaSetDevice failed错误

MindSpore:【模型训练】【mindinsight】timeline的时间和实际用时相差很远

NXP IMX8QXP replacement DDR model operation process

NXP IMX8QXP更换DDR型号操作流程

MindSpore:【resnet_thor模型】尝试运行resnet_thor时报Could not convert to

NC | Tao Liang Group of West Lake University - TMPRSS2 "assists" virus infection and mediates the host invasion of Clostridium sothrix hemorrhagic toxin...

Critical Reviews | 南农邹建文组综述全球农田土壤抗生素与耐药基因分布

【剑指 Offe】剑指 Offer 18. 删除链表的节点

What is the difference between a cloud database and an on-premises database?
随机推荐
kotlin的by lazy
开心的聚餐
Listen to the boot broadcast
7.30模拟赛总结
【MindSpore】多卡训练保存权重问题
OMP: Error #15: Initializing libiomp5md.dll, but found libiomp5md.dll already initialized.解决方法
几个GTest、GMock的例子
荐号 | 对你有恩的人,不要请吃饭来报答
C# wpf borderless window add shadow effect
生物医学论文有何价值 论文中译英怎样翻译效果好
How architects grow
【Pointing to Offer】Pointing to Offer 18. Delete the node of the linked list
跨域问题的解决方法
Swiper rotates pictures and plays background music
牛客刷题系列之进阶版(搜索旋转排序数组,链表内指定区间反转)
How do radio waves transmit information?
Tensorflow2.0 confusion matrix does not match printing accuracy
卫星电话是直接与卫星通信还是通过地面站?
一文读懂“语言模型”
SimpleOSS third-party library libcurl and engine libcurl error solution