当前位置:网站首页>Cann operator: using iterators to efficiently realize tensor data cutting and blocking processing
Cann operator: using iterators to efficiently realize tensor data cutting and blocking processing
2022-07-04 19:53:00 【InfoQ】
Mission scenarios and objectives
Conventional scheme :
Prepare knowledge and analysis
1. step
2. iterator
template <typename T>
class PositionIterator {
public:
PositionIterator(){};
~PositionIterator(){};
PositionIterator(std::vector<T> stt, std::vector<T> sh) {
if (stt.size() != sh.size()) {
PositionIterator();
} else {
for (unsigned int i = 0; i < sh.size(); i++) {
if (stt[i] >= sh[i]) {
PositionIterator();
}
}
pos_ = stt;
shape_ = sh;
}
}
PositionIterator operator++() {
pos_[shape_.size() - 1] += 1;
for (unsigned int i = shape_.size() - 1; i > 0; i--) {
if (pos_[i] / shape_[i] != 0) {
pos_[i - 1] += pos_[i] / shape_[i];
pos_[i] = pos_[i] % shape_[i];
}
}
return *this;
}
bool End() {
if (pos_[0] != shape_[0]) {
return false;
}
return true;
}
std::vector<T> GetPos() { return pos_; }
std::vector<T> GetShape() { return shape_; }
private:
std::vector<T> pos_;
std::vector<T> shape_;
};
Diagonal The realization of operators

template <typename T>
uint32_t DiagonalCpuKernel::DoComputeType(CpuKernelContext &ctx,
const int64_t &offset,
const int64_t &dim1,
const int64_t &dim2) {
// Get the inuput and output
Tensor *input_x = ctx.Input(0);
Tensor *y = ctx.Output(0);
// Get some information of input
auto x_shape = input_x->GetTensorShape();
std::vector<int64_t> x_shape_ = x_shape->GetDimSizes();
const int64_t x_dim = x_shape->GetDims();
auto dataptr = reinterpret_cast<T *>(ctx.Input(0)->GetData());
auto y_dataptr = reinterpret_cast<T *>(y->GetData());
// Compute
// First, calculate the number of diagonal elements
int64_t dsize = OffsetSize(offset, dim1, dim2, x_shape_);
// To generate the input Tensor Step vector of x_stride
std::vector<int64_t> x_stride = ConstructStride<int64_t>(x_shape_);
// Discussion by situation ,2 Peacekeeping greater than 2 The d
if (x_dim != N2) {
//set the vx_shape and vx_stride
// Generate x_shape and x_stride Remove from dim1 and dim2 Corresponding to vx_shape And vx_stride
std::vector<int64_t> vx_shape, vx_stride;
for (unsigned int tmp_dim = 0; tmp_dim < x_shape_.size(); tmp_dim++) {
if (tmp_dim != dim1 && tmp_dim != dim2) {
vx_shape.push_back(x_shape_[tmp_dim]);
vx_stride.push_back(x_stride[tmp_dim]);
}
}
// set the y_shape, y_stride, vy_stride
// Generate output Tensor Shape and step vector of :y_shape and y_stride
std::vector<int64_t> y_shape = vx_shape;
y_shape.push_back(dsize);
std::vector<int64_t> y_stride =
ConstructStride<int64_t>(y_shape);
// Generate output Tensor Out of the last one-dimensional step vector :vy_stride
std::vector<int64_t> vy_stride = y_stride;
vy_stride.pop_back();
// Read diagonal data
std::vector<int64_t> v_start(vx_shape.size(), 0);
for (PositionIterator<int64_t> myiter(v_start, vx_shape); !myiter.End();
++myiter) {
// Use the iterator to determine the division dim1 and dim2 Position coordinates of dimensions
auto p = myiter.GetPos();
// The basic position values of input and output are calculated by step vector and position coordinates base_pos1 and outbase_pos
int64_t base_pos1 = MulSum<int64_t>(p, vx_stride);
int64_t outbase_pos = MulSum<int64_t>(p, vy_stride);
for (int i = 0; i < dsize; i++) {
// Combined with the foundation position value calculated above , Yes dim1 and dim2 The corresponding dimension determines the position of diagonal elements , And assign it to the output data address (get_data It involves taking elements from the upper diagonal or the lower diagonal , It does not affect the understanding of the function of iterators )
int64_t base_pos2 = i * (x_stride[dim1] + x_stride[dim2]);
int64_t arr[N2] = {x_stride[dim1], x_stride[dim2]};
y_dataptr[outbase_pos + i] =
get_data(base_pos1 + base_pos2, offset, arr, dataptr);
}
}
} else {
for (int i = 0; i < dsize; i++) {
int64_t base_pos = i * (x_stride[dim1] + x_stride[dim2]);
int64_t arr[N2] = {x_stride[dim1], x_stride[dim2]};
y_dataptr[i] = get_data(base_pos, offset, arr, dataptr);
}
}
return KERNEL_STATUS_OK;
}
Other uses of iterators
for (position_iterator<int64_t> mit(v_start, v_shape); !mit.end(); ++mit) {
auto p = mit.get_pos();
int axis_len = input_shape_[tmp_axis];
std::vector<ValueIndex<T>> data_(axis_len);
int base_pos = mul_sum<int64_t>(p, v_stride);
for (int32_t i = 0; i < axis_len; i++) {
data_[i].value = x_dataptr[base_pos + i * input_stride[tmp_axis]];
data_[i].index = i;
}
std::vector<std::vector<T1>> data_;
for (int64_t i = 0; i < dim0; i++) {
std::vector<T1> tmp_v1;
for (PositionIterator<int64_t> mit(v_start, v_shape); !mit.End(); ++mit) {
auto pos = mit.GetPos();
tmp_v1.push_back(
x_dataptr[MulSum<int64_t>(pos, v_stride) + i * input_stride[axis]]);
}
data_.push_back(tmp_v1);
}
边栏推荐
- 需求开发思考
- Kotlin cycle control
- c# . Net MVC uses Baidu ueditor rich text box to upload files (pictures, videos, etc.)
- 1002. A+b for Polynomials (25) (PAT class a)
- 求2的n次方
- Abc229 summary (connected component count of the longest continuous character graph in the interval)
- 多表操作-内连接查询
- "Only one trip", active recommendation and exploration of community installation and maintenance tasks
- BCG 使用之CBCGPTabWnd控件(相当于MFC TabControl)
- 如何使用Async-Awati异步任務處理代替BackgroundWorker?
猜你喜欢
There are multiple divs in the large div, which are displayed on the same line. After overflow, scroll bars are generated without line breaks
Creation of JVM family objects
Explore the contour drawing function drawcontours() of OpenCV in detail with practical examples
HMM隐马尔可夫模型最详细讲解与代码实现
C语言-入门-基础-语法-流程控制(七)
Mysql database basic operation -ddl | dark horse programmer
记一次 .NET 某工控数据采集平台 线程数 爆高分析
Online sql to excel (xls/xlsx) tool
TCP两次挥手,你见过吗?那四次握手呢?
Chrome开发工具:VMxxx文件是什么鬼
随机推荐
The company needs to be monitored. How do ZABBIX and Prometheus choose? That's the right choice!
abc229 总结(区间最长连续字符 图的联通分量计数)
Pythagorean number law (any three numbers can meet the conditions of Pythagorean theorem)
1011 World Cup betting (20 points) (pat a)
有关架构设计的个人思考(本文后续不断修改更新)
Lm10 cosine wave homeopathic grid strategy
1009 Product of Polynomials(25 分)(PAT甲级)
Cbcgpprogressdlgctrl progress bar used by BCG
记一次 .NET 某工控数据采集平台 线程数 爆高分析
HMM hidden Markov model and code implementation
Crawler (6) - Web page data parsing (2) | the use of beautifulsoup4 in Crawlers
Master the use of auto analyze in data warehouse
How test engineers "attack the city" (Part I)
Educational Codeforces Round 22 E. Army Creation
1002. A+B for Polynomials (25)(PAT甲级)
【问题】druid报异常sql injection violation, part alway true condition not allow 解决方案
【毕业季】绿蚁新醅酒,红泥小火炉。晚来天欲雪,能饮一杯无?
求2的n次方
Comment utiliser async awati asynchrone Task Handling au lieu de backgroundworker?
勾股数规律(任意三个数能够满足勾股定理需要满足的条件)