当前位置:网站首页>OneFlow源码解析:Op、Kernel与解释器
OneFlow源码解析:Op、Kernel与解释器
2022-08-01 16:12:00 【InfoQ】
1 Op与Kernel的注册
REGISTER_USER_OP
REGISTER_USER_OP_GRAD
REGISTER_USER_KERNEL
1.1 ReluOp的注册
- class定义:
- build/oneflow/core/framework/op_generated.h
- 注册op、op的部分实现:
- build/oneflow/core/framework/op_generated.cpp
- 主要实现:
- oneflow/oneflow/user/ops/relu_op.cpp
REGISTER_USER_OP
op_generated.cpp
static UserOpRegisterTrigger<OpRegistry> g_register_trigger715 =
::oneflow::user_op::UserOpRegistryMgr::Get()
.CheckAndGetOpRegistry("relu")
.Input("x")
.Output("y")
.SetGetSbpFn(&ReluOp::GetSbp)
.SetLogicalTensorDescInferFn(&ReluOp::InferLogicalTensorDesc)
.SetPhysicalTensorDescInferFn(&ReluOp::InferPhysicalTensorDesc)
.SetDataTypeInferFn(&ReluOp::InferDataType);

OpRegistry
Finish
UserOpRegisterTrigger
OpRegistryResult
relu
1.2 ReluKernel的注册
REGISTER_USER_KERNEL
static UserOpRegisterTrigger<OpKernelRegistry> g_register_trigger0 =
UserOpRegistryMgr::Get().
CheckAndGetOpKernelRegistry("relu").
.SetCreateFn(...)
.SetIsMatchedHob(UnaryPrimitiveExists(ep::primitive::UnaryOp::kRelu, "y", "x"))
.SetInplaceProposalFn([](const user_op::InferContext&,
const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe<void> {
OF_RETURN_IF_ERROR(AddInplaceArgPairFn("y", 0, "x", 0, true));
return Maybe<void>::Ok();
});
result_.create_fn
[]() {
return user_op::NewOpKernel<UnaryPrimitiveKernel>(
"y", "x", [](user_op::KernelComputeContext* ctx) {
const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("x", 0);
const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("y", 0);
return ep::primitive::NewPrimitive<ep::primitive::ElementwiseUnaryFactory>(
ctx->device_type(), ep::primitive::UnaryOp::kRelu, src->data_type(),
dst->data_type());
});
}
1.3 Op和Kernel注册相关的类关系图

2、UserOpExpr的构造
functional::Relu
find("Relu")
PackedFunctor<impl::ReluFunctor>
call
impl::ReluFunctor
class ReluFunctor { public: ReluFunctor() { op_ = CHECK_JUST(one::OpBuilder("relu").Input("x", 1).Output("y", 1).Build()); } Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x, bool inplace) const { // 忽略inplace相关逻辑 return OpInterpUtil::Dispatch<Tensor>(*op_, {x}); } private: std::shared_ptr<OpExpr> op_;};
user op
OpBuilder的Build()
UserOpExpr
UserOpExpr
- base_attrs_
- tensor_desc_infer_fn_
- dtype_infer_fn_
- device_and_stream_infer_fn_
OpBuilder
Input/Output
UserOpConf
据OpRegistryResult::op_def
UserOpExpr
UserOpConf
UserOpExpr
BuiltinOpExprImpl<UserOpConf>
op_proto_
op_proto_
UserOpExpr
3 、Functor的执行
ReluFunctor
OpInterpUtil::Dispatch

3.1 根据环境和输入选择解释器
OpExprInterpreter
GetInterpreter
Interpreter
- LazyInterpreter: 用于lazy mode下的分布式静态图执行模式
- EagerLocalInterpreter: 用于eager local mode本地单卡执行模式(和pytorch单卡或DDP对齐)
- EagerGlobalInterpreter: 用于eager global mode,的分布式动态图执行模式

GetInterpreter
AutogradInterpreter::Apply
Apply
3.2 Apply
Maybe<void> EagerInterpreter::Apply(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) const {#define APPLY_IF(op_type) \ if (const auto* op = dynamic_cast<const op_type##Expr*>(&op_expr)) { \ return ApplyImpl(*op, inputs, outputs, ctx); \ } APPLY_IF(UserOp); APPLY_IF(VariableOp); APPLY_IF(CastToLocalOp); APPLY_IF(CastFromLocalOp); APPLY_IF(GlobalToGlobalOp); APPLY_IF(CastToGlobalOp); APPLY_IF(CastFromGlobalOp); APPLY_IF(DistributeSplitOp); APPLY_IF(DistributeCloneOp); APPLY_IF(DistributeConcatOp); APPLY_IF(DistributeAddOp); APPLY_IF(FunctionOp); APPLY_IF(SelectTopNOp)#undef APPLY_IF OF_UNIMPLEMENTED() << "The type " << op_expr.op_type_name() << " has not been supported in EagerInterpreter::Apply.";}
if (const auto* op = dynamic_cast<const UserOpExpr*>(&op_expr)) {
return ApplyImpl(*op, inputs, outputs, ctx);
}
Maybe<void> EagerLocalInterpreter::ApplyImpl(const UserOpExpr& op_expr, const TensorTuple& inputs,
TensorTuple* outputs,
const OpExprInterpContext& ctx) const {
return NaiveInterpret(op_expr, inputs, outputs, ctx);
}
3.3 NaiveInterpret
- check input tensor的device是否一致
- 生成output tensor
- 为output tensor推导和检查shape/stride/dtype
- 构建op执行指令,并派发至vm
Maybe<void> NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& inputs, const Symbol<Device>& default_device, TensorTuple* outputs, const OpExprInterpContext& ctx) { const auto& attrs = ctx.attrs; // 检查input tensor是否位于相同device上 ... // 推导outout tensor的设备类型 // Infer devices if (!user_op_expr.has_device_and_stream_infer_fn()) { stream = JUST(GetDefaultStreamByDevice(default_device)); for (int i = 0; i < outputs->size(); i++) { auto* tensor_impl = JUST(TensorImpl4Tensor(outputs->at(i))); *JUST(tensor_impl->mut_device()) = default_device; } } else { need_check_mem_case = false; stream = JUST(user_op_expr.InferDeviceAndStream(attrs, inputs, outputs)); } // 推导outout tensor的形状、数据类型 // Infer shapes and dtypes const auto& device_tag = stream->device()->type(); JUST(user_op_expr.InferPhysicalTensorDesc( attrs, device_tag, [&](int32_t i) -> const TensorMeta* { return CHECK_JUST(TensorImpl4Tensor(inputs[i]))->mut_tensor_meta(); }, [&](int32_t i) -> TensorMeta* { // using thread_local TensorMeta pointer if inplace. // using tensor_impl TensorMeta pointer if not inplace. return output_tensor_metas->at(i); })); // 为output tensor初始化eager_blob_object for (int i = 0; i < output_eager_blob_objects->size(); i++) { auto* tensor_impl = JUST(TensorImpl4Tensor(outputs->at(i))); if (!output_eager_blob_objects->at(i)) { if (!JUST(user_op_expr.SupportNonContiguous())) { std::shared_ptr<Stride> stride(new Stride(*tensor_impl->shape())); tensor_impl->mut_tensor_meta()->set_stride(stride); } const auto& dep_object = NewLocalDepObject(); JUST(tensor_impl->InitEagerBlobObject(dep_object)); output_eager_blob_objects->at(i) = JUST(tensor_impl->eager_blob_object()); } else { // output i is inplaced. // check thread_local TensorMeta and tensor_impl TensorMeta. CHECK_OR_RETURN(tensor_impl->tensor_meta()->shape() == output_tensor_metas->at(i)->shape()); CHECK_OR_RETURN(tensor_impl->tensor_meta()->dtype() == output_tensor_metas->at(i)->dtype()); } } // 从user_op_expr中取出kernel const auto& kernel = JUST(user_op_expr.MutKernel4Stream(stream)); kernel->set_need_check_mem_case(need_check_mem_case); for (int64_t index : kernel->output_tuple_indexes4mut2_obns()) { output_eager_blob_objects->at(index)->set_is_shape_synced(false); } // kernel dispatch至VM,等待后续实际的调度执行 JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> { return builder->Call(kernel, input_eager_blob_objects, output_eager_blob_objects, ctx, stream); })); return Maybe<void>::Ok();}
OpCall
参考资料
- OneFlow学习笔记:Op注册
- (https://mp.weixin.qq.com/s/eF-c2irraxnH4iAesURy0Q)
- 从Functor到OpExprInterpreter
- https://github.com/Oneflow-Inc/oneflow/tree/v0.8.1
- https://zhuanlan.zhihu.com/p/523884650
- OneFlow v0.8.0正式发布
- 解读Pathways:向前一步是OneFlow
- 18张图,直观理解神经网络、流形和拓扑
- OneFlow源码解析:算子签名的自动推断
- 分布式深度学习编程新范式:Global Tensor
- LLVM之父:为什么我们要重建AI基础设施软件
- 大模型训练难?效率超群、易用的“李白”模型库来了
边栏推荐
猜你喜欢
便携烙铁开源系统IronOS,支持多款便携DC, QC, PD供电烙铁,支持所有智能烙铁标准功能
DOM树jsjs特效代码
MySQL INTERVAL 关键字指南
清华教授发文劝退读博:我见过太多博士生精神崩溃、心态失衡、身体垮掉、一事无成!...
南京科技大学、中国电子科技第28研究所等联合|MLRIP: Pre-training a military language representation model with informative factual knowledge and professional knowledge base(预训练具有丰富事实知识和专业知识库的军事语言表示模型)
27英寸横置大屏+实体按键,全新探险者才是安全而合理的做法!
蚂蚁首次披露核心基础软件技术开源版图
AI艺术‘美丑’不可控?试试 AI 美学评分器~
使用Canvas实现网页鼠标签名效果
设计专业第一台笔记本 华硕灵耀Pro16 2022 新品首发超值入手
随机推荐
显示为弹出窗口是什么意思(电脑总是弹出广告)
IronOS, an open source system for portable soldering irons, supports a variety of portable DC, QC, PD powered soldering irons, and supports all standard functions of smart soldering irons
Ant discloses the open source layout of core basic software technology for the first time
1个月写900多条用例,2线城市年薪33W+的测试经理能有多卷?
预定义和自定义
pytorch测试的时候为何要加上model.eval()?
Spark: Cluster Computing with Working Sets
经验|如何做好业务测试?
LeetCode50天刷题计划(Day 6—— 整数反转 14.20-15.20)
HashCode technology insider interview must ask
计算机系统与网络安全技术——第一章——信息安全概述——1.1-网络安全定义——什么是信息?
【无标题】
【repo】SyntaxError: invalid syntax
LeetCode50天刷题计划(Day 9—— 整数转罗马数字(20.40-22.10)
wordpress模板函数说明备注整理收藏
Arduino无线下载 Arduino USB接口无线自动下载程序
pynlpir更新license Error: unable to fetch newest license解决方案
链滴的几个 Markdown 语法没有渲染
intentservice使用(Intention)
moxa串口服务器配置说明(moxa串口驱动)