当前位置:网站首页>OneFlow源码解析:Op、Kernel与解释器
OneFlow源码解析:Op、Kernel与解释器
2022-08-01 16:12:00 【InfoQ】
1 Op与Kernel的注册
REGISTER_USER_OP
REGISTER_USER_OP_GRAD
REGISTER_USER_KERNEL
1.1 ReluOp的注册
- class定义:
- build/oneflow/core/framework/op_generated.h
- 注册op、op的部分实现:
- build/oneflow/core/framework/op_generated.cpp
- 主要实现:
- oneflow/oneflow/user/ops/relu_op.cpp
REGISTER_USER_OP
op_generated.cpp
static UserOpRegisterTrigger<OpRegistry> g_register_trigger715 =
::oneflow::user_op::UserOpRegistryMgr::Get()
.CheckAndGetOpRegistry("relu")
.Input("x")
.Output("y")
.SetGetSbpFn(&ReluOp::GetSbp)
.SetLogicalTensorDescInferFn(&ReluOp::InferLogicalTensorDesc)
.SetPhysicalTensorDescInferFn(&ReluOp::InferPhysicalTensorDesc)
.SetDataTypeInferFn(&ReluOp::InferDataType);

OpRegistry
Finish
UserOpRegisterTrigger
OpRegistryResult
relu
1.2 ReluKernel的注册
REGISTER_USER_KERNEL
static UserOpRegisterTrigger<OpKernelRegistry> g_register_trigger0 =
UserOpRegistryMgr::Get().
CheckAndGetOpKernelRegistry("relu").
.SetCreateFn(...)
.SetIsMatchedHob(UnaryPrimitiveExists(ep::primitive::UnaryOp::kRelu, "y", "x"))
.SetInplaceProposalFn([](const user_op::InferContext&,
const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe<void> {
OF_RETURN_IF_ERROR(AddInplaceArgPairFn("y", 0, "x", 0, true));
return Maybe<void>::Ok();
});
result_.create_fn
[]() {
return user_op::NewOpKernel<UnaryPrimitiveKernel>(
"y", "x", [](user_op::KernelComputeContext* ctx) {
const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("x", 0);
const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("y", 0);
return ep::primitive::NewPrimitive<ep::primitive::ElementwiseUnaryFactory>(
ctx->device_type(), ep::primitive::UnaryOp::kRelu, src->data_type(),
dst->data_type());
});
}
1.3 Op和Kernel注册相关的类关系图

2、UserOpExpr的构造
functional::Relu
find("Relu")
PackedFunctor<impl::ReluFunctor>
call
impl::ReluFunctor
class ReluFunctor { public: ReluFunctor() { op_ = CHECK_JUST(one::OpBuilder("relu").Input("x", 1).Output("y", 1).Build()); } Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x, bool inplace) const { // 忽略inplace相关逻辑 return OpInterpUtil::Dispatch<Tensor>(*op_, {x}); } private: std::shared_ptr<OpExpr> op_;};
user op
OpBuilder的Build()
UserOpExpr
UserOpExpr
- base_attrs_
- tensor_desc_infer_fn_
- dtype_infer_fn_
- device_and_stream_infer_fn_
OpBuilder
Input/Output
UserOpConf
据OpRegistryResult::op_def
UserOpExpr
UserOpConf
UserOpExpr
BuiltinOpExprImpl<UserOpConf>
op_proto_
op_proto_
UserOpExpr
3 、Functor的执行
ReluFunctor
OpInterpUtil::Dispatch

3.1 根据环境和输入选择解释器
OpExprInterpreter
GetInterpreter
Interpreter
- LazyInterpreter: 用于lazy mode下的分布式静态图执行模式
- EagerLocalInterpreter: 用于eager local mode本地单卡执行模式(和pytorch单卡或DDP对齐)
- EagerGlobalInterpreter: 用于eager global mode,的分布式动态图执行模式

GetInterpreter
AutogradInterpreter::Apply
Apply
3.2 Apply
Maybe<void> EagerInterpreter::Apply(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) const {#define APPLY_IF(op_type) \ if (const auto* op = dynamic_cast<const op_type##Expr*>(&op_expr)) { \ return ApplyImpl(*op, inputs, outputs, ctx); \ } APPLY_IF(UserOp); APPLY_IF(VariableOp); APPLY_IF(CastToLocalOp); APPLY_IF(CastFromLocalOp); APPLY_IF(GlobalToGlobalOp); APPLY_IF(CastToGlobalOp); APPLY_IF(CastFromGlobalOp); APPLY_IF(DistributeSplitOp); APPLY_IF(DistributeCloneOp); APPLY_IF(DistributeConcatOp); APPLY_IF(DistributeAddOp); APPLY_IF(FunctionOp); APPLY_IF(SelectTopNOp)#undef APPLY_IF OF_UNIMPLEMENTED() << "The type " << op_expr.op_type_name() << " has not been supported in EagerInterpreter::Apply.";}
if (const auto* op = dynamic_cast<const UserOpExpr*>(&op_expr)) {
return ApplyImpl(*op, inputs, outputs, ctx);
}
Maybe<void> EagerLocalInterpreter::ApplyImpl(const UserOpExpr& op_expr, const TensorTuple& inputs,
TensorTuple* outputs,
const OpExprInterpContext& ctx) const {
return NaiveInterpret(op_expr, inputs, outputs, ctx);
}
3.3 NaiveInterpret
- check input tensor的device是否一致
- 生成output tensor
- 为output tensor推导和检查shape/stride/dtype
- 构建op执行指令,并派发至vm
Maybe<void> NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& inputs, const Symbol<Device>& default_device, TensorTuple* outputs, const OpExprInterpContext& ctx) { const auto& attrs = ctx.attrs; // 检查input tensor是否位于相同device上 ... // 推导outout tensor的设备类型 // Infer devices if (!user_op_expr.has_device_and_stream_infer_fn()) { stream = JUST(GetDefaultStreamByDevice(default_device)); for (int i = 0; i < outputs->size(); i++) { auto* tensor_impl = JUST(TensorImpl4Tensor(outputs->at(i))); *JUST(tensor_impl->mut_device()) = default_device; } } else { need_check_mem_case = false; stream = JUST(user_op_expr.InferDeviceAndStream(attrs, inputs, outputs)); } // 推导outout tensor的形状、数据类型 // Infer shapes and dtypes const auto& device_tag = stream->device()->type(); JUST(user_op_expr.InferPhysicalTensorDesc( attrs, device_tag, [&](int32_t i) -> const TensorMeta* { return CHECK_JUST(TensorImpl4Tensor(inputs[i]))->mut_tensor_meta(); }, [&](int32_t i) -> TensorMeta* { // using thread_local TensorMeta pointer if inplace. // using tensor_impl TensorMeta pointer if not inplace. return output_tensor_metas->at(i); })); // 为output tensor初始化eager_blob_object for (int i = 0; i < output_eager_blob_objects->size(); i++) { auto* tensor_impl = JUST(TensorImpl4Tensor(outputs->at(i))); if (!output_eager_blob_objects->at(i)) { if (!JUST(user_op_expr.SupportNonContiguous())) { std::shared_ptr<Stride> stride(new Stride(*tensor_impl->shape())); tensor_impl->mut_tensor_meta()->set_stride(stride); } const auto& dep_object = NewLocalDepObject(); JUST(tensor_impl->InitEagerBlobObject(dep_object)); output_eager_blob_objects->at(i) = JUST(tensor_impl->eager_blob_object()); } else { // output i is inplaced. // check thread_local TensorMeta and tensor_impl TensorMeta. CHECK_OR_RETURN(tensor_impl->tensor_meta()->shape() == output_tensor_metas->at(i)->shape()); CHECK_OR_RETURN(tensor_impl->tensor_meta()->dtype() == output_tensor_metas->at(i)->dtype()); } } // 从user_op_expr中取出kernel const auto& kernel = JUST(user_op_expr.MutKernel4Stream(stream)); kernel->set_need_check_mem_case(need_check_mem_case); for (int64_t index : kernel->output_tuple_indexes4mut2_obns()) { output_eager_blob_objects->at(index)->set_is_shape_synced(false); } // kernel dispatch至VM,等待后续实际的调度执行 JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> { return builder->Call(kernel, input_eager_blob_objects, output_eager_blob_objects, ctx, stream); })); return Maybe<void>::Ok();}
OpCall
参考资料
- OneFlow学习笔记:Op注册
- (https://mp.weixin.qq.com/s/eF-c2irraxnH4iAesURy0Q)
- 从Functor到OpExprInterpreter
- https://github.com/Oneflow-Inc/oneflow/tree/v0.8.1
- https://zhuanlan.zhihu.com/p/523884650
- OneFlow v0.8.0正式发布
- 解读Pathways:向前一步是OneFlow
- 18张图,直观理解神经网络、流形和拓扑
- OneFlow源码解析:算子签名的自动推断
- 分布式深度学习编程新范式:Global Tensor
- LLVM之父:为什么我们要重建AI基础设施软件
- 大模型训练难?效率超群、易用的“李白”模型库来了
边栏推荐
- Spark: Cluster Computing with Working Sets
- leetcode:80. 删除有序数组中的重复项 II
- 美国弗吉尼亚大学、微软 | Active Data Pattern Extraction Attacks on Generative Language Models(对生成语言模型的主动数据模式提取攻击)
- Meeting OA project (6) --- (to-be-opened meeting, historical meeting, all meetings)
- leetcode:33. 搜索旋转排序数组
- Why should model.eval() be added to the pytorch test?
- Spark: Cluster Computing with Working Sets
- 怎么安装汉化包(svn中文语言包安装)
- 珠海市生物安全P3实验室主体结构封顶
- 测试工程师进阶必读书目
猜你喜欢
ECCV 2022 | Poseur:你以为我是姿态估计,其实是目标检测哒
IronOS, an open source system for portable soldering irons, supports a variety of portable DC, QC, PD powered soldering irons, and supports all standard functions of smart soldering irons
Break the limit of file locks and use storage power to help enterprises grow new momentum
AI艺术‘美丑’不可控?试试 AI 美学评分器~
5年测试,只会功能要求17K,功能测试都敢要求这么高薪资了?
A full review of mainstream timed task solutions
设计专业第一台笔记本 华硕灵耀Pro16 2022 新品首发超值入手
百图生科卓越开发者计划全面升级暨《计算免疫问题白皮书》发布
【Unity,C#】哨兵射线触发器模板代码
Ant discloses the open source layout of core basic software technology for the first time
随机推荐
30分钟成为Contributor|如何多方位参与OpenHarmony开源贡献?
选择合适的 DevOps 工具,从理解 DevOps 开始
mysql源码分析——聚簇索引
flink -redis sink 可以sink 到集群吗?
华盛顿大学、Allen AI 等联合 | RealTime QA: What's the Answer Right Now?(实时 QA:现在的答案是什么?)
5年测试,只会功能要求17K,功能测试都敢要求这么高薪资了?
MUI 做手机返回操作栏
ESP8266-Arduino编程实例-GA1A12S202对数刻度模拟光传感器
会议OA项目(六)--- (待开会议、历史会议、所有会议)
Can MySQL do two-way synchronization of multiple vps?
kubelet节点压力驱逐
eslint语法报错解决
主流定时任务解决方案全横评
js to determine whether it is a pc or a mobile terminal (including ipad)
LeetCode50天刷题计划(Day 8—— 盛最多水的容器(23.00-1.20)
gconf/dconf实战编程(3)利用dconf库读写配置实战以及诸多配套工具演示
商业智能BI业务分析思维:供应链分析 - 什么是牛鞭效应(一)
pytorch测试的时候为何要加上model.eval()?
js邯郸市地图网页源码下载
火花:集群计算工作集