当前位置:网站首页>【PyTorchVideo教程01】快速实现视频动作识别
【PyTorchVideo教程01】快速实现视频动作识别
2022-07-30 19:18:00 【CV-杨帆】
1 PyTorchVideo介绍
PyTorchVideo是Facebook2021年4月份发布,主要针对视频深度学习应用。
b站:https://www.bilibili.com/video/BV1QT411j7M3
1.1 参考资料:
pytorchvideo官网:https://pytorchvideo.org/
pytorchvideo Github:https://github.com/facebookresearch/pytorchvideo
Tutorials:https://pytorchvideo.org/docs/tutorial_torchhub_inference
深入浅出PyTorch:8.3 PyTorchVideo简介
PyTorchVideo: 针对视频深度学习,你想要的它都有:https://zhuanlan.zhihu.com/p/390909705
PyTorchVideo: A Deep Learning Library for Video Understanding:https://arxiv.org/pdf/2111.09887.pdf
1.2 介绍
近几年来,随着传播媒介和视频平台的发展,视频正在取代图片成为下一代的主流媒体,这也使得有关视频的深度学习模型正在获得越来越多的关注。
然而,有关视频的深度学习模型仍然有着许多缺点:
- 计算资源耗费更多,并且没有高质量的 model zoo,不能像图片一样进行迁移学习和论文复现。
- 数据集处理较麻烦,但没有一个很好的视频处理工具。
- 随着多模态越来越流行,亟需一个工具来处理其他模态。
除此之外,还有部署优化等问题,为了解决这些问题,Meta推出了PyTorchVideo深度学习库(包含组件如Figure 1所示)。PyTorchVideo 是一个专注于视频理解工作的深度学习库。PytorchVideo 提供了加速视频理解研究所需的可重用、模块化和高效的组件。PyTorchVideo 是使用PyTorch开发的,支持不同的深度学习视频组件,如视频模型、视频数据集和视频特定转换。
正文开始之前先放一个demo,PyTorchVideo通过模型部署优化模组(accelerator)率先实现了移动端的实时视频动作识别(基于X3D模型),未来视频模型跑在移动端不再是梦想。
PyTorchVideo 移动端的实时视频动作识别
PyTorchVideo A deep learning library for video understanding
3 GPU平台
极链AI:https://cloud.videojj.com/auth/register?inviter=18452&activityChannel=student_invite
镜像快速搭建
4 安装pytorchvideo
cd /home
pip install pytorchvideo
wget https://dl.fbaipublicfiles.com/pyslowfast/dataset/class_names/kinetics_classnames.json
wget https://dl.fbaipublicfiles.com/pytorchvideo/projects/archery.mp4
如果archery.mp4无法下载,可以先下载好,然后上传,视频资源我已经上传到了阿里云盘:
https://www.aliyundrive.com/s/xjzfmH3uoFB
我在csdn上也上传了视频资源:archery.mp4 行为识别 pytorchvideo demo演示视频(行为识别)
5 demo演示
需要提前准备好一个视频
开始搭建(使用Notebook,主要是查看中间的步骤)
import torch
import json
from torchvision.transforms import Compose, Lambda
from torchvision.transforms._transforms_video import (
CenterCropVideo,
NormalizeVideo,
)
from pytorchvideo.data.encoded_video import EncodedVideo
from pytorchvideo.transforms import (
ApplyTransformToKey,
ShortSideScale,
UniformTemporalSubsample,
UniformCropVideo
)
from typing import Dict
# Device on which to run the model
# Set to cuda to load on GPU
device = "cpu"
# Pick a pretrained model and load the pretrained weights
model_name = "slowfast_r50"
model = torch.hub.load("facebookresearch/pytorchvideo", model=model_name, pretrained=True)
# Set to eval mode and move to desired device
model = model.to(device)
model = model.eval()
with open("kinetics_classnames.json", "r") as f:
kinetics_classnames = json.load(f)
# Create an id to label name mapping
kinetics_id_to_classname = {
}
for k, v in kinetics_classnames.items():
kinetics_id_to_classname[v] = str(k).replace('"', "")
####################
# SlowFast transform
####################
side_size = 256
mean = [0.45, 0.45, 0.45]
std = [0.225, 0.225, 0.225]
crop_size = 256
num_frames = 32
sampling_rate = 2
frames_per_second = 30
alpha = 4
class PackPathway(torch.nn.Module):
""" Transform for converting video frames as a list of tensors. """
def __init__(self):
super().__init__()
def forward(self, frames: torch.Tensor):
fast_pathway = frames
# Perform temporal sampling from the fast pathway.
slow_pathway = torch.index_select(
frames,
1,
torch.linspace(
0, frames.shape[1] - 1, frames.shape[1] // alpha
).long(),
)
frame_list = [slow_pathway, fast_pathway]
return frame_list
transform = ApplyTransformToKey(
key="video",
transform=Compose(
[
UniformTemporalSubsample(num_frames),
Lambda(lambda x: x/255.0),
NormalizeVideo(mean, std),
ShortSideScale(
size=side_size
),
CenterCropVideo(crop_size),
PackPathway()
]
),
)
# The duration of the input clip is also specific to the model.
clip_duration = (num_frames * sampling_rate)/frames_per_second
# Load the example video
video_path = "archery.mp4"
# Select the duration of the clip to load by specifying the start and end duration
# The start_sec should correspond to where the action occurs in the video
start_sec = 0
end_sec = start_sec + clip_duration
# Initialize an EncodedVideo helper class
video = EncodedVideo.from_path(video_path)
# Load the desired clip
video_data = video.get_clip(start_sec=start_sec, end_sec=end_sec)
# Apply a transform to normalize the video input
video_data = transform(video_data)
# Move the inputs to the desired device
inputs = video_data["video"]
inputs = [i.to(device)[None, ...] for i in inputs]
# Pass the input clip through the model
preds = model(inputs)
# Get the predicted classes
post_act = torch.nn.Softmax(dim=1)
preds = post_act(preds)
pred_classes = preds.topk(k=5).indices
# Map the predicted classes to the label names
pred_class_names = [kinetics_id_to_classname[int(i)] for i in pred_classes[0]]
print("Predicted labels: %s" % ", ".join(pred_class_names))
处理结果:
Predicted labels: archery, throwing axe, playing paintball, disc golfing, riding or walking with horse

边栏推荐
猜你喜欢

redis

Does the satellite phone communicate directly with the satellite or through a ground station?

Mysql execution principle analysis

Golang logging library zerolog use record

【每日一道LeetCode】——191. 位1的个数

C# wpf borderless window add shadow effect

MindSpore: CV.Rescale(rescale,shift)中参数rescale和shift的含义?
![[Summary] 1396- 60+ VSCode plugins to create a useful editor](/img/e4/65e55d0e4948c011585b72733d4d19.jpg)
[Summary] 1396- 60+ VSCode plugins to create a useful editor

SwiftUI iOS Boutique Open Source Project Complete Baked Food Recipe App based on SQLite (tutorial including source code)

The advanced version of the cattle brushing series (search for rotating sorted arrays, inversion of the specified range in the linked list)
随机推荐
VS Code 连接SQL Server
MindSpore:【Resolve node failed】解析节点失败的问题
Vulkan开启特征(feature)的正确姿势
WEBSOCKETPP使用简介+demo
The advanced version of the cattle brushing series (search for rotating sorted arrays, inversion of the specified range in the linked list)
经济新闻:错误# 15:初始化libiomp5md。dll,但发现libiomp5md。已经初始化dll。解决方法
一文读懂“语言模型”
DM8:单库单实例搭建本地数据守护服务
【MindSpore】多卡训练保存权重问题
Perfectly Clear QuickDesk & QuickServer图像校正优化工具
Recommendation | People who are kind to you, don't repay them by inviting them to eat
AI Basics: Graphical Transformer
开心的聚餐
LeetCode每日一题(1717. Maximum Score From Removing Substrings)
Listen to the boot broadcast
7.30模拟赛总结
Scrapy framework is introduced
【MindSpore】用coco2017训练Model_zoo上的 yolov4,迭代了两千多batch_size之后报错,大佬们帮忙看看。
VBA batch import Excel data into Access database
技术很牛逼,还需要“向上管理”吗?