Skip to content

Commit 777d9c8

Browse files
committed
feat: 优化Pad算子
1 parent f344ae8 commit 777d9c8

File tree

16 files changed

+426
-196
lines changed

16 files changed

+426
-196
lines changed
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#ifndef KERNEL_CUDA_PAD_CUH
2+
#define KERNEL_CUDA_PAD_CUH
3+
4+
#include "threads_distributer.cuh"
5+
#include <cstdint>
6+
7+
namespace refactor::kernel::cuda {
8+
9+
struct DimInfo {
10+
unsigned int strideI, strideO, padS, dimI;
11+
};
12+
13+
void launchPad(
14+
KernelLaunchParameters const &,
15+
uint8_t const *src, uint8_t const *src_const,
16+
DimInfo const *dims, void *output,
17+
unsigned int rank,
18+
unsigned int blockSize);
19+
20+
}// namespace refactor::kernel::cuda
21+
22+
#endif// KERNEL_CUDA_PAD_CUH

src/04kernel/cuda/src/pad.cu

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
#include "kernel/cuda/pad.cuh"
2+
#include "macro.cuh"
3+
#include <cstdint>
4+
5+
namespace refactor::kernel::cuda {
6+
7+
__global__ static void padKernel(
8+
unsigned long long n,
9+
uint8_t const *__restrict__ src,
10+
uint8_t const *__restrict__ src_const,
11+
DimInfo const *__restrict__ dims,
12+
uint8_t *__restrict__ dst,
13+
unsigned int rank,
14+
unsigned int blockSize) {
15+
for (auto tid = blockIdx.x * blockDim.x + threadIdx.x,
16+
step = blockDim.x * gridDim.x;
17+
tid < n;
18+
tid += step) {
19+
long rem = tid, j = 0;
20+
bool flag = false;
21+
for (auto i = 0; i < rank; ++i) {
22+
auto strideO = __ldg(&(dims[i].strideO));
23+
auto strideI = __ldg(&(dims[i].strideI));
24+
auto padS = __ldg(&(dims[i].padS));
25+
auto dimI = __ldg(&(dims[i].dimI));
26+
auto pos = rem / strideO - padS;
27+
if (pos < 0 || pos >= dimI) {
28+
flag = true;
29+
break;
30+
}
31+
j += pos * strideI;
32+
rem %= strideO;
33+
}
34+
if (flag) {
35+
optimizedMemcpy(dst + tid * blockSize, src_const, blockSize);
36+
} else {
37+
optimizedMemcpy(dst + tid * blockSize, src + j * blockSize, blockSize);
38+
}
39+
}
40+
}
41+
42+
void launchPad(
43+
KernelLaunchParameters const &params,
44+
uint8_t const *src, uint8_t const *src_const,
45+
DimInfo const *dims, void *output,
46+
unsigned int rank,
47+
unsigned int blockSize) {
48+
49+
50+
padKernel<<<
51+
params.gridSize,
52+
params.blockSize,
53+
0,
54+
reinterpret_cast<cudaStream_t>(params.stream)>>>(
55+
params.n,
56+
src,
57+
src_const,
58+
dims,
59+
reinterpret_cast<uint8_t *>(output),
60+
rank,
61+
blockSize);
62+
}
63+
64+
}// namespace refactor::kernel::cuda

src/04kernel/include/kernel/attributes/pad_info.h

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -37,23 +37,28 @@ namespace refactor::kernel {
3737
}
3838
};
3939

40-
using PadsShape = absl::InlinedVector<int64_t, 4>;
40+
namespace pad {
41+
struct Dim {
42+
int64_t dimI, dimO, pads;
43+
};
44+
}// namespace pad
4145

46+
using PadDimension = std::vector<pad::Dim>;
4247

4348
struct PadInfo {
44-
int rank;
45-
PadType mode;
46-
PadsShape pads;
47-
PadsShape wholeNDim;
48-
PadsShape partNDim;
49-
PadsShape partStride;
50-
DataType type;
51-
bool have_value;
52-
size_t size;
53-
54-
explicit PadInfo(PadsShape, PadType, Tensor const &, Tensor const &, bool) noexcept;
55-
};
49+
struct Dim {
50+
dim_t strideI, strideO, padS, dimI;
51+
52+
// bool operator==(Dim const &) const noexcept;
53+
// bool operator!=(Dim const &) const noexcept;
54+
};
55+
std::vector<Dim> dims;
56+
dim_t blockCount, blockSize;
5657

58+
PadInfo(decltype(dims), dim_t, dim_t) noexcept;
59+
PadInfo(PadDimension, Tensor const &);
60+
void reform(dim_t) noexcept;
61+
};
5762

5863
}// namespace refactor::kernel
5964

src/04kernel/include/kernel/collectors/pad.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77
namespace refactor::kernel {
88

99
struct PadCollector final : public InfoCollector {
10-
PadsShape pads;
10+
PadDimension dims;
1111
PadType mode;
1212

13-
explicit PadCollector(decltype(_target) target, PadsShape const &pads_, PadType mode_) noexcept
14-
: InfoCollector(target), pads(std::move(pads_)), mode(mode_) {}
13+
explicit PadCollector(decltype(_target) target, PadDimension const &dims_, PadType mode_) noexcept
14+
: InfoCollector(target), dims(std::move(dims_)), mode(mode_) {}
1515

1616
std::vector<KernelBox>
1717
filter(TensorRefs inputs, TensorRefs outputs) const final;
Lines changed: 79 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,88 @@
11
#include "kernel/attributes/pad_info.h"
2-
#include <iostream>
32
#include <numeric>
43

54
namespace refactor::kernel {
5+
using PI = PadInfo;
66

7-
PadInfo::PadInfo(
8-
PadsShape pads_,
9-
PadType mode_,
10-
Tensor const &x,
11-
Tensor const &y,
12-
bool have_value_) noexcept : rank(x.rank()), mode(mode_), pads(std::move(pads_)), wholeNDim(rank, 0),
13-
partNDim(rank, 0), partStride(rank, 1), type(x.dataType), have_value(have_value_),
14-
size(0) {
15-
int64_t p = 1;
16-
for (auto i = rank - 1; i >= 0; --i) {
17-
wholeNDim[i] = y.shape[i];
18-
partNDim[i] = x.shape[i];
19-
partStride[i] = p;
20-
p = p * partNDim[i];
7+
// bool PI::Dim::operator==(Dim const &rhs) const noexcept {
8+
// return strideI == rhs.strideI &&
9+
// strideO == rhs.strideO &&
10+
// padStride == rhs.padStride &&
11+
// dimt.dimI == rhs.dimI &&;
12+
// }
13+
// bool PI::Dim::operator!=(Dim const &rhs) const noexcept {
14+
// return !operator==(rhs);
15+
// }
16+
17+
PI::PadInfo(decltype(dims) dims_, dim_t blockCount_, dim_t blockSize_) noexcept
18+
: dims(std::move(dims_)), blockCount(blockCount_), blockSize(blockSize_) {}
19+
20+
PI::PadInfo(PadDimension dims_, Tensor const &input) : dims{}, blockCount(1),
21+
blockSize(input.dataType.size()) {
22+
size_t rank = input.rank();
23+
ASSERT(dims_.size() == rank, "Invalid to get PadInfo.");
24+
25+
// std::vector<dim_t> shape;
26+
size_t j = 0;
27+
for (auto i : range0_(rank)) {
28+
if (dims_[i].dimI != dims_[i].dimO || dims_[i].dimI != 1) {
29+
if (j < i) { dims_[j] = dims_[i]; }
30+
//shape.push_back(dims_[i].dimI);
31+
j++;
32+
}
33+
}
34+
dims_.resize(rank = j);
35+
// 合并末尾连续维度
36+
for (auto i : range0_(rank).rev()) {
37+
if (auto d = dims_[i].dimI; d == dims_[i].dimO) {
38+
blockSize *= d;
39+
dims_.pop_back();
40+
} else {
41+
dims.reserve(rank = dims_.size());
42+
auto &dim = dims_[i];
43+
if (auto times = std::gcd(std::gcd(dims_[i].dimI, dims_[i].pads), dims_[i].dimO); times > 1) {
44+
blockSize *= times;
45+
dim.dimI /= times;
46+
dim.dimO /= times;
47+
dim.pads /= times;
48+
}
49+
break;
50+
}
51+
}
52+
53+
dim_t strideI = 1, strideO = 1;
54+
for (auto i : range0_(rank).rev()) {
55+
auto const &dim = dims_[i];
56+
dims.push_back({
57+
strideI,
58+
strideO,
59+
static_cast<dim_t>(dim.pads),
60+
static_cast<dim_t>(dim.dimI),
61+
});
62+
strideI *= dim.dimI;
63+
strideO *= dim.dimO;
64+
}
65+
std::reverse(dims.begin(), dims.end());
66+
// for (auto i : range0_(rank)) {
67+
// fmt::println("strideI = {}, strideO = {}, padS = {}, dimI = {}", dims[i].strideI, dims[i].strideO, dims[i].padS, dims[i].dimI);
68+
// }
69+
blockCount = strideO;
70+
}
71+
72+
void PI::reform(dim_t maxblockSize) noexcept {
73+
auto blockSize_ = std::gcd(blockSize, maxblockSize);
74+
if (blockSize_ == blockSize) { return; }
75+
auto t = blockSize / blockSize_;
76+
blockCount *= t;
77+
blockSize = blockSize_;
78+
for (auto &d : dims) {
79+
d.strideI *= t;
80+
d.strideO *= t;
81+
d.padS *= t;
82+
d.dimI *= t;
2183
}
22-
size = std::accumulate(wholeNDim.begin(), wholeNDim.end(), 1, std::multiplies<>());
84+
dims.resize(dims.size() + 1);
85+
dims.back() = {1, 1, 0, t};
2386
}
2487

2588
}// namespace refactor::kernel

src/04kernel/src/collectors/pad.cc

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,32 @@
1-
#include "../kernels/pad/cpu_kernel.hh"
2-
// #include "../kernels/pad/cuda_kernel.hh"
31
#include "kernel/collectors/pad.h"
2+
#include "../kernels/pad/cpu_kernel.hh"
3+
#include "../kernels/pad/cuda_kernel.hh"
44

55
namespace refactor::kernel {
66

77
std::vector<KernelBox>
88
PadCollector::filter(TensorRefs inputs, TensorRefs outputs) const {
99
auto const &input = inputs[0];
10-
auto const &output = outputs[0];
11-
bool have_value = inputs.size() >= 3 ? true : false;
12-
PadInfo info(pads, mode, input, output, have_value);
10+
PadInfo info(dims, input);
11+
auto const_value = inputs.size() >= 3 ? std::make_optional(inputs[2]) : std::nullopt;
1312

1413
std::vector<KernelBox> ans;
1514
switch (_target) {
1615
case decltype(_target)::Cpu:
17-
if (auto ptr = PadCpu::build(std::move(info)); ptr) {
16+
if (auto ptr = PadCpu::build(std::move(info), mode, const_value); ptr) {
17+
ans.emplace_back(std::move(ptr));
18+
}
19+
break;
20+
case decltype(_target)::Nvidia:
21+
if (auto ptr = PadCuda::build(std::move(info), mode, const_value); ptr) {
1822
ans.emplace_back(std::move(ptr));
1923
}
2024
break;
21-
// case decltype(_target)::Nvidia:
22-
// if (auto ptr = PadCuda::build(); ptr) {
23-
// ans.emplace_back(std::move(ptr));
24-
// }
25-
// break;
2625
default:
2726
UNREACHABLEX(void, "Unknown target");
2827
}
2928
return ans;
3029
}
3130

32-
}// namespace refactor::kernel
31+
}// namespace refactor::kernel
32+

0 commit comments

Comments
 (0)