概述
简介
本章节将介绍笔者学习pytorch源码中遇到的一些并行算法(基于pytorch)。当然这些源码都可以在官方的github上获取。
入口
整体思路
入口为gather_topk函数,首先会经过RadixSelect函数找到第K大小的数(可能不是唯一的)。然后再并行找到大于这个值的所有数并按序前插以确定Index(本节举例最大Topk)。
首先需要了解的前提知识为基数排序,在Pytorch中topk是基于基数排序的变体实现的,了解其原理有助于了解topk的实现。
基数排序
RadixSelect
说这个函数为核心也不为过,这一节主要介绍这个函数。
寻常的基数排序是按位的,而Pytorch中是按两位的。具体在代码中体现为:
#define RADIX_SIZE 4
#define RADIX_BITS 2
#define RADIX_MASK 3
(这里举例最大Topk,且最大topk为按照11,10,01,00的顺序入桶,最小topk反之)
为啥说它是变体呢:topk其实并未将所有数据排序而是分组慢慢挑出最大的K个数据,随之找出这个第K大的数据。
-
灵魂函数 : getBitfield、setBitfield
如何将最大的k个数慢慢挑出来呢? 这里介绍一下灵魂函数getBitfield、setBitfield
作用:前者获取当前数据第i和第i+1位的数据(00,01,10,11按两位,其实就是对应基数排序的4个桶,step = RADIX_BITS),很容易理解。后者的作用是为了标记我们现在关注的数据,这个可能不太好理解,这里详细讲一下:
比如现在有1111,1110,1101,1100,0000,0001
这6个数而我们现在要找topk3。首先按照基数排序的思路
最大桶11
对应的元素数量为4:1111,1110,1101,1100
。此时最大的4
个数据大于3(topk3)也就是说我们只需从这4个数中继续找就行了,可以跳过当前剩余的入桶操作,这个setBitdield的作用就是记录这4个数的前缀11
帮助我们跳过其他数(之后只找前缀为11
的数)。在代码中具体表现为desired和desired_mask。这里简单提一下,后面会举例详细讲。
源码:
static __device__ __forceinline__
unsigned int getBitfield(unsigned int val, int pos, int len) {
#if defined(__HIP_PLATFORM_HCC__)
pos &= 0xff;
len &= 0xff;
unsigned int m = (1u << len) - 1u;
return (val >> pos) & m;
#else
unsigned int ret;
asm("bfe.u32 %0, %1, %2, %3;" : "=r"(ret) : "r"(val), "r"(pos), "r"(len));
return ret;
#endif
}
static __device__ __forceinline__
unsigned int setBitfield(unsigned int val, unsigned int toInsert, int pos, int len) {
#if defined(__HIP_PLATFORM_HCC__)
pos &= 0xff;
len &= 0xff;
unsigned int m = (1u << len) - 1u;
toInsert &= m;
toInsert <<= pos;
m <<= pos;
return (val & ~m) | toInsert;
#else
unsigned int ret;
asm("bfi.b32 %0, %1, %2, %3, %4;" :
"=r"(ret) : "r"(toInsert), "r"(val), "r"(pos), "r"(len));
return ret;
#endif
}
- 主要函数
countRadixUsingMask作用从字面上其实就很容易理解:往4个桶内放入我们关注的那些数的数量(对应基数排序的入桶操作),理解基数排序后这里也应该很容易理解。
不过Pytorch写的就比较高级(高性能)了:一个线程束(warp)为单位计算出每个桶内的数量。
(我在下面的源码关键处给了中文注释,慢慢理解吧,convert我在上面的基数排序也提到过)。
源码:
template <typename scalar_t,
typename bitwise_t,
typename index_t,
typename CountType,
int RadixSize,
int RadixBits>
__device__ void countRadixUsingMask(
CountType counts[RadixSize],
CountType* smem,
bitwise_t desired,
bitwise_t desiredMask,
int radixDigitPos,
index_t sliceSize,
index_t withinSliceStride,
scalar_t* data) {
// Clear out per-thread counts from a previous round
#pragma unroll
for (int i = 0; i < RadixSize; ++i) {
counts[i] = 0;
}
if (threadIdx.x < RadixSize) {
smem[threadIdx.x] = 0;
}
__syncthreads();
// Scan over all the data. Upon a read, the warp will accumulate
// counts per each digit in the radix using warp voting.
for (index_t i = threadIdx.x; i < sliceSize; i += blockDim.x) {
bitwise_t val =
TopKTypeConfig<scalar_t>::convert(doLdg(&data[i * withinSliceStride]));
bool hasVal = ((val & desiredMask) == desired);
bitwise_t digitInRadix =
Bitfield<bitwise_t>::getBitfield(val, radixDigitPos, RadixBits);
#pragma unroll
for (uint32_t j = 0; j < RadixSize; ++j) {
bool vote = hasVal && (digitInRadix == j);
//计算当前程束内所有需要入j桶的数的数量
#if defined(__HIP_PLATFORM_HCC__)
counts[j] += __popcll(WARP_BALLOT(vote));
#else
counts[j] += __popc(WARP_BALLOT(vote, ACTIVE_MASK()));
#endif
}
}
// Now, for each warp, sum values
//first thread of warp
if (getLaneId() == 0) {
//线程束内的第一个线程写入共享内存
#pragma unroll
for (uint32_t i = 0; i < RadixSize; ++i) {
gpuAtomicAdd(&smem[i], counts[i]);
}
}
__syncthreads();
// For each thread, read in the total counts
#pragma unroll
//将结果分享给每个线程
for (uint32_t i = 0; i < RadixSize; ++i) {
counts[i] = smem[i];
//printf("%u ",counts[i]);
}
__syncthreads();
}
RadixSelect的剩余部分:
在算出每个桶内的数量后就需要挑了,这里讲一下怎么挑。
主要分为两种情况(后两者和为i一种):
(1)当前桶内数量等于1且k=1,那么我们要找的那个数就是这个数(利用desiredmask并调用findpattern找出这个数)。
(2)当前桶内数量小于k, 那么执行k -= num[current_bucket]
(就是将这些数挑出来)
(3) 当前桶内数量大于k,那么我们要找的数就在这几个数中间,用setbitfield(关注这些数)提取他们的前缀继续在这几个数内找。
源码:
auto found_unique = [&](int i, int count) -> bool {
/* All threads have the same value in counts here, so all */
/* threads will return from the function. */
if (count == 1 && kToFind == 1) {
/* There is a unique answer. */
desired =
Bitfield<bitwise_t>::setBitfield(desired, i, digitPos, RADIX_BITS);
desiredMask = Bitfield<bitwise_t>::setBitfield(
desiredMask, RADIX_MASK, digitPos, RADIX_BITS);
/* The answer is now the unique element v such that: */
/* (v & desiredMask) == desired */
/* However, we do not yet know what the actual element is. We */
/* need to perform a search through the data to find the */
/* element that matches this pattern. */
*topK = findPattern<scalar_t, bitwise_t, index_t>(
(scalar_t*)smem,
data,
sliceSize,
withinSliceStride,
desired,
desiredMask);
return true;
}
return false;
};
auto found_non_unique = [&](int i, int count) -> bool {
if (count >= kToFind) {
desired =
Bitfield<bitwise_t>::setBitfield(desired, i, digitPos, RADIX_BITS);
desiredMask = Bitfield<bitwise_t>::setBitfield(
desiredMask, RADIX_MASK, digitPos, RADIX_BITS);
/* The top-Kth element v must now be one such that: */
/* (v & desiredMask == desired) */
/* but we haven't narrowed it down; we must check the next */
/* least-significant digit */
return true;
}
kToFind -= count;
return false; // continue the loop
};
// All threads participate in the comparisons below to know the
// final result
if (Order) {//最大topk
// Process in descending order
#pragma unroll
for (int i = RADIX_SIZE - 1; i >= 0; --i) {
int count = counts[i];
//
if (found_unique(i, count)) {
return;
}
if (found_non_unique(i, count)) {
break;
}
}
} else {//最小topk
// Process in ascending order
#pragma unroll
for (int i = 0; i < RADIX_SIZE; ++i) {
int count = counts[i];
if (found_unique(i, count)) {
return;
}
if (found_non_unique(i, count)) {
break;
}
}
}
} // end digitPos for
// There is no unique result, but there is a non-unique result
// matching `desired` exactly
*topK = TopKTypeConfig<scalar_t>::deconvert(desired);
上面讲的可能不够直观这里举个例子:
还是上面的1111,1110,1101,1100,0000,0001
这6个数
1.找topk4(k = 4),
首先看前两位
首先入桶 11
元素数量为4大于等于k=4.那么在1111,1110,1101,1100
这4个数(desired_mask = 11
)内继续找:
看第3,4位
11
桶内元素数量:1, 执行 k-=1。此时k = 3。
10
桶内元素数量: 1, 执行 k-=1。此时k = 2。
01
桶内元素数量: 1,执行 k-=1。此时k = 1。
00
桶内元素数量: 1,并且 k = 1,我们要找的值就为00桶对应的值(前缀为1100
的数这个前缀由11
+ 00
得来)
还有其他情况比如找到最后两位时桶内元素数量还大于k那么这个topk其实就等于desired_mask(第k个数不唯一)。
我将整个代码都提取了出来(float)并自己写了份无convert的(unsigned int)版本。
完整代码
最后
以上就是粗暴啤酒为你收集整理的Pytorch 源码浅析1.Topk(上)简介的全部内容,希望文章能够帮你解决Pytorch 源码浅析1.Topk(上)简介所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
发表评论 取消回复