我是靠谱客的博主 粗暴啤酒,最近开发中收集的这篇文章主要介绍Pytorch 源码浅析1.Topk(上)简介,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

简介

本章节将介绍笔者学习pytorch源码中遇到的一些并行算法(基于pytorch)。当然这些源码都可以在官方的github上获取。

入口

整体思路

入口为gather_topk函数,首先会经过RadixSelect函数找到第K大小的数(可能不是唯一的)。然后再并行找到大于这个值的所有数并按序前插以确定Index(本节举例最大Topk)。

首先需要了解的前提知识为基数排序,在Pytorch中topk是基于基数排序的变体实现的,了解其原理有助于了解topk的实现。
基数排序

RadixSelect

说这个函数为核心也不为过,这一节主要介绍这个函数。
寻常的基数排序是按位的,而Pytorch中是按两位的。具体在代码中体现为:

#define RADIX_SIZE 4
#define RADIX_BITS 2
#define RADIX_MASK 3

(这里举例最大Topk,且最大topk为按照11,10,01,00的顺序入桶,最小topk反之)

为啥说它是变体呢:topk其实并未将所有数据排序而是分组慢慢挑出最大的K个数据,随之找出这个第K大的数据。

  1. 灵魂函数 : getBitfieldsetBitfield
    如何将最大的k个数慢慢挑出来呢? 这里介绍一下灵魂函数getBitfield、setBitfield
    作用:前者获取当前数据第i和第i+1位的数据(00,01,10,11按两位,其实就是对应基数排序的4个桶,step = RADIX_BITS),很容易理解。

    后者的作用是为了标记我们现在关注的数据,这个可能不太好理解,这里详细讲一下:
    比如现在有1111,1110,1101,1100,0000,0001这6个数而我们现在要找topk3。首先按照基数排序的思路
    最大桶11对应的元素数量为4:1111,1110,1101,1100。此时最大的4个数据大于3(topk3)也就是说我们只需从这4个数中继续找就行了,可以跳过当前剩余的入桶操作,这个setBitdield的作用就是记录这4个数的前缀11帮助我们跳过其他数(之后只找前缀为11的数)。在代码中具体表现为desireddesired_mask。这里简单提一下,后面会举例详细讲。
    源码

static __device__ __forceinline__
    unsigned int getBitfield(unsigned int val, int pos, int len) {
#if defined(__HIP_PLATFORM_HCC__)
      pos &= 0xff;
      len &= 0xff;

      unsigned int m = (1u << len) - 1u;
      return (val >> pos) & m;
#else
      unsigned int ret;
      asm("bfe.u32 %0, %1, %2, %3;" : "=r"(ret) : "r"(val), "r"(pos), "r"(len));
      return ret;
#endif
    }

  static __device__ __forceinline__
    unsigned int setBitfield(unsigned int val, unsigned int toInsert, int pos, int len) {
#if defined(__HIP_PLATFORM_HCC__)
      pos &= 0xff;
      len &= 0xff;

      unsigned int m = (1u << len) - 1u;
      toInsert &= m;
      toInsert <<= pos;
      m <<= pos;

      return (val & ~m) | toInsert;
#else
      unsigned int ret;
      asm("bfi.b32 %0, %1, %2, %3, %4;" :
          "=r"(ret) : "r"(toInsert), "r"(val), "r"(pos), "r"(len));
      return ret;
#endif
    }

  1. 主要函数
    countRadixUsingMask作用从字面上其实就很容易理解:往4个桶内放入我们关注的那些数的数量(对应基数排序的入桶操作),理解基数排序后这里也应该很容易理解。
    不过Pytorch写的就比较高级(高性能)了:一个线程束(warp)为单位计算出每个桶内的数量。
    (我在下面的源码关键处给了中文注释,慢慢理解吧,convert我在上面的基数排序也提到过)。
    源码
template <typename scalar_t,
         typename bitwise_t,
         typename index_t,
         typename CountType,
         int RadixSize,
         int RadixBits>
         __device__ void countRadixUsingMask(
             CountType counts[RadixSize],
             CountType* smem,
             bitwise_t desired,
             bitwise_t desiredMask,
             int radixDigitPos,
             index_t sliceSize,
             index_t withinSliceStride,
             scalar_t* data) {
           // Clear out per-thread counts from a previous round
#pragma unroll
           for (int i = 0; i < RadixSize; ++i) {
             counts[i] = 0;
           }

           if (threadIdx.x < RadixSize) {
             smem[threadIdx.x] = 0;
           }
           __syncthreads();

           // Scan over all the data. Upon a read, the warp will accumulate
           // counts per each digit in the radix using warp voting.
           for (index_t i = threadIdx.x; i < sliceSize; i += blockDim.x) {
             bitwise_t val =
               TopKTypeConfig<scalar_t>::convert(doLdg(&data[i * withinSliceStride]));

             bool hasVal = ((val & desiredMask) == desired);
             bitwise_t digitInRadix =
               Bitfield<bitwise_t>::getBitfield(val, radixDigitPos, RadixBits);

#pragma unroll
             for (uint32_t j = 0; j < RadixSize; ++j) {
               bool vote = hasVal && (digitInRadix == j);
               //计算当前程束内所有需要入j桶的数的数量
#if defined(__HIP_PLATFORM_HCC__)
               counts[j] += __popcll(WARP_BALLOT(vote));
#else
               counts[j] += __popc(WARP_BALLOT(vote, ACTIVE_MASK()));
#endif
             }
           }

           // Now, for each warp, sum values
           //first thread of warp
           if (getLaneId() == 0) {
           //线程束内的第一个线程写入共享内存
#pragma unroll
             for (uint32_t i = 0; i < RadixSize; ++i) {
               gpuAtomicAdd(&smem[i], counts[i]);
             }
           }

           __syncthreads();

           // For each thread, read in the total counts
#pragma unroll
		 //将结果分享给每个线程
           for (uint32_t i = 0; i < RadixSize; ++i) {
             counts[i] = smem[i];
             //printf("%u  ",counts[i]);
           }

           __syncthreads();
         }

RadixSelect的剩余部分:
在算出每个桶内的数量后就需要了,这里讲一下怎么挑。
主要分为两种情况(后两者和为i一种):
(1)当前桶内数量等于1且k=1,那么我们要找的那个数就是这个数(利用desiredmask并调用findpattern找出这个数)。
(2)当前桶内数量小于k, 那么执行k -= num[current_bucket](就是将这些数挑出来)
(3) 当前桶内数量大于k,那么我们要找的数就在这几个数中间,用setbitfield(关注这些数)提取他们的前缀继续在这几个数内找。
源码

auto found_unique = [&](int i, int count) -> bool {
      /* All threads have the same value in counts here, so all */
      /* threads will return from the function. */
      if (count == 1 && kToFind == 1) {
        /* There is a unique answer. */
        desired =
            Bitfield<bitwise_t>::setBitfield(desired, i, digitPos, RADIX_BITS);
        desiredMask = Bitfield<bitwise_t>::setBitfield(
            desiredMask, RADIX_MASK, digitPos, RADIX_BITS);

        /* The answer is now the unique element v such that: */
        /* (v & desiredMask) == desired */
        /* However, we do not yet know what the actual element is. We */
        /* need to perform a search through the data to find the */
        /* element that matches this pattern. */
        *topK = findPattern<scalar_t, bitwise_t, index_t>(
            (scalar_t*)smem,
            data,
            sliceSize,
            withinSliceStride,
            desired,
            desiredMask);
        return true;
      }
      return false;
    };
    auto found_non_unique = [&](int i, int count) -> bool {
      if (count >= kToFind) {
        desired =
            Bitfield<bitwise_t>::setBitfield(desired, i, digitPos, RADIX_BITS);
        desiredMask = Bitfield<bitwise_t>::setBitfield(
            desiredMask, RADIX_MASK, digitPos, RADIX_BITS);

        /* The top-Kth element v must now be one such that: */
        /* (v & desiredMask == desired) */
        /* but we haven't narrowed it down; we must check the next */
        /* least-significant digit */
        return true;
      }
      kToFind -= count;
      return false; // continue the loop
    };

    // All threads participate in the comparisons below to know the
    // final result
    if (Order) {//最大topk
      // Process in descending order
#pragma unroll
      for (int i = RADIX_SIZE - 1; i >= 0; --i) {
        int count = counts[i];
        //
        if (found_unique(i, count)) {
          return;
        }
        if (found_non_unique(i, count)) {
          break;
        }
      }
    } else {//最小topk
      // Process in ascending order
#pragma unroll
      for (int i = 0; i < RADIX_SIZE; ++i) {
        int count = counts[i];
        if (found_unique(i, count)) {
          return;
        }
        if (found_non_unique(i, count)) {
          break;
        }
      }
    }
  } // end digitPos for

  // There is no unique result, but there is a non-unique result
  // matching `desired` exactly
  *topK = TopKTypeConfig<scalar_t>::deconvert(desired);

上面讲的可能不够直观这里举个例子:
还是上面的1111,1110,1101,1100,0000,0001这6个数
1.找topk4(k = 4),
首先看前两位
首先入桶 11元素数量为4大于等于k=4.那么在1111,1110,1101,1100这4个数(desired_mask = 11)内继续找:
看第3,4位
11桶内元素数量:1, 执行 k-=1。此时k = 3。
10桶内元素数量: 1, 执行 k-=1。此时k = 2。
01桶内元素数量: 1,执行 k-=1。此时k = 1。
00桶内元素数量: 1,并且 k = 1,我们要找的值就为00桶对应的值(前缀为1100的数这个前缀由11+ 00得来)

还有其他情况比如找到最后两位时桶内元素数量还大于k那么这个topk其实就等于desired_mask(第k个数不唯一)。

我将整个代码都提取了出来(float)并自己写了份无convert的(unsigned int)版本。
完整代码

最后

以上就是粗暴啤酒为你收集整理的Pytorch 源码浅析1.Topk(上)简介的全部内容,希望文章能够帮你解决Pytorch 源码浅析1.Topk(上)简介所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(54)

评论列表共有 0 条评论

立即
投稿
返回
顶部