05360171创建于 2022年3月18日历史提交
/******************************************************************************
*
* Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*

 ******************************************************************************/

#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCNumerics.cuh>
#include <THC/THC.h>

#include <cuda.h>

#include <torch/torch.h>
#include <torch/extension.h>

namespace nms_internal {

__device__
float calc_single_iou(const float4 b1, const float4 b2) {
  // (lt), (rb)
  float l = max(b1.x, b2.x);
  float t = max(b1.y, b2.y);
  float r = min(b1.z, b2.z);
  float b = min(b1.w, b2.w);

  float first = (r - l);
  first = (first < 0) ? 0 : first;
  float second = (b - t);
  second = (second < 0) ? 0 : second;

  float intersection = first * second;

  float area1 = (b1.w - b1.y) * (b1.z - b1.x);
  float area2 = (b2.w - b2.y) * (b2.z - b2.x);

  return intersection / (area1 + area2 - intersection);
}

// Choose whether or not to delete a box
// return 1 to delete, 0 to keep
__device__
uint8_t masked_iou(const float4 box1,
                const float4 box2,
                const uint8_t box2_deleted,
                const float criteria) {
  // if box2 isn't already deleted, calculate IoU
  if (box2_deleted == 1) return 1;

  float iou = calc_single_iou(box1, box2);

  // if iou < criteria, keep otherwise delete
  return (iou < criteria) ? 0 : 1;
}

// Based on what has been deleted, get the first non-deleted index
// and the count of non-deleted values
__device__
void get_current_num_idx(const uint8_t *deleted,
                         const int num_to_consider,
                         int *first_non_deleted,
                         int *remaining) {
  // dumb.
  // TODO: Not dumb, actually parallel
  int first = INT_MAX;
  int count = 0;
  for (int i = 0; i < num_to_consider; ++i) {
    // if element is deleted, ignore
    if (deleted[i] == 0) {
      first = (i < first) ? i : first;
      count++;
    }
  }
  *first_non_deleted = first;
  *remaining = count;
}


__global__
void nms_kernel(const int N,
                const int num_classes,
                const int *score_offsets,
                const float *scores,
                const long *score_idx,
                const float4 *bboxes,
                const float criteria,      // IoU threshold
                const int max_num,         // maximum number of candidate boxes to use
                uint8_t *deleted,          // assume initialised to false for all values
                long *num_candidates_out,   // number of outputs for this class
                float *score_out,          // output scores
                float4 *bboxes_out,        // output bboxes
                long *labels_out) {        // output labels
  // launch one block per class for now
  // Ignore class 0 (background) by starting at 1
  const int cls = blockIdx.x + 1;

  // offsets into scores and their indices
  const int offset_start = score_offsets[cls];
  const int offset_end   = score_offsets[cls+1];
  const int num_scores = offset_end - offset_start;
  // alias into local scores, indices and deleted buffers
  const float *local_scores = &scores[offset_start];
  const long *local_indices = &score_idx[offset_start];
  uint8_t *local_deleted = &deleted[offset_start];

  // aliases into output buffers
  float *local_score_out = &score_out[offset_start];
  float4 *local_bbox_out = &bboxes_out[offset_start];
  long *local_labels_out = &labels_out[offset_start];

  // Nothing to do here - early exit
  if (num_scores == 0) {
    if (threadIdx.x == 0) {
      num_candidates_out[cls] = 0;
    }
    return;
  }

  // how many scores we care about
  int num_to_consider = min(num_scores, max_num);
  int current_num = num_to_consider;

  // always start by looking at the first (highest) score
  int first_score_idx = 0;

  // store _global_ bbox candidate indices in shmem
  __shared__ int local_candidates[200];
  // also store _local_ indices for scores
  __shared__ int local_score_indices[200];
  // only thread 0 tracks how many candidates there are - need
  // to distribute that via shmem
  __shared__ int shared_num_candidates;

  // index into shmem buffer for storing candidates
  int current_candidate_idx = 0;

  // initialise all shmem values to sentinels for sanity
  for (int i = threadIdx.x; i < 200; i += blockDim.x) {
    local_candidates[i] = -1;
    local_score_indices[i] = -1;
  }
  // Shouldn't be necessary, make sure that no entries are
  // coming in deleted from poor initialisation.
  for (int i = threadIdx.x; i < num_scores; i += blockDim.x) {
    local_deleted[i] = 0;
  }
  __syncthreads();

  // While there's more scores/boxes to process
  while (current_num > 0) {

    // get the candidate index & bbox
    // first_score_idx is _local_ into the aliased index-storing buffer
    // candidate_idx is _global_ into the bbox buffer
    const long candidate_idx = local_indices[first_score_idx];
    const float4 candidate_bbox = bboxes[candidate_idx];
    // Now we've looked at this candidate, remove it from consideration
    local_deleted[first_score_idx] = 1;

    // calculate the IoUs of candidate vs. remaining boxes & manipulate delete array
    // standard block-stride loop over boxes
    for (int i = threadIdx.x; i < num_to_consider; i += blockDim.x) {
      // Know we've already looked at all entries before the candidate, so we can ignore them
      // TODO: handle this loop more efficiently w.r.t. skipped entries
      if (i > first_score_idx) {
        long test_idx = local_indices[i];
        float4 test_bbox = bboxes[test_idx];
        // Note if we need to delete this box
        local_deleted[i] = masked_iou(candidate_bbox, test_bbox, local_deleted[i], criteria);
      }
    }

    // make sure all IoU / deletion calcs are done
    // NOTE: shouldn't be necessary, candidate writing isn't dependent on the results
    // of IoU calcs, and sync point _after_ that writing should cover.
    // __syncthreads();

    // write the candidate idx into shmem and increment storage pointer
    if (threadIdx.x == 0) {
      // idx into global bbox array
      local_candidates[current_candidate_idx] = candidate_idx;
      // idx into local scores
      local_score_indices[current_candidate_idx] = first_score_idx;
      // increment storage location
      current_candidate_idx++;
    }

    __syncthreads();

    // Now, get the number of remaining boxes and the first non-deleted idx
    get_current_num_idx(local_deleted, num_to_consider, &first_score_idx, &current_num);

    __syncthreads();
  }

  // Note: Only thread 0 has the correct number of candidates (as that's the thread
  // that actually handles candidate tracking). Need to bcast the correct value to
  // everyone for multi-threaded output writing, so do that here via shmem.
  if (threadIdx.x == 0) {
    shared_num_candidates = current_candidate_idx;
  }
  __syncthreads();

  // at this point we should have all candidate indices for this class
  // use them to write out scores, bboxes and labels
  for (int i = threadIdx.x; i < shared_num_candidates; i += blockDim.x) {
    local_score_out[i] = local_scores[local_score_indices[i]];
    local_bbox_out[i] = bboxes[local_candidates[i]]; // bboxes[local_indices[i]];
    local_labels_out[i] = cls;
  }

  // write the final number of candidates from this class to a buffer
  if (threadIdx.x == 0) {
    num_candidates_out[cls] = current_candidate_idx;
  }
}

__global__
void squash_outputs(const int N, // number of sets of outputs
                    const long *num_candidates, // number of candidates per entry
                    const int *output_offsets, // offsets into outputs
                    const float *output_scores,
                    const float4 *output_boxes,
                    const long* output_labels,
                    const long* squashed_offsets,
                    float *squashed_scores,
                    float4 *squashed_boxes,
                    long *squashed_labels) {
  // block per output
  const int cls = blockIdx.x + 1;

  const int num_to_write = num_candidates[cls];
  const long read_offset = output_offsets[cls];
  const long write_offset = squashed_offsets[cls];

  for (int i = threadIdx.x; i < num_to_write; i += blockDim.x) {
    // Read
    auto score = output_scores[read_offset + i];
    auto bbox = output_boxes[read_offset + i];
    auto label = output_labels[read_offset + i];

    // Write
    squashed_scores[write_offset + i] = score;
    squashed_boxes[write_offset + i] = bbox;
    squashed_labels[write_offset + i] = label;
  }
}

};  // namespace nms_internal

std::vector<at::Tensor> nms(const int N, // number of images
                            const int num_classes,
                            const at::Tensor score_offsets,
                            const at::Tensor sorted_scores,
                            const at::Tensor sorted_scores_idx,
                            const at::Tensor bboxes,
                            const float criteria,
                            const int max_num) {
  // Run all classes in different blocks, ignore background class 0
  const int num_blocks = num_classes - 1;
  const int total_scores = score_offsets[score_offsets.numel()-1].item<int>();

  // track which elements have been deleted in each iteration
  at::Tensor deleted = torch::zeros({total_scores}, torch::CUDA(at::kByte));
  // track how many outputs we have for each class
  at::Tensor num_candidates_out = torch::zeros({num_classes}, torch::CUDA(at::kLong));
  // outputs
  at::Tensor score_out = torch::empty({total_scores}, torch::CUDA(at::kFloat));
  at::Tensor label_out = torch::empty({total_scores}, torch::CUDA(at::kLong));
  at::Tensor bbox_out = torch::empty({total_scores, 4}, torch::CUDA(at::kFloat));

  // Run the kernel
  const int THREADS_PER_BLOCK = 64;
  auto stream = at::cuda::getCurrentCUDAStream().stream();
  nms_internal::nms_kernel<<<num_blocks, THREADS_PER_BLOCK, 0, stream>>>(N,
                                                                         num_classes,
                                                                         score_offsets.data_ptr<int>(),
                                                                         sorted_scores.data_ptr<float>(),
                                                                         sorted_scores_idx.data_ptr<long>(),
                                                                         (float4*)bboxes.data_ptr<float>(),
                                                                         criteria,
                                                                         max_num,
                                                                         deleted.data_ptr<uint8_t>(),
                                                                         num_candidates_out.data_ptr<long>(),
                                                                         score_out.data_ptr<float>(),
                                                                         (float4*)bbox_out.data_ptr<float>(),
                                                                         label_out.data_ptr<long>());
  THCudaCheck(cudaGetLastError());

  // Now need to squash the output so it's contiguous.
  // get prefix sum of num_candidates_out
  // Note: Still need lengths
  auto output_offsets = num_candidates_out.cumsum(0);
  auto total_outputs = output_offsets[output_offsets.numel()-1].item<long>();
  output_offsets = output_offsets - num_candidates_out;

  // allocate final outputs
  at::Tensor squashed_scores = torch::empty({total_outputs}, torch::CUDA(at::kFloat));
  at::Tensor squashed_bboxes = torch::empty({total_outputs, 4}, torch::CUDA(at::kFloat));
  at::Tensor squashed_labels = torch::empty({total_outputs}, torch::CUDA(at::kLong));

  // Copy non-squashed outputs -> squashed.
  nms_internal::squash_outputs<<<num_blocks, THREADS_PER_BLOCK, 0, stream>>>(N,
                                                                             num_candidates_out.data_ptr<long>(),
                                                                             score_offsets.data_ptr<int>(),
                                                                             score_out.data_ptr<float>(),
                                                                             (float4*)bbox_out.data_ptr<float>(),
                                                                             label_out.data_ptr<long>(),
                                                                             output_offsets.contiguous().data_ptr<long>(),
                                                                             squashed_scores.data_ptr<float>(),
                                                                             (float4*)squashed_bboxes.data_ptr<float>(),
                                                                             squashed_labels.data_ptr<long>());
  THCudaCheck(cudaGetLastError());

  return {squashed_bboxes, squashed_scores, squashed_labels};
}