当前位置:网站首页>OneFlow源码解析:算子签名的自动推断
OneFlow源码解析:算子签名的自动推断
2022-06-28 17:55:00 【InfoQ】
# import会触发一系列初始化工作,暂时忽略
import oneflow as flow
# tensor的实现其实很复杂,因为要融合local和分布式的global tensor
t = flow.tensor([-1, 0, 1])
r = flow.relu(t)1、编译环境
.
├── build
└── oneflow# docker run -itd -v $PWD/oneflow:/mnt/oneflow -v $PWD/build:/mnt/build \
# manylinux2014_x86_64_cuda11.2 bash
cd /mnt/build
cmake -S /mnt/oneflow
cmake --build . # --parallel 8
cd ../oneflow/python
python3 setup.py bdist_wheel
pip install ./dist/oneflow-0.7.0+cpu-cp38-cp38-linux_x86_64.whlCMAKE_BUILD_TYPE=Debug cmake -S /mnt/oneflow
cmake --build . --parallel 8
source /mnt/build/source.sh
gdb python3
b oneflow::one::MakeLocalTensorFromData
run
import oneflow as flow
flow.Tensor([[1,2,3],[4,5,6]])2、Python Binding
# python/oneflow/__init__.py
from oneflow._C import relu
# python/oneflow/_C/__init__.py
from oneflow._oneflow_internal._C import *PYBIND11_MODULE(_oneflow_internal, m) {
// ...
py::class_<::oneflow::cfg::Message, std::shared_ptr<::oneflow::cfg::Message>>(m, "CfgMessage");
::oneflow::cfg::Pybind11ModuleRegistry().ImportAll(m);
::oneflow::OneflowModuleRegistry().ImportAll(m);
}
using SubModuleMap = std::map<std::string, std::vector<std::function<void(pybind11::module&)>>>;
SubModuleMap* GetSubModuleMap() {
static SubModuleMap sub_module_map;
return &sub_module_map;
}
// 修改map,执行注册
void OneflowModuleRegistry::Register(std::string module_path,
std::function<void(pybind11::module&)> BuildModule) {
(*GetSubModuleMap())[module_path].emplace_back(BuildModule);
}
void OneflowModuleRegistry::ImportAll(pybind11::module& m) {
for (const auto& pair : (*GetSubModuleMap())) {
for (const auto& BuildModule : pair.second) { BuildSubModule(pair.first, m, BuildModule); }
}
}
void OneflowModuleRegistry::BuildSubModule(
const std::string& module_path, pybind11::module& m,
const std::function<void(pybind11::module&)>& BuildModule) {
// ...
BuildModule(m);
// ...
}static void OneflowApiPythonModule9623(pybind11::module&);
namespace {
struct OfApiRegistryInit {
OfApiRegistryInit() {
::oneflow::OneflowModuleRegistry().Register("_C", &OneflowApiPythonModule9623);
}
};
OfApiRegistryInit of_api_registry_init;
}
static void OneflowApiPythonModule9623(pybind11::module & m) {
m.def("relu", &functional::PyFunction<functional::ReluSchema_TTB>);
m.def("pow", &functional::PyFunction<
functional::PowSchema_TTT, functional::ScalarPowSchema_TTScB,
functional::ScalarPowSchema_TTSc, functional::ScalarReversePowSchema_TScT
>);
}3、多个接口签名的自动推断
import oneflow as flow
r = flow.randn(1, 10)
flow.pow(r, 2)
flow.pow(r, flow.ones(1, 10))struct ReluSchema_TTB {
using FType = Maybe<one::Tensor> (const std::shared_ptr<one::Tensor>& x, bool inplace);
using R = Maybe<one::Tensor>;
static constexpr FType* func = &functional::Relu;
static constexpr size_t max_args = 2;
static constexpr size_t max_pos_args = 2;
static constexpr char const* signature = "Tensor (Tensor x, Bool inplace=False)";
static FunctionDef function_def;
};
// SchemaT如 ReluSchema_TTB
template<typename... SchemaT>
class PyFunctionDispatcher {
public:
// schema_t是第I个签名
template<size_t I>
using schema_t = typename std::tuple_element<I, std::tuple<SchemaT...>>::type;
// schema_size_是签名个数,比如relu是1,pow是4
PyFunctionDispatcher() : schema_size_(sizeof...(SchemaT)) {
signatures_.resize(schema_size_);
InitSignatures(std::make_index_sequence<sizeof...(SchemaT)>{});
}
template<size_t I0, size_t... I>
py::object call(const py::args& args, const py::kwargs& kwargs,
std::index_sequence<I0, I...>) const {
// T是当前检查的签名,比如 ReluSchema_TTB
using T = schema_t<I0>;
std::vector<PythonArg> parsed_args(T::max_args);
if (ParseArgs(args, kwargs, &parsed_args, T::function_def, T::max_pos_args,
/*raise_exception*/ schema_size_ == 1)) {
return detail::unpack_call(*T::func, parsed_args);
}
return call(args, kwargs, std::index_sequence<I...>{});
}
py::object call(const py::args& args, const py::kwargs& kwargs, std::index_sequence<>) const {
// throw error ...
return py::none();
}
private:
template<size_t... I>
void InitSignatures(std::index_sequence<I...>) {
__attribute__((__unused__)) int dummy[] = {
((void)(signatures_[I] = schema_t<I>::signature), 0)...};
}
private:
size_t schema_size_;
std::vector<const char*> signatures_;
};
// SchemaT如 ReluSchema_TTB
template<typename... SchemaT>
inline py::object PyFunction(const py::args& args, const py::kwargs& kwargs) {
static PyFunctionDispatcher<SchemaT...> dispatcher;
return dispatcher.call(args, kwargs, std::make_index_sequence<sizeof...(SchemaT)>{});
}
// py module注册
static void OneflowApiPythonModule9623(pybind11::module & m) {
m.def("relu", &functional::PyFunction<functional::ReluSchema_TTB>);
m.def("pow", &functional::PyFunction<
functional::PowSchema_TTT, functional::ScalarPowSchema_TTScB,
functional::ScalarPowSchema_TTSc, functional::ScalarReversePowSchema_TScT
>);
}- positional与keyword参数类型冲突
- 签名中的keyword参数名在kwargs中不存在且不接受默认值
- 参数类型不符合PythonArgCheck规定的内部类型检查要求
- kwargs包含function_def中未定义的参数
- 将args展开为各个PythonArg元素,通过index_sequence和变长模版参数包的展开实现;
- 利用function_traits推导得到函数参数类型列表ArgsType;
- As函数调用可简化为As<typename tuple_element<I, typename ArgsType>>()...核心是拿到各个参数的实际类型并交给As处理,最终调用ObjectAs实现各种内部数据类型的转换。
class PythonArg {
template<typename T>
T As() const {
return ObjectAsHelper<oneflow::detail::remove_cvref_t<T>>()(this).GetOrThrow();
}
};
template<typename F, typename R>
struct unpack_call_dispatcher {
template<size_t... I>
static R apply(const F& f, const std::vector<PythonArg>& args, std::index_sequence<I...>) {
// 这里适当改写了一下,把ArgsType抽出来
using ArgsType = function_traits<F>::args_type;
return f(args[I]
.As<oneflow::detail::remove_cvref_t<typename std::tuple_element<
I, typename ArgsType>::type>>()...);
}
};
template<typename F>
py::object unpack_call(const F& f, const std::vector<PythonArg>& args) {
constexpr size_t nargs = function_traits<F>::nargs;
using R = typename function_traits<F>::return_type;
return CastToPyObject(
unpack_call_dispatcher<F, R>::apply(f, args, std::make_index_sequence<nargs>{}));
}File ".../oneflow/api/python/functional/py_function.h", line 76, in call
TypeError: pow(): received an invalid combination of arguments. The valid signatures are:
*0: Tensor (Tensor input, Tensor exponent)
*1: Tensor (Tensor input, Scalar exponent, *, Bool inplace=False)
*2: Tensor (Tensor input, Scalar exponent)
*3: Tensor (Scalar exponent, Tensor input)flow.relu(1)
TypeException:
File ".../oneflow/api/python/functional/py_function.cpp", line 98, in ParseArgs
TypeError: relu(): argument 'x' must be tensor, not int- PyFunction是pybind11的def定义的入口函数,并为算子保存一个dispatcher对象用于推断合适的签名;
- PyFunctionDispatcher通过模版函数的递归调用实现了签名的自动筛选,通过成员变量为参数校验和异常提示保存必要的信息;
- unpack_call在编译期就确定了具体执行的算子函数类型,这一点在PyFunctionDispatcher中是无法做到的;
- unpack_call_dispatcher的作用是将vector展开为多个元素、作为调用算子函数的参数,这在unpack_call中也是无法做到的;
- PythonArg是Python与C++类型转换的桥梁,同时承担类型检查的职能;
- 基于yaml生成的2组文件,yaml.pybind.cpp中调用pybind11的m.def指定模块调用的函数,并定义了函数签名的Schema结构作为PyFunction的模版参数。yaml.cpp中则定义了具体的执行函数,如Relu。将二者衔接起来的就是Schema的字段func,对于Relu算子来说,签名Schema的func字段就是函数functional:Relu。
4、算子Functor的注册与执行
static void _oneflow_function_library_0(FunctionLibrary & m);
// 以定义一个静态变量的方式调用注册函数
static int _oneflow_function_library_dummy_0 = []() {
FunctionLibrary* library = FunctionLibrary::Global();
_oneflow_function_library_0(*library);
return 0;
}();
void _oneflow_function_library_0(FunctionLibrary & m) {
m.add_functor<impl::ReluFunctor>("Relu");
};[=]() {
// Func如 impl::ReluFunctor
Func func;
// func_name来自lambda绑定,如Relu
return PackedFunctorMaker<func_type>::make(func_name, func);
}
Maybe<one::Tensor> Relu(const std::shared_ptr<one::Tensor>& x, bool inplace) {
static thread_local const auto& __op = CHECK_JUST(
FunctionLibrary::Global()->find
<
Maybe<one::Tensor>,
const std::shared_ptr<one::Tensor>&,
bool
> ("Relu"));
return __op->call(x, inplace);
}
[=]() {
// Func如 impl::ReluFunctor
Func func;
// func_name来自lambda绑定,如Relu
return PackedFunctorMaker<func_type>::make(func_name, func);
}// func是一个函数变量,类型如 impl::ReluFunctor
[func](const remove_cvref_t<Args>&... args) -> R {
return func(std::forward<const remove_cvref_t<Args>&>(args)...);
}- 同一个名字可能对应多个Functor。所以不能只用名字作为Functor的key,需要结合签名。
- FunctionLibrary负责管理所有的Functor。但是单例不适合作为模版类,所以通过内嵌的PackedFuncCreatorMap保存签名各异的Functor。
- 每种签名都会特化一个PackedFuncCreatorMap模版类,再通过名字区分不同的Functor。
- 首先,yaml生成的2个cpp文件,都没有Functor信息,只有Relu这个名字、以及Functor的签名信息。Functor是在各个模块根据名字注册的。yaml与FunctionLibrary通过名字和签名进行交互。
- 其次,FunctionLibrary::find返回的PackedFunctor是带模版参数的(参数就是Functor签名)。find能否直接返回Functor对象呢?主要是map不便存储不同类型的Functor。即使Functor都有共同的虚基类、map的value存储指针,但不能要求所有Functor的执行接口是一致的,虚函数不满足这个场景的需求。所以find不能直接返回Functor对象。
- PackedFunctor的作用就在于,它把真正的Functor包在自己的结构里面;它的模版参数与Functor的调用接口一致;它的call方法将Op的所有入参通过lambda转发给Functor。
- Functor能直接作为PackedFunctor的成员变量吗?应该是可以的。PackedFunctorMaker::make的模版参数也包含Functor。但是这样每个Functor都要特化一个PackedFunctor,编译后的可执行程序容易膨胀。而现在的实现,PackedFunctor只根据Functor执行函数签名特化,代价是要做一次调用转发(编译器有优化空间?)。
- 从Python到C++调用过程分析
- https://github.com/Oneflow-Inc/oneflow/tree/release/0.7.0
- 深度学习概述
- 一个算子在深度学习框架中的旅程
- 手把手推导分布式矩阵乘的最优并行策略
- 训练千亿参数大模型,离不开四种并行策略
- 解读Pathways(二):向前一步是OneFlow
- 关于并发和并行,Go和Erlang之父都弄错了?
- OneFlow v0.7.0发布:全新分布式接口,LiBai、Serving等一应俱全
边栏推荐
- Introduction to apifox
- From Mogao Grottoes to the Pacific Ocean, massive data have found new homes
- Small program graduation project based on wechat examination small program graduation project opening report function reference
- oracle cdc 但是使用的服务名没有sid 该怎么配置呢?
- Does the dataworks SQL script support if else judgment of statement blocks
- JQ plug-in analysis
- Go 降序排序 取 Top N
- Node foundation ~ node level
- About the solution of "modulenotfounderror: no module named 'flask.\u compat'"
- Mycat+分库分表
猜你喜欢

2022 practice questions and mock examination of Shandong Province safety officer C certificate examination

Squid proxy server application (I came from afar to make an appointment with you)

Database mysql statement final review CTGU

Applet graduation design based on wechat conference room reservation applet graduation design opening report function reference

DNSLog注入

An error is reported when ActiveMQ is started. The 1883 port occupation problem is solved

io模型初探

GCC getting started manual

Visio use

Small program graduation project based on wechat subscription water supply mall small program graduation project opening report function reference
随机推荐
2022年6月27日-2022年7月3日(ue4视频教程)
Why insert is configured with'select last_ INSERT_ What if id() 'returns 0?
Go 降序排序 取 Top N
Industrial digitalization and new generation digitalization system design platform -- Lecture
如何高效优雅地管理接口文档
The MySQL installed in Alibaba cloud server is version 8. Is it because the MySQL driver version of dataworks does not support it? Now mention
Google launches advanced API security to protect APIs from security threats
如何使用 SAP CDS view 中的 currency conversion 功能
Exploration and practice of reinforcement learning in yellow page merchants' intelligent chat assistant
Small program graduation project based on wechat agricultural and sideline products agricultural products mall small program graduation project opening report function reference
Visio use
工业数字化与新一代数字化系统设计平台----讲座
Can data sources only be connected to Alibaba cloud cloud databases? Can't you connect the databases installed in Alibaba cloud servers?
io模型初探
From Mogao Grottoes to the Pacific Ocean, massive data have found new homes
EasyExcel 学习笔记
How much is the data delay when you collect Oracle data? I can't keep it down for 3 seconds. Is there an industry reference
2022 chemical automation control instrument test simulation 100 questions simulation test platform operation
DMS的SQL结果集导出支持传参数吗?
How to upgrade from RHEL 8 to RHEL 9