当前位置:网站首页>OneFlow源码解析:Op、Kernel与解释器
OneFlow源码解析:Op、Kernel与解释器
2022-08-01 16:12:00 【InfoQ】
1 Op与Kernel的注册
REGISTER_USER_OPREGISTER_USER_OP_GRADREGISTER_USER_KERNEL1.1 ReluOp的注册
- class定义:
- build/oneflow/core/framework/op_generated.h
- 注册op、op的部分实现:
- build/oneflow/core/framework/op_generated.cpp
- 主要实现:
- oneflow/oneflow/user/ops/relu_op.cpp
REGISTER_USER_OPop_generated.cppstatic UserOpRegisterTrigger<OpRegistry> g_register_trigger715 =
::oneflow::user_op::UserOpRegistryMgr::Get()
.CheckAndGetOpRegistry("relu")
.Input("x")
.Output("y")
.SetGetSbpFn(&ReluOp::GetSbp)
.SetLogicalTensorDescInferFn(&ReluOp::InferLogicalTensorDesc)
.SetPhysicalTensorDescInferFn(&ReluOp::InferPhysicalTensorDesc)
.SetDataTypeInferFn(&ReluOp::InferDataType);
OpRegistryFinishUserOpRegisterTriggerOpRegistryResultrelu1.2 ReluKernel的注册
REGISTER_USER_KERNELstatic UserOpRegisterTrigger<OpKernelRegistry> g_register_trigger0 =
UserOpRegistryMgr::Get().
CheckAndGetOpKernelRegistry("relu").
.SetCreateFn(...)
.SetIsMatchedHob(UnaryPrimitiveExists(ep::primitive::UnaryOp::kRelu, "y", "x"))
.SetInplaceProposalFn([](const user_op::InferContext&,
const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe<void> {
OF_RETURN_IF_ERROR(AddInplaceArgPairFn("y", 0, "x", 0, true));
return Maybe<void>::Ok();
});result_.create_fn[]() {
return user_op::NewOpKernel<UnaryPrimitiveKernel>(
"y", "x", [](user_op::KernelComputeContext* ctx) {
const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("x", 0);
const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("y", 0);
return ep::primitive::NewPrimitive<ep::primitive::ElementwiseUnaryFactory>(
ctx->device_type(), ep::primitive::UnaryOp::kRelu, src->data_type(),
dst->data_type());
});
}1.3 Op和Kernel注册相关的类关系图

2、UserOpExpr的构造
functional::Relufind("Relu")PackedFunctor<impl::ReluFunctor>callimpl::ReluFunctorclass ReluFunctor { public: ReluFunctor() { op_ = CHECK_JUST(one::OpBuilder("relu").Input("x", 1).Output("y", 1).Build()); } Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x, bool inplace) const { // 忽略inplace相关逻辑 return OpInterpUtil::Dispatch<Tensor>(*op_, {x}); } private: std::shared_ptr<OpExpr> op_;};user opOpBuilder的Build()UserOpExprUserOpExpr- base_attrs_
- tensor_desc_infer_fn_
- dtype_infer_fn_
- device_and_stream_infer_fn_
OpBuilderInput/OutputUserOpConf据OpRegistryResult::op_defUserOpExprUserOpConfUserOpExprBuiltinOpExprImpl<UserOpConf>op_proto_op_proto_UserOpExpr3 、Functor的执行
ReluFunctorOpInterpUtil::Dispatch
3.1 根据环境和输入选择解释器
OpExprInterpreterGetInterpreterInterpreter- LazyInterpreter: 用于lazy mode下的分布式静态图执行模式
- EagerLocalInterpreter: 用于eager local mode本地单卡执行模式(和pytorch单卡或DDP对齐)
- EagerGlobalInterpreter: 用于eager global mode,的分布式动态图执行模式

GetInterpreterAutogradInterpreter::ApplyApply3.2 Apply
Maybe<void> EagerInterpreter::Apply(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) const {#define APPLY_IF(op_type) \ if (const auto* op = dynamic_cast<const op_type##Expr*>(&op_expr)) { \ return ApplyImpl(*op, inputs, outputs, ctx); \ } APPLY_IF(UserOp); APPLY_IF(VariableOp); APPLY_IF(CastToLocalOp); APPLY_IF(CastFromLocalOp); APPLY_IF(GlobalToGlobalOp); APPLY_IF(CastToGlobalOp); APPLY_IF(CastFromGlobalOp); APPLY_IF(DistributeSplitOp); APPLY_IF(DistributeCloneOp); APPLY_IF(DistributeConcatOp); APPLY_IF(DistributeAddOp); APPLY_IF(FunctionOp); APPLY_IF(SelectTopNOp)#undef APPLY_IF OF_UNIMPLEMENTED() << "The type " << op_expr.op_type_name() << " has not been supported in EagerInterpreter::Apply.";}if (const auto* op = dynamic_cast<const UserOpExpr*>(&op_expr)) {
return ApplyImpl(*op, inputs, outputs, ctx);
}Maybe<void> EagerLocalInterpreter::ApplyImpl(const UserOpExpr& op_expr, const TensorTuple& inputs,
TensorTuple* outputs,
const OpExprInterpContext& ctx) const {
return NaiveInterpret(op_expr, inputs, outputs, ctx);
}3.3 NaiveInterpret
- check input tensor的device是否一致
- 生成output tensor
- 为output tensor推导和检查shape/stride/dtype
- 构建op执行指令,并派发至vm
Maybe<void> NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& inputs, const Symbol<Device>& default_device, TensorTuple* outputs, const OpExprInterpContext& ctx) { const auto& attrs = ctx.attrs; // 检查input tensor是否位于相同device上 ... // 推导outout tensor的设备类型 // Infer devices if (!user_op_expr.has_device_and_stream_infer_fn()) { stream = JUST(GetDefaultStreamByDevice(default_device)); for (int i = 0; i < outputs->size(); i++) { auto* tensor_impl = JUST(TensorImpl4Tensor(outputs->at(i))); *JUST(tensor_impl->mut_device()) = default_device; } } else { need_check_mem_case = false; stream = JUST(user_op_expr.InferDeviceAndStream(attrs, inputs, outputs)); } // 推导outout tensor的形状、数据类型 // Infer shapes and dtypes const auto& device_tag = stream->device()->type(); JUST(user_op_expr.InferPhysicalTensorDesc( attrs, device_tag, [&](int32_t i) -> const TensorMeta* { return CHECK_JUST(TensorImpl4Tensor(inputs[i]))->mut_tensor_meta(); }, [&](int32_t i) -> TensorMeta* { // using thread_local TensorMeta pointer if inplace. // using tensor_impl TensorMeta pointer if not inplace. return output_tensor_metas->at(i); })); // 为output tensor初始化eager_blob_object for (int i = 0; i < output_eager_blob_objects->size(); i++) { auto* tensor_impl = JUST(TensorImpl4Tensor(outputs->at(i))); if (!output_eager_blob_objects->at(i)) { if (!JUST(user_op_expr.SupportNonContiguous())) { std::shared_ptr<Stride> stride(new Stride(*tensor_impl->shape())); tensor_impl->mut_tensor_meta()->set_stride(stride); } const auto& dep_object = NewLocalDepObject(); JUST(tensor_impl->InitEagerBlobObject(dep_object)); output_eager_blob_objects->at(i) = JUST(tensor_impl->eager_blob_object()); } else { // output i is inplaced. // check thread_local TensorMeta and tensor_impl TensorMeta. CHECK_OR_RETURN(tensor_impl->tensor_meta()->shape() == output_tensor_metas->at(i)->shape()); CHECK_OR_RETURN(tensor_impl->tensor_meta()->dtype() == output_tensor_metas->at(i)->dtype()); } } // 从user_op_expr中取出kernel const auto& kernel = JUST(user_op_expr.MutKernel4Stream(stream)); kernel->set_need_check_mem_case(need_check_mem_case); for (int64_t index : kernel->output_tuple_indexes4mut2_obns()) { output_eager_blob_objects->at(index)->set_is_shape_synced(false); } // kernel dispatch至VM,等待后续实际的调度执行 JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> { return builder->Call(kernel, input_eager_blob_objects, output_eager_blob_objects, ctx, stream); })); return Maybe<void>::Ok();}OpCall参考资料
- OneFlow学习笔记:Op注册
- (https://mp.weixin.qq.com/s/eF-c2irraxnH4iAesURy0Q)
- 从Functor到OpExprInterpreter
- https://github.com/Oneflow-Inc/oneflow/tree/v0.8.1
- https://zhuanlan.zhihu.com/p/523884650
- OneFlow v0.8.0正式发布
- 解读Pathways:向前一步是OneFlow
- 18张图,直观理解神经网络、流形和拓扑
- OneFlow源码解析:算子签名的自动推断
- 分布式深度学习编程新范式:Global Tensor
- LLVM之父:为什么我们要重建AI基础设施软件
- 大模型训练难?效率超群、易用的“李白”模型库来了
边栏推荐
猜你喜欢
随机推荐
设计专业第一台笔记本 华硕灵耀Pro16 2022 新品首发超值入手
在网站页脚增加几只戏水的小鱼
Ranking of itineraries (summer vacation daily question 12)
Kubernetes 进阶训练营 控制器
月薪12K,蝶变向新勇往直前,我通过转行软件测试实现月薪翻倍...
DOM series of touch screen events
Go unit tests
南京科技大学、中国电子科技第28研究所等联合|MLRIP: Pre-training a military language representation model with informative factual knowledge and professional knowledge base(预训练具有丰富事实知识和专业知识库的军事语言表示模型)
2.8K 120Hz触控双屏加持 灵耀X 双屏Pro 2022让办公无惧想象
mysql 面试题
Spark: Cluster Computing with Working Sets
DOM系列之触屏事件
kubelet节点压力驱逐
珠海市生物安全P3实验室主体结构封顶
Ant discloses the open source layout of core basic software technology for the first time
ESP8266-Arduino编程实例-GA1A12S202对数刻度模拟光传感器
moxa串口服务器配置说明(moxa串口驱动)
请问下怎么取数据库中上一个小时的数据到odps进行实时节点的同步呢
2022年7月最热的10篇AI论文
80篇国产数据库实操文档汇总(含TiDB、达梦、openGauss等)









