diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression.cpp index 1e0330294..e2fe58b7a 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression.cpp @@ -38,15 +38,19 @@ namespace nd4j { REQUIRE_TRUE(false, 0, "image.non_max_suppression: Max output size argument cannot be retrieved."); REQUIRE_TRUE(boxes->rankOf() == 2, 0, "image.non_max_suppression: The rank of boxes array should be 2, but %i is given", boxes->rankOf()); - REQUIRE_TRUE(scales->rankOf() == 1 && scales->lengthOf() == boxes->sizeAt(0), 0, "image.non_max_suppression: The rank of boxes array should be 2, but %i is given", boxes->rankOf()); + REQUIRE_TRUE(boxes->sizeAt(1) == 4, 0, "image.non_max_suppression: The last dimension of boxes array should be 4, but %i is given", boxes->sizeAt(1)); + REQUIRE_TRUE(scales->rankOf() == 1 && scales->lengthOf() == boxes->sizeAt(0), 0, "image.non_max_suppression: The rank of scales array should be 1, but %i is given", boxes->rankOf()); if (scales->lengthOf() < maxOutputSize) maxOutputSize = scales->lengthOf(); - double threshold = 0.5; + double overlayThreshold = 0.5; + double scoreThreshold = - DataTypeUtils::infOrMax(); if (block.getTArguments()->size() > 0) - threshold = T_ARG(0); + overlayThreshold = T_ARG(0); + if (block.getTArguments()->size() > 1) + scoreThreshold = T_ARG(1); - helpers::nonMaxSuppressionV2(block.launchContext(), boxes, scales, maxOutputSize, threshold, output); + helpers::nonMaxSuppression(block.launchContext(), boxes, scales, maxOutputSize, overlayThreshold, scoreThreshold, output); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression_overlaps.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression_overlaps.cpp new file mode 100644 index 000000000..4f405d8c8 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression_overlaps.cpp @@ -0,0 +1,93 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by GS at 10/17/2019 +// + +#include +#include + +#if NOT_EXCLUDED(OP_image_non_max_suppression_overlaps) + +namespace nd4j { + namespace ops { + CUSTOM_OP_IMPL(non_max_suppression_overlaps, 2, 1, false, 0, 0) { + auto boxes = INPUT_VARIABLE(0); + auto scales = INPUT_VARIABLE(1); + auto output = OUTPUT_VARIABLE(0); + int maxOutputSize; // = INT_ARG(0); + if (block.width() > 2) + maxOutputSize = INPUT_VARIABLE(2)->e(0); + else if (block.getIArguments()->size() == 1) + maxOutputSize = INT_ARG(0); + else + REQUIRE_TRUE(false, 0, "image.non_max_suppression_overlaps: Max output size argument cannot be retrieved."); + REQUIRE_TRUE(boxes->rankOf() == 2, 0, "image.non_max_suppression_overlaps: The rank of boxes array should be 2, but %i is given", boxes->rankOf()); + REQUIRE_TRUE(boxes->sizeAt(0) == boxes->sizeAt(1), 0, "image.non_max_suppression_overlaps: The boxes array should be square, but {%lld, %lld} is given", boxes->sizeAt(0), boxes->sizeAt(1)); + REQUIRE_TRUE(scales->rankOf() == 1 && scales->lengthOf() == boxes->sizeAt(0), 0, "image.non_max_suppression_overlaps: The rank of scales array should be 1, but %i is given", boxes->rankOf()); + +// if (scales->lengthOf() < maxOutputSize) +// maxOutputSize = scales->lengthOf(); + double overlapThreshold = 0.5; + double scoreThreshold = -DataTypeUtils::infOrMax(); + if (block.getTArguments()->size() > 0) + overlapThreshold = T_ARG(0); + if (block.getTArguments()->size() > 1) + scoreThreshold = T_ARG(1); + + // TODO: refactor helpers to multithreaded facility + helpers::nonMaxSuppressionGeneric(block.launchContext(), boxes, scales, maxOutputSize, overlapThreshold, + scoreThreshold, output); + return Status::OK(); + } + + DECLARE_SHAPE_FN(non_max_suppression_overlaps) { + auto in = inputShape->at(0); + int outRank = shape::rank(in); + Nd4jLong *outputShape = nullptr; + + int maxOutputSize; + if (block.width() > 2) + maxOutputSize = INPUT_VARIABLE(2)->e(0); + else if (block.getIArguments()->size() == 1) + maxOutputSize = INT_ARG(0); + else + REQUIRE_TRUE(false, 0, "image.non_max_suppression: Max output size argument cannot be retrieved."); + + double overlapThreshold = 0.5; + double scoreThreshold = 0.; + + Nd4jLong boxSize = helpers::nonMaxSuppressionGeneric(block.launchContext(), INPUT_VARIABLE(0), + INPUT_VARIABLE(1), maxOutputSize, overlapThreshold, scoreThreshold, nullptr); //shape::sizeAt(in, 0); + if (boxSize < maxOutputSize) + maxOutputSize = boxSize; + + outputShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(maxOutputSize, DataType::INT32); + + return SHAPELIST(outputShape); + } + DECLARE_TYPES(non_max_suppression_overlaps) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_FLOATS}) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedInputTypes(2, {ALL_INTS}) + ->setAllowedOutputTypes({ALL_INDICES}); + } + + } +} +#endif diff --git a/libnd4j/include/ops/declarable/headers/parity_ops.h b/libnd4j/include/ops/declarable/headers/parity_ops.h index cbc7e56da..3660ee229 100644 --- a/libnd4j/include/ops/declarable/headers/parity_ops.h +++ b/libnd4j/include/ops/declarable/headers/parity_ops.h @@ -1691,15 +1691,38 @@ namespace nd4j { * 1 - scales - 1D-tensor with shape (num_boxes) by float type * 2 - output_size - 0D-tensor by int type (optional) * float args: - * 0 - threshold - threshold value for overlap checks (optional, by default 0.5) + * 0 - overlap_threshold - threshold value for overlap checks (optional, by default 0.5) + * 1 - score_threshold - the threshold for deciding when to remove boxes based on score (optional, by default -inf) * int args: * 0 - output_size - as arg 2 used for same target. Eigher this or arg 2 should be provided. * + * output: + * - vector with size M, where M <= output_size by int type + * * */ #if NOT_EXCLUDED(OP_image_non_max_suppression) DECLARE_CUSTOM_OP(non_max_suppression, 2, 1, false, 0, 0); #endif + /* + * image.non_max_suppression_overlaps op. + * input: + * 0 - boxes - 2D-tensor with shape (num_boxes, 4) by float type + * 1 - scales - 1D-tensor with shape (num_boxes) by float type + * 2 - output_size - 0D-tensor by int type (optional) + * float args: + * 0 - overlap_threshold - threshold value for overlap checks (optional, by default 0.5) + * 1 - score_threshold - the threshold for deciding when to remove boxes based on score (optional, by default -inf) + * int args: + * 0 - output_size - as arg 2 used for same target. Eigher this or arg 2 should be provided. + * + * output: + * 0 - 1D integer tensor with shape [M], epresenting the selected indices from the overlaps tensor, where M <= max_output_size + * */ + #if NOT_EXCLUDED(OP_image_non_max_suppression_overlaps) + DECLARE_CUSTOM_OP(non_max_suppression_overlaps, 2, 1, false, 0, 0); + #endif + /* * cholesky op - decomposite positive square symetric matrix (or matricies when rank > 2). * input: diff --git a/libnd4j/include/ops/declarable/helpers/cpu/image_suppression.cpp b/libnd4j/include/ops/declarable/helpers/cpu/image_suppression.cpp index f90974a9f..f4fb98b2a 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/image_suppression.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/image_suppression.cpp @@ -22,21 +22,26 @@ //#include #include #include +#include namespace nd4j { namespace ops { namespace helpers { template - static void nonMaxSuppressionV2_(NDArray* boxes, NDArray* scales, int maxSize, double threshold, NDArray* output) { - std::vector indices(scales->lengthOf()); + static void nonMaxSuppressionV2_(NDArray* boxes, NDArray* scales, int maxSize, double overlapThreshold, + double scoreThreshold, NDArray* output) { + std::vector indices(scales->lengthOf()); std::iota(indices.begin(), indices.end(), 0); - + for (auto e = 0; e < scales->lengthOf(); e++) { + if (scales->e(e) < scoreThreshold) indices[e] = -1; + } std::sort(indices.begin(), indices.end(), [scales](int i, int j) {return scales->e(i) > scales->e(j);}); // std::vector selected(output->lengthOf()); std::vector selectedIndices(output->lengthOf(), 0); auto needToSuppressWithThreshold = [] (NDArray& boxes, int previousIndex, int nextIndex, T threshold) -> bool { + if (previousIndex < 0 || nextIndex < 0) return true; T minYPrev = nd4j::math::nd4j_min(boxes.e(previousIndex, 0), boxes.e(previousIndex, 2)); T minXPrev = nd4j::math::nd4j_min(boxes.e(previousIndex, 1), boxes.e(previousIndex, 3)); T maxYPrev = nd4j::math::nd4j_max(boxes.e(previousIndex, 0), boxes.e(previousIndex, 2)); @@ -70,7 +75,7 @@ namespace helpers { PRAGMA_OMP_PARALLEL_FOR //_ARGS(firstprivate(numSelected)) for (int j = numSelected - 1; j >= 0; --j) { if (shouldSelect) - if (needToSuppressWithThreshold(*boxes, indices[i], indices[selectedIndices[j]], T(threshold))) { + if (needToSuppressWithThreshold(*boxes, indices[i], indices[selectedIndices[j]], T(overlapThreshold))) { shouldSelect = false; } } @@ -80,11 +85,119 @@ namespace helpers { } } } +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + template + static Nd4jLong + nonMaxSuppressionGeneric_(nd4j::LaunchContext* context, NDArray* boxes, NDArray* scores, int outputSize, + double overlapThreshold, double scoreThreshold, NDArray* output) { - void nonMaxSuppressionV2(nd4j::LaunchContext * context, NDArray* boxes, NDArray* scales, int maxSize, double threshold, NDArray* output) { - BUILD_SINGLE_SELECTOR(boxes->dataType(), nonMaxSuppressionV2_, (boxes, scales, maxSize, threshold, output), NUMERIC_TYPES); +// const int outputSize = maxSize->e(0); + auto numBoxes = boxes->sizeAt(0); + //std::vector scoresData(numBoxes); + T* scoresData = scores->dataBuffer()->primaryAsT(); + //std::copy_n(scores->getDataBuffer()->primaryAsT(), numBoxes, scoresData.begin()); + + // Data structure for a selection candidate in NMS. + struct Candidate { + int _boxIndex; + T _score; + int _suppressBeginIndex; + }; + + auto cmp = [](const Candidate& bsI, const Candidate& bsJ) -> bool{ + return ((bsI._score == bsJ._score) && (bsI._boxIndex > bsJ._boxIndex)) || + (bsI._score < bsJ._score); + }; + std::priority_queue, decltype(cmp)> candidatePriorityQueue(cmp); + for (auto i = 0; i < scores->lengthOf(); ++i) { + if (scoresData[i] > scoreThreshold) { + candidatePriorityQueue.emplace(Candidate({i, scoresData[i], 0})); + } + } + + std::vector selected; + T similarity, originalScore; + Candidate nextCandidate; + + while (selected.size() < outputSize && !candidatePriorityQueue.empty()) { + nextCandidate = candidatePriorityQueue.top(); + originalScore = nextCandidate._score; + candidatePriorityQueue.pop(); + + // Overlapping boxes are likely to have similar scores, therefore we + // iterate through the previously selected boxes backwards in order to + // see if `nextCandidate` should be suppressed. We also enforce a property + // that a candidate can be suppressed by another candidate no more than + // once via `suppress_begin_index` which tracks which previously selected + // boxes have already been compared against next_candidate prior to a given + // iteration. These previous selected boxes are then skipped over in the + // following loop. + bool shouldHardSuppress = false; + for (int j = static_cast(selected.size()) - 1; j >= nextCandidate._suppressBeginIndex; --j) { + similarity = boxes->t(nextCandidate._boxIndex, selected[j]); + nextCandidate._score *= T(similarity <= overlapThreshold?1.0:0.); //suppressWeightFunc(similarity); + + // First decide whether to perform hard suppression + if (similarity >= static_cast(overlapThreshold)) { + shouldHardSuppress = true; + break; + } + + // If next_candidate survives hard suppression, apply soft suppression + if (nextCandidate._score <= scoreThreshold) break; + } + // If `nextCandidate._score` has not dropped below `scoreThreshold` + // by this point, then we know that we went through all of the previous + // selections and can safely update `suppress_begin_index` to + // `selected.size()`. If on the other hand `next_candidate.score` + // *has* dropped below the score threshold, then since `suppressWeight` + // always returns values in [0, 1], further suppression by items that were + // not covered in the above for loop would not have caused the algorithm + // to select this item. We thus do the same update to + // `suppressBeginIndex`, but really, this element will not be added back + // into the priority queue in the following. + nextCandidate._suppressBeginIndex = selected.size(); + + if (!shouldHardSuppress) { + if (nextCandidate._score == originalScore) { + // Suppression has not occurred, so select next_candidate + selected.push_back(nextCandidate._boxIndex); +// selected_scores.push_back(nextCandidate._score); + } + if (nextCandidate._score > scoreThreshold) { + // Soft suppression has occurred and current score is still greater than + // score_threshold; add next_candidate back onto priority queue. + candidatePriorityQueue.push(nextCandidate); + } + } + } + + if (output) { + DataBuffer buf(selected.data(), selected.size() * sizeof(I), DataTypeUtils::fromT()); + output->dataBuffer()->copyBufferFrom(buf, buf.getLenInBytes()); + } + + return (Nd4jLong)selected.size(); } - BUILD_SINGLE_TEMPLATE(template void nonMaxSuppressionV2_, (NDArray* boxes, NDArray* scales, int maxSize, double threshold, NDArray* output), NUMERIC_TYPES); + + Nd4jLong + nonMaxSuppressionGeneric(nd4j::LaunchContext* context, NDArray* boxes, NDArray* scores, int maxSize, + double overlapThreshold, double scoreThreshold, NDArray* output) { + BUILD_DOUBLE_SELECTOR(boxes->dataType(), output == nullptr?DataType::INT32:output->dataType(), return nonMaxSuppressionGeneric_, (context, boxes, scores, maxSize, overlapThreshold, scoreThreshold, output), FLOAT_TYPES, INTEGER_TYPES); + return 0; + } + + BUILD_DOUBLE_TEMPLATE(template Nd4jLong nonMaxSuppressionGeneric_, (nd4j::LaunchContext* context, NDArray* boxes, NDArray* scores, int maxSize, + double overlapThreshold, double scoreThreshold, NDArray* output), FLOAT_TYPES, INTEGER_TYPES); + + void + nonMaxSuppression(nd4j::LaunchContext * context, NDArray* boxes, NDArray* scales, int maxSize, + double overlapThreshold, double scoreThreshold, NDArray* output) { + BUILD_SINGLE_SELECTOR(boxes->dataType(), nonMaxSuppressionV2_, (boxes, scales, maxSize, + overlapThreshold, scoreThreshold, output), NUMERIC_TYPES); + } + BUILD_SINGLE_TEMPLATE(template void nonMaxSuppressionV2_, (NDArray* boxes, NDArray* scales, int maxSize, + double overlapThreshold, double scoreThreshold, NDArray* output), NUMERIC_TYPES); } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/image_suppression.cu b/libnd4j/include/ops/declarable/helpers/cuda/image_suppression.cu index 3393a61e3..a96db0195 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/image_suppression.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/image_suppression.cu @@ -22,6 +22,7 @@ #include #include #include +#include namespace nd4j { namespace ops { @@ -121,24 +122,40 @@ namespace helpers { for (auto i = tid; i < len; i += step) indexBuf[i] = (I)srcBuf[i]; } -//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + + template + static __global__ void suppressScores(T* scores, I* indices, Nd4jLong length, T scoreThreshold) { + auto start = blockIdx.x * blockDim.x; + auto step = gridDim.x * blockDim.x; + + for (auto e = start + threadIdx.x; e < (int)length; e += step) { + if (scores[e] < scoreThreshold) { + scores[e] = scoreThreshold; + indices[e] = -1; + } + else { + indices[e] = I(e); + } + } + } + + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // nonMaxSuppressionV2 algorithm - given from TF NonMaxSuppressionV2 implementation // template - static void nonMaxSuppressionV2_(nd4j::LaunchContext* context, NDArray* boxes, NDArray* scales, int maxSize, double threshold, NDArray* output) { + static void nonMaxSuppressionV2_(nd4j::LaunchContext* context, NDArray* boxes, NDArray* scales, int maxSize, double threshold, double scoreThreshold, NDArray* output) { auto stream = context->getCudaStream(); NDArray::prepareSpecialUse({output}, {boxes, scales}); std::unique_ptr indices(NDArrayFactory::create_('c', {scales->lengthOf()})); // - 1, scales->lengthOf()); //, scales->getContext()); - indices->linspace(0); - indices->syncToDevice(); // linspace only on CPU, so sync to Device as well NDArray scores(*scales); Nd4jPointer extras[2] = {nullptr, stream}; - + auto indexBuf = indices->dataBuffer()->specialAsT();///reinterpret_cast(indices->specialBuffer()); + auto scoreBuf = scores.dataBuffer()->specialAsT(); + suppressScores<<<128, 128, 128, *stream>>>(scoreBuf, indexBuf, scores.lengthOf(), T(scoreThreshold)); + indices->tickWriteDevice(); sortByValue(extras, indices->buffer(), indices->shapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), scores.buffer(), scores.shapeInfo(), scores.specialBuffer(), scores.specialShapeInfo(), true); - - auto indexBuf = reinterpret_cast(indices->specialBuffer()); - + indices->tickWriteDevice(); NDArray selectedIndices = NDArrayFactory::create('c', {output->lengthOf()}); int numSelected = 0; int numBoxes = boxes->sizeAt(0); @@ -180,10 +197,156 @@ namespace helpers { } } + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + template + static __device__ bool checkOverlapBoxes(T* boxes, Nd4jLong* shape, T* scores, I* indices, I* selectedIndices, I* startIndices, I selectedSize, I nextCandidateIndex, T overlapThreshold, T scoreThreshold) { + bool shouldHardSuppress = false; + T& nextCandidateScore = scores[nextCandidateIndex]; + I selectedIndex = indices[nextCandidateIndex]; + I finish = startIndices[nextCandidateIndex]; + + for (int j = selectedSize; j > finish; --j) { + Nd4jLong xPos[] = {selectedIndex, selectedIndices[j - 1]}; + auto xShift = shape::getOffset(shape, xPos, 0); + nextCandidateScore *= (boxes[xShift] <= static_cast(overlapThreshold)?T(1.):T(0.));// + // First decide whether to perform hard suppression + if (boxes[xShift] >= overlapThreshold) { + shouldHardSuppress = true; + break; + } + + // If nextCandidate survives hard suppression, apply soft suppression + if (nextCandidateScore <= scoreThreshold) break; + } + + return shouldHardSuppress; + } +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + template + static __global__ void + suppressNonMaxOverlapKernel(T* boxes, Nd4jLong* boxesShape, T* scoresData, I* indices, I* startIndices, Nd4jLong length, I maxOutputLen, + T overlapThreshold, T scoreThreshold, I* output, Nd4jLong* outputShape, I* outputLength) { + + __shared__ I selectedSize; + __shared__ I* tempOutput; + + if (threadIdx.x == 0) { + selectedSize = outputLength?*outputLength:maxOutputLen; + extern __shared__ unsigned char shmem[]; + tempOutput = (I*)shmem; + } + __syncthreads(); + + auto start = blockIdx.x * blockDim.x; + auto step = blockDim.x * gridDim.x; + + for (I nextCandidateIndex = start + threadIdx.x; selectedSize < maxOutputLen && nextCandidateIndex < (I)length; ) { + auto originalScore = scoresData[nextCandidateIndex];//nextCandidate._score; + I nextCandidateBoxIndex = indices[nextCandidateIndex]; + auto selectedSizeMark = selectedSize; + + // skip for cases when index is less than 0 (under score threshold) + if (nextCandidateBoxIndex < 0) { + nextCandidateIndex += step; + continue; + } + // check for overlaps + bool shouldHardSuppress = checkOverlapBoxes(boxes, boxesShape, scoresData, indices, tempOutput, startIndices, selectedSize, + nextCandidateIndex, overlapThreshold, scoreThreshold);//false; + T nextCandidateScore = scoresData[nextCandidateIndex]; + + startIndices[nextCandidateIndex] = selectedSize; + if (!shouldHardSuppress) { + if (nextCandidateScore == originalScore) { + // Suppression has not occurred, so select nextCandidate + if (output) + output[selectedSize] = nextCandidateBoxIndex; + tempOutput[selectedSize] = nextCandidateBoxIndex; + math::atomics::nd4j_atomicAdd(&selectedSize, (I)1); + } + + if (nextCandidateScore > scoreThreshold) { + // Soft suppression has occurred and current score is still greater than + // scoreThreshold; add nextCandidate back onto priority queue. + continue; // in some cases, this index not 0 + } + } + nextCandidateIndex += step; + } + + if (threadIdx.x == 0) { + if (outputLength) + *outputLength = selectedSize; + } + } + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + template + static Nd4jLong + nonMaxSuppressionGeneric_(nd4j::LaunchContext* context, NDArray* boxes, NDArray* scores, int outputSize, + double overlapThreshold, double scoreThreshold, NDArray* output) { + auto stream = context->getCudaStream(); + if (output) + NDArray::prepareSpecialUse({output}, {boxes, scores}); + else { + if (!boxes->isActualOnDeviceSide()) + boxes->syncToDevice(); + if (!scores->isActualOnDeviceSide()) + scores->syncToDevice(); + } + + NDArray indices = NDArrayFactory::create('c', {scores->lengthOf()}); // - 1, scales->lengthOf()); //, scales->getContext()); + NDArray startPositions = NDArrayFactory::create('c', {scores->lengthOf()}); + NDArray selectedScores(*scores); + Nd4jPointer extras[2] = {nullptr, stream}; + auto indexBuf = indices.dataBuffer()->specialAsT();///reinterpret_cast(indices->specialBuffer()); + + suppressScores<<<128, 128, 128, *stream>>>(selectedScores.dataBuffer()->specialAsT(), indexBuf, selectedScores.lengthOf(), T(scoreThreshold)); + + sortByValue(extras, indices.buffer(), indices.shapeInfo(), indices.specialBuffer(), indices.specialShapeInfo(), selectedScores.buffer(), selectedScores.shapeInfo(), selectedScores.specialBuffer(), selectedScores.specialShapeInfo(), true); + indices.tickWriteDevice(); + selectedScores.tickWriteDevice(); + + auto scoresData = selectedScores.dataBuffer()->specialAsT();//, numBoxes, scoresData.begin()); + + auto startIndices = startPositions.dataBuffer()->specialAsT(); + I selectedSize = 0; + Nd4jLong res = 0; + if (output) { // this part used when output shape already calculated to fill up values on output + DataBuffer selectedSizeBuf(&selectedSize, sizeof(I), DataTypeUtils::fromT()); + suppressNonMaxOverlapKernel <<<1, 1, 1024, *stream >>> (boxes->dataBuffer()->specialAsT(), + boxes->specialShapeInfo(), scoresData, indexBuf, startIndices, scores->lengthOf(), (I) outputSize, + T(overlapThreshold), T(scoreThreshold), output->dataBuffer()->specialAsT(), output->specialShapeInfo(), + selectedSizeBuf.specialAsT()); + } + else { // this case used on calculation of output shape. Output and output shape shoulde be nullptr. + DataBuffer selectedSizeBuf(&selectedSize, sizeof(I), DataTypeUtils::fromT()); + suppressNonMaxOverlapKernel <<<1, 1, 1024, *stream >>> (boxes->dataBuffer()->specialAsT(), + boxes->specialShapeInfo(), scoresData, indexBuf, startIndices, scores->lengthOf(), (I)outputSize, + T(overlapThreshold), T(scoreThreshold), (I*)nullptr, (Nd4jLong*) nullptr, selectedSizeBuf.specialAsT()); + selectedSizeBuf.syncToPrimary(context, true); + res = *selectedSizeBuf.primaryAsT(); + } + + if (output) + NDArray::registerSpecialUse({output}, {boxes, scores}); + + return res; + } +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + void nonMaxSuppression(nd4j::LaunchContext * context, NDArray* boxes, NDArray* scales, int maxSize, double threshold, double scoreThreshold, NDArray* output) { + BUILD_DOUBLE_SELECTOR(boxes->dataType(), output->dataType(), nonMaxSuppressionV2_, + (context, boxes, scales, maxSize, threshold, scoreThreshold, output), + FLOAT_TYPES, INDEXING_TYPES); + } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - void nonMaxSuppressionV2(nd4j::LaunchContext * context, NDArray* boxes, NDArray* scales, int maxSize, double threshold, NDArray* output) { - BUILD_DOUBLE_SELECTOR(boxes->dataType(), output->dataType(), nonMaxSuppressionV2_, (context, boxes, scales, maxSize, threshold, output), FLOAT_TYPES, INDEXING_TYPES); + Nd4jLong nonMaxSuppressionGeneric(nd4j::LaunchContext * context, NDArray* boxes, NDArray* scales, int maxSize, double threshold, double scoreThreshold, NDArray* output) { + BUILD_DOUBLE_SELECTOR(boxes->dataType(), output ? output->dataType():DataType::INT32, return nonMaxSuppressionGeneric_, + (context, boxes, scales, maxSize, threshold, scoreThreshold, output), + FLOAT_TYPES, INDEXING_TYPES); + return boxes->sizeAt(0); } } diff --git a/libnd4j/include/ops/declarable/helpers/image_suppression.h b/libnd4j/include/ops/declarable/helpers/image_suppression.h index afce399a6..85224e0f5 100644 --- a/libnd4j/include/ops/declarable/helpers/image_suppression.h +++ b/libnd4j/include/ops/declarable/helpers/image_suppression.h @@ -26,7 +26,10 @@ namespace nd4j { namespace ops { namespace helpers { - void nonMaxSuppressionV2(nd4j::LaunchContext * context, NDArray* boxes, NDArray* scales, int maxSize, double threshold, NDArray* output); + void nonMaxSuppression(nd4j::LaunchContext * context, NDArray* boxes, NDArray* scales, int maxSize, + double overlapThreshold, double scoreThreshold, NDArray* output); + Nd4jLong nonMaxSuppressionGeneric(nd4j::LaunchContext* context, NDArray* boxes, NDArray* scores, int maxSize, + double overlapThreshold, double scoreThreshold, NDArray* output); } } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp index c1e9ca5e2..84e3b4e8f 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp @@ -1960,7 +1960,82 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_2) { ASSERT_EQ(ND4J_STATUS_OK, results->status()); NDArray* result = results->at(0); - result->printBuffer("NonMaxSuppression OUtput2"); +// result->printBuffer("NonMaxSuppression OUtput2"); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + delete results; +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressingOverlap_1) { + + NDArray boxes = NDArrayFactory::create('c', {4,4}, { + 0, 0, 1, 1, + 0, 0.1, 1, 1.1, + 0, -0.1, 1, 0.9, + 0, 10, 1, 11}); + NDArray scores = NDArrayFactory::create('c', {4}, {0.9, .75, .6, .95}); //3 + NDArray max_num = NDArrayFactory::create(3); + NDArray expected = NDArrayFactory::create('c', {1,}, {3}); + + nd4j::ops::non_max_suppression_overlaps op; + auto results = op.execute({&boxes, &scores, &max_num}, {0.5, 0.}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + NDArray* result = results->at(0); +// result->printBuffer("NonMaxSuppressionOverlap1 Output"); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + delete results; +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressingOverlap_2) { + + NDArray boxes = NDArrayFactory::create('c', {4,4}, { + 0, 0, 1, 1, + 0, 0.1, 1, 1.1, + 0, -0.1, 1, 0.9, + 0, 10, 1, 11}); + NDArray scores = NDArrayFactory::create('c', {4}, {0.9, .95, .6, .75}); //3 + NDArray max_num = NDArrayFactory::create(3); + NDArray expected = NDArrayFactory::create('c', {3,}, {1,1,1}); + + nd4j::ops::non_max_suppression_overlaps op; + auto results = op.execute({&boxes, &scores, &max_num}, {0.5, 0.}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + NDArray* result = results->at(0); +// result->printBuffer("NonMaxSuppressionOverlap Output"); + ASSERT_TRUE(expected.isSameShapeStrict(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + delete results; +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressingOverlap_3) { + + NDArray boxes = NDArrayFactory::create('c', {4,4}, { + 0, 0, 1, 1, + 0, 0.1, 1, 1.1, + 0, -0.1, 1, 0.9, + 0, 10, 1, 11}); + NDArray scores = NDArrayFactory::create('c', {4}, {0.5, .95, -.6, .75}); //3 + NDArray max_num = NDArrayFactory::create(5); + NDArray expected = NDArrayFactory::create('c', {5,}, {1,1,1,1,1}); + + nd4j::ops::non_max_suppression_overlaps op; + auto results = op.execute({&boxes, &scores, &max_num}, {0.5, 0.}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + NDArray* result = results->at(0); +// result->printBuffer("NonMaxSuppressionOverlap Output"); ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -1984,7 +2059,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_1) { ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto result = results->at(0); - result->printIndexedBuffer("Cropped and Resized"); +// result->printIndexedBuffer("Cropped and Resized"); ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result));