diff --git a/libnd4j/include/ops/declarable/generic/nn/batchnorm.cpp b/libnd4j/include/ops/declarable/generic/nn/batchnorm.cpp index a8cd17131..3cf088ae9 100644 --- a/libnd4j/include/ops/declarable/generic/nn/batchnorm.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/batchnorm.cpp @@ -197,8 +197,7 @@ CUSTOM_OP_IMPL(batchnorm_bp, 4, 3, false, 1, 2) { // ***** calculations ***** // // notations: - // f = g * (gamma * ((x - m) / (v + eps)^0.5) + beta) -> means dLdO * ff_output - // g = dLdO + // f = g * (gamma * ((x - m) / (v + eps)^0.5) + beta) -> means dLdO * ff_output, g = dLdO // stdInv = 1 / (v + eps)^0.5 // N - batch size (product of spatial dimensions) diff --git a/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool2d.cpp b/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool2d.cpp index 4e3314897..873ac545a 100644 --- a/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool2d.cpp @@ -31,31 +31,28 @@ namespace ops { CUSTOM_OP_IMPL(avgpool2d, 1, 1, false, 0, 10) { auto input = INPUT_VARIABLE(0); - - REQUIRE_TRUE(input->rankOf() == 4, 0, "Input should have rank of 4, but got %i instead", input->rankOf()); + auto output = OUTPUT_VARIABLE(0); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; - auto argI = *(block.getIArguments()); - auto output = OUTPUT_VARIABLE(0); const auto kH = INT_ARG(0); const auto kW = INT_ARG(1); const auto sH = INT_ARG(2); const auto sW = INT_ARG(3); - int pH = INT_ARG(4); - int pW = INT_ARG(5); + auto pH = INT_ARG(4); + auto pW = INT_ARG(5); const auto dH = INT_ARG(6); const auto dW = INT_ARG(7); const auto isSameMode = static_cast(INT_ARG(8)); const auto extraParam0 = INT_ARG(9); + const int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC + REQUIRE_TRUE(input->rankOf() == 4, 0, "AVGPOOL2D op: input should have rank of 4, but got %i instead", input->rankOf()); REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D op: dilation must not be zero, but got instead {%i, %i}", dH, dW); int oH = 0; int oW = 0; - int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC - const int iH = static_cast(isNCHW ? input->sizeAt(2) : input->sizeAt(1)); const int iW = static_cast(isNCHW ? input->sizeAt(3) : input->sizeAt(2)); @@ -207,7 +204,6 @@ CUSTOM_OP_IMPL(avgpool2d_bp, 2, 1, false, 0, 10) { } return Status::OK(); - } DECLARE_SHAPE_FN(avgpool2d_bp) { diff --git a/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool3d.cpp b/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool3d.cpp index 3f118e002..b72a1f6f7 100644 --- a/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool3d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/pooling/avgpool3d.cpp @@ -51,14 +51,14 @@ CUSTOM_OP_IMPL(avgpool3dnew, 1, 1, false, 0, 14) { int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC REQUIRE_TRUE(input->rankOf() == 5, 0, "AVGPOOL3DNEW OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); - REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); + REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW OP: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); std::string expectedOutputShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2})); - REQUIRE_TRUE(expectedOutputShape == ShapeUtils::shapeAsString(output), 0, "AVGPOOL3D op: wrong shape of output array, expected is %s, but got %s instead !", expectedOutputShape.c_str(), ShapeUtils::shapeAsString(output).c_str()); + REQUIRE_TRUE(expectedOutputShape == ShapeUtils::shapeAsString(output), 0, "AVGPOOL3DNEW OP: wrong shape of output array, expected is %s, but got %s instead !", expectedOutputShape.c_str(), ShapeUtils::shapeAsString(output).c_str()); if(!isNCDHW) { input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] @@ -176,8 +176,8 @@ CUSTOM_OP_IMPL(avgpool3dnew_bp, 2, 1, false, 0, 14) { std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2})); std::string expectedGradIShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2})); - REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "AVGPOOL3D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "AVGPOOL3D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str()); + REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "AVGPOOL3DNEW_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "AVGPOOL3DNEW_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str()); if(!isNCDHW) { input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] diff --git a/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool2d.cpp b/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool2d.cpp index eb535a098..13ba252e7 100644 --- a/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool2d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool2d.cpp @@ -32,6 +32,7 @@ namespace ops { ////////////////////////////////////////////////////////////////////////// // maxpool2d corresponds to poolingMode=0 CUSTOM_OP_IMPL(maxpool2d, 1, 1, false, 0, 9) { + auto input = INPUT_VARIABLE(0); REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D OP: input array should have rank of 4, but got %i instead", input->rankOf()); diff --git a/libnd4j/include/ops/declarable/platform/cudnn/avgpool2d.cu b/libnd4j/include/ops/declarable/platform/cudnn/avgpool2d.cu new file mode 100644 index 000000000..8ff0bafb1 --- /dev/null +++ b/libnd4j/include/ops/declarable/platform/cudnn/avgpool2d.cu @@ -0,0 +1,138 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// + + +#include "cudnnUtils.h" +#include + +namespace nd4j { +namespace ops { +namespace platforms { + + +////////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(avgpool2d, ENGINE_CUDA) { + + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + const auto kH = INT_ARG(0); + const auto kW = INT_ARG(1); + const auto sH = INT_ARG(2); + const auto sW = INT_ARG(3); + auto pH = INT_ARG(4); + auto pW = INT_ARG(5); + const auto dH = INT_ARG(6); + const auto dW = INT_ARG(7); + const auto paddingMode = static_cast(INT_ARG(8)); + const auto extraParam0 = INT_ARG(9); + const int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC + + REQUIRE_TRUE(input->rankOf() == 4, 0, "AVGPOOL2D CUDNN op: input should have rank of 4, but got %i instead", input->rankOf()); + REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D CUDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW); + + int oH = 0; + int oW = 0; + + const int iH = static_cast(isNCHW ? input->sizeAt(2) : input->sizeAt(1)); + const int iW = static_cast(isNCHW ? input->sizeAt(3) : input->sizeAt(2)); + + ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, paddingMode); + + if (paddingMode) + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); + + const cudnnPoolingMode_t mode = (extraParam0 == 0) ? CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING : CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; + + pooling2dCUDNN(block.launchContext(), input, output, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW, mode); + + return Status::OK(); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_CHECK(avgpool2d, ENGINE_CUDA) { + + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32; + + return goodType && input->dataType() == output->dataType(); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(avgpool2d_bp, ENGINE_CUDA) { + + auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + + const auto kH = INT_ARG(0); // filter(kernel) height + const auto kW = INT_ARG(1); // filter(kernel) width + const auto sH = INT_ARG(2); // strides height + const auto sW = INT_ARG(3); // strides width + auto pH = INT_ARG(4); // paddings height + auto pW = INT_ARG(5); // paddings width + const auto dH = INT_ARG(6); // dilations height + const auto dW = INT_ARG(7); // dilations width + const auto paddingMode = INT_ARG(8); // 0-VALID, 1-SAME + const auto extraParam0 = INT_ARG(9); + const auto isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC + + REQUIRE_TRUE(input->rankOf() == 4, 0, "AVGPOOL2D_BP CUDNN op: input should have rank of 4, but got %i instead", input->rankOf()); + REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D_BP CUDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW); + + int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + + std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1}); + std::vector expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW, 0,indIOioC,indIiH,indIiH+1}); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "AVGPOOL2D_BP CUDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, "AVGPOOL2D_BP CUDNN op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradI).c_str()); + + if(paddingMode) // SAME + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); + + const cudnnPoolingMode_t mode = (extraParam0 == 0) ? CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING : CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; + + pooling2dBpCUDNN(block.launchContext(), input, gradO, gradI, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW, mode); + + return Status::OK(); +} + +PLATFORM_CHECK(avgpool2d_bp, ENGINE_CUDA) { + + auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + + const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32; + + return goodType && (input->dataType() == gradO->dataType()) + && (input->dataType() == gradI->dataType()) + && shape::haveSameShapeAndStrides(input->getShapeInfo(), gradI->getShapeInfo()); +} + + +} +} +} diff --git a/libnd4j/include/ops/declarable/platform/cudnn/avgpool3d.cu b/libnd4j/include/ops/declarable/platform/cudnn/avgpool3d.cu new file mode 100644 index 000000000..878f306b3 --- /dev/null +++ b/libnd4j/include/ops/declarable/platform/cudnn/avgpool3d.cu @@ -0,0 +1,144 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// + + +#include "cudnnUtils.h" +#include + +namespace nd4j { +namespace ops { +namespace platforms { + + +////////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(avgpool3dnew, ENGINE_CUDA) { + + auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW) + + int kD = INT_ARG(0); // filter(kernel) depth + int kH = INT_ARG(1); // filter(kernel) height + int kW = INT_ARG(2); // filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID + int extraParam0 = INT_ARG(13); + int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC + + REQUIRE_TRUE(input->rankOf() == 5, 0, "AVGPOOL3DNEW CUDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); + REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW CUDNN OP: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); + + int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + + std::vector expectedOutputShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); + REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, "AVGPOOL3DNEW CUDNN OP: wrong shape of output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedOutputShape).c_str(), ShapeUtils::shapeAsString(output).c_str()); + + if(paddingMode) // SAME + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); + + const cudnnPoolingMode_t mode = (extraParam0 == 0) ? CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING : CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; + + pooling3dCUDNN(block.launchContext(), input, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW, mode); + + return Status::OK(); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_CHECK(avgpool3dnew, ENGINE_CUDA) { + + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32; + + return goodType && input->dataType() == output->dataType(); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(avgpool3dnew_bp, ENGINE_CUDA) { + + auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next + auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon + + const int kD = INT_ARG(0); // filter(kernel) depth + const int kH = INT_ARG(1); // filter(kernel) height + const int kW = INT_ARG(2); // filter(kernel) width + const int sD = INT_ARG(3); // strides depth + const int sH = INT_ARG(4); // strides height + const int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + const int dD = INT_ARG(9); // dilations depth + const int dH = INT_ARG(10); // dilations height + const int dW = INT_ARG(11); // dilations width + const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID + const int extraParam0 = INT_ARG(13); // define what divisor to use while averaging + const int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC + + REQUIRE_TRUE(input->rankOf() == 5, 0, "AVGPOOL3DNEW_BP CUDNN OP: input should have rank of 5, but got %i instead", input->rankOf()); + REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW_BP CUDNN OP: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); + + int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + + std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); + std::vector expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "AVGPOOL3DNEW_BP CUDNN: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, "AVGPOOL3DNEW_BP CUDNN: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradI).c_str()); + + if(isSameMode) // SAME + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); + + const cudnnPoolingMode_t mode = (extraParam0 == 0) ? CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING : CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; + + pooling3dBpCUDNN(block.launchContext(), input, gradO, gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW, mode); + + return Status::OK(); +} + +PLATFORM_CHECK(avgpool3dnew_bp, ENGINE_CUDA) { + + auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next + auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon + + const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32; + + return goodType && (input->dataType() == gradO->dataType()) + && (input->dataType() == gradI->dataType()) + && shape::haveSameShapeAndStrides(input->getShapeInfo(), gradI->getShapeInfo()); +} + + +} +} +} diff --git a/libnd4j/include/ops/declarable/platform/cudnn/batchnorm.cu b/libnd4j/include/ops/declarable/platform/cudnn/batchnorm.cu index 3bd1357bf..1177d1a3c 100644 --- a/libnd4j/include/ops/declarable/platform/cudnn/batchnorm.cu +++ b/libnd4j/include/ops/declarable/platform/cudnn/batchnorm.cu @@ -97,9 +97,6 @@ static void batchnormCUDNN(const LaunchContext* context, err = cudnnSetTensorNdDescriptor(params, dataType, xRank, paramsShape.data(), paramsStrides.data()); if (err != 0) throw nd4j::cuda_exception::build("batchnormCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for mean/variance/gamma/beta failed", err); - - if (err != 0) throw nd4j::cuda_exception::build("batchnormCUDNN: cudnnSetConvolutionNdDescriptor failed", err); - // provide scaling parameters const float alpha32(1), beta32(0); const double alpha64(1), beta64(0); @@ -114,20 +111,127 @@ static void batchnormCUDNN(const LaunchContext* context, x, input->getSpecialBuffer(), z, output->getSpecialBuffer(), params, - gamma ? gamma->getSpecialBuffer(): nullptr, - beta ? beta->getSpecialBuffer() : nullptr, + gamma->getSpecialBuffer(), beta->getSpecialBuffer(), mean->getSpecialBuffer(), variance->getSpecialBuffer(), epsilon); if (err != 0) throw nd4j::cuda_exception::build("batchnormCUDNN: cudnnBatchNormalizationForwardInference failed", err); - // cudaErr = cudaStreamSynchronize(*context->getCudaStream()); - // if (cudaErr != 0) - // throw cuda_exception::build("batchnormCUDNN: cudaStreamSynchronize failed !", cudaErr); - + auto cudaErr = cudaStreamSynchronize(*context->getCudaStream()); + if (cudaErr != 0) + throw cuda_exception::build("batchnormCUDNN: cudaStreamSynchronize failed !", cudaErr); NDArray::registerSpecialUse({output}, {input, mean, variance, gamma, beta}); } +////////////////////////////////////////////////////////////////////////// +static void batchnormBpCUDNN(const LaunchContext* context, + const NDArray* input, const NDArray* mean, const NDArray* variance, const NDArray* gamma, const NDArray* gradO, + NDArray* gradI, NDArray* gradG, NDArray* gradB, + const double epsilon, const bool isSpatialMode) { + + // input, gradO, gradI -> 4D:nchw, 5D:ncdhw + // mean, variance, gamma, beta, gradM, gradV, gradG, gradB -> 1xCx1x1 for 4D and 1xCx1x1x1 for 5D for BATCHNORM_MODE_SPATIAL mode + // -> 1xCxHxW for 4D and 1xCxDxHxW for 5D for BATCHNORM_MODE_PER_ACTIVATION mode + + const cudnnDataType_t dataType = cudnnDataType(input->dataType()); + + const int xRank = input->rankOf(); + + auto handle = reinterpret_cast(context->getCuDnnHandle()); + cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); + if (err != 0) throw nd4j::cuda_exception::build("batchnormBpCUDNN: can't set stream for cuDNN", err); + + const std::vector xShape = input->getShapeAsVectorInt(); // input and output have same shapes + + std::vector paramsShape, paramsStrides; // mean, variance, gamma and beta have same shapes + if(isSpatialMode) { // 1xCx1x1 + const int iC = mean->lengthOf(); + const int stride0 = mean->strideAt(0); + paramsShape = xRank == 4 ? std::vector({1, iC, 1, 1}) : std::vector({1, iC, 1, 1, 1}); + paramsStrides = xRank == 4 ? std::vector({iC*stride0, stride0, 1, 1}) : std::vector({iC*stride0, stride0, 1, 1, 1}); + } + else { + paramsShape = mean->getShapeAsVectorInt(); + paramsStrides = xRank == 4 ? std::vector({(int)mean->strideAt(0), (int)mean->strideAt(1), (int)mean->strideAt(2), (int)mean->strideAt(3)}) : std::vector({(int)mean->strideAt(0), (int)mean->strideAt(1), (int)mean->strideAt(2), (int)mean->strideAt(3), (int)mean->strideAt(4)}); + } + + std::vector xStrides = {(int)input->strideAt(0), (int)input->strideAt(1), (int)input->strideAt(2), (int)input->strideAt(3)}; + std::vector dxStrides = {(int)gradI->strideAt(0), (int)gradI->strideAt(1), (int)gradI->strideAt(2), (int)gradI->strideAt(3)}; + std::vector dzStrides = {(int)gradO->strideAt(0), (int)gradO->strideAt(1), (int)gradO->strideAt(2), (int)gradO->strideAt(3)}; + + if(xRank > 4) { // 5D + xStrides.push_back((int)input->strideAt(4)); + dxStrides.push_back((int)gradI->strideAt(4)); + dzStrides.push_back((int)gradO->strideAt(4)); + } + + cudnnTensorFormat_t format = CUDNN_TENSOR_NCHW; + + // input descriptor + cudnnTensorDescriptor_t x; + cudnnCreateTensorDescriptor(&x); + if(input->ews() == 1) + err = cudnnSetTensorNdDescriptorEx(x, format, dataType, xRank, xShape.data()); + else + err = cudnnSetTensorNdDescriptor(x, dataType, xRank, xShape.data(), xStrides.data()); + if (err != 0) throw nd4j::cuda_exception::build("batchnormBpCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for input failed", err); + + // gradO descriptor + cudnnTensorDescriptor_t dz; + cudnnCreateTensorDescriptor(&dz); + if(gradO->ews() == 1) + err = cudnnSetTensorNdDescriptorEx(dz, format, dataType, xRank, xShape.data()); + else + err = cudnnSetTensorNdDescriptor(dz, dataType, xRank, xShape.data(), dzStrides.data()); + if (err != 0) throw nd4j::cuda_exception::build("batchnormBpCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for gradO failed", err); + + // gradI descriptor + cudnnTensorDescriptor_t dx; + cudnnCreateTensorDescriptor(&dx); + if(input->ews() == 1) + err = cudnnSetTensorNdDescriptorEx(dx, format, dataType, xRank, xShape.data()); + else + err = cudnnSetTensorNdDescriptor(dx, dataType, xRank, xShape.data(), dxStrides.data()); + if (err != 0) throw nd4j::cuda_exception::build("batchnormBpCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for gradI failed", err); + + // mean, variance, gamma, gradG and gradB descriptor, the same descriptor for all of them + cudnnTensorDescriptor_t params; + cudnnCreateTensorDescriptor(¶ms); + if(mean->ews() == 1) + err = cudnnSetTensorNdDescriptorEx(params, format, dataType, xRank, paramsShape.data()); + else + err = cudnnSetTensorNdDescriptor(params, dataType, xRank, paramsShape.data(), paramsStrides.data()); + if (err != 0) throw nd4j::cuda_exception::build("batchnormBpCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for mean/variance/gamma/gradG/gradB failed", err); + + // provide scaling parameters + const float alpha32(1), beta32(0); + double alpha64(1), beta64(0); + const void* ptrAlpha = input->sizeOfT() <= 4 ? reinterpret_cast(&alpha32) : reinterpret_cast(&alpha64); + const void* ptrBeta = input->sizeOfT() <= 4 ? reinterpret_cast(&beta32) : reinterpret_cast(&beta64); + + NDArray::prepareSpecialUse({gradI, gradG, gradB}, {input, mean, variance, gamma, gradO}); + + // calculations + // TODO: we can use cache here + err = cudnnBatchNormalizationBackward(*handle, isSpatialMode ? CUDNN_BATCHNORM_SPATIAL : CUDNN_BATCHNORM_PER_ACTIVATION, + ptrAlpha, ptrBeta, ptrAlpha, ptrBeta, + x, input->getSpecialBuffer(), + dz, gradO->getSpecialBuffer(), + dx, gradI->getSpecialBuffer(), + params, + gamma->getSpecialBuffer(), gradG->getSpecialBuffer(), gradB->getSpecialBuffer(), + epsilon, + nullptr/*mean->getSpecialBuffer()*/, nullptr/*variance->getSpecialBuffer()*/); + + if (err != 0) throw nd4j::cuda_exception::build("batchnormBpCUDNN: cudnnBatchNormalizationBackward failed", err); + + auto cudaErr = cudaStreamSynchronize(*context->getCudaStream()); + if (cudaErr != 0) + throw cuda_exception::build("batchnormBpCUDNN: cudaStreamSynchronize failed !", cudaErr); + + NDArray::registerSpecialUse({gradI, gradG, gradB}, {input, mean, variance, gamma, gradO}); +} + ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(batchnorm, ENGINE_CUDA) { @@ -189,11 +293,21 @@ PLATFORM_IMPL(batchnorm, ENGINE_CUDA) { const bool needPermut = axes.size() == 1 && mean->lengthOf() == input->sizeAt(-1); if(needPermut) { // if NHWC - std::vector perm = {0, 3, 1, 2}; // NHWC -> NCHW + std::vector perm = inRank == 4 ? std::vector({0, 3, 1, 2}) : std::vector({0, 4, 1, 2, 3}); // NHWC -> NCHW input = new NDArray(input->permute(perm)); output = new NDArray(output->permute(perm)); } + // cudnn requires gamma and beta to be non-nullptr + if(!applyScale) { + gamma = new NDArray(mean); + *gamma = 1; + } + if(!applyOffset) { + beta = new NDArray(mean); + *beta = 0; + } + // calculations batchnormCUDNN(block.launchContext(), input, mean, variance, gamma, beta, output, epsilon, axes.size() == 1); @@ -202,6 +316,12 @@ PLATFORM_IMPL(batchnorm, ENGINE_CUDA) { delete output; } + if(!applyScale) + delete gamma; + + if(!applyOffset) + delete beta; + return Status::OK(); } @@ -220,9 +340,6 @@ PLATFORM_CHECK(batchnorm, ENGINE_CUDA) { const int numOfIntArgs = block.getIArguments()->size(); const int xRank = input->rankOf(); - // disable cudnn batchnorm so far - return false; - // *********************************** // if(xRank != 4 && xRank != 5) return false; @@ -269,6 +386,182 @@ PLATFORM_CHECK(batchnorm, ENGINE_CUDA) { return true; } +////////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(batchnorm_bp, ENGINE_CUDA) { + + NDArray* input = INPUT_VARIABLE(0); + NDArray* mean = INPUT_VARIABLE(1); + NDArray* variance = INPUT_VARIABLE(2); + NDArray* gamma = nullptr; + NDArray* beta = nullptr; + NDArray* gradO = INPUT_VARIABLE(block.width() - 1); // next epsilon + + NDArray* gradI = OUTPUT_VARIABLE(0); + NDArray* gradM = OUTPUT_VARIABLE(1); + NDArray* gradV = OUTPUT_VARIABLE(2); + NDArray* gradG = nullptr; + NDArray* gradB = nullptr; + + const bool applyScale = (bool)INT_ARG(0); + const bool applyOffset = (bool)INT_ARG(1); + const float epsilon = T_ARG(0); + + if(applyScale) { + gamma = INPUT_VARIABLE(3); + gradG = OUTPUT_VARIABLE(3); + } + if(applyOffset) { + beta = INPUT_VARIABLE(3 + (int)applyScale); + gradB = OUTPUT_VARIABLE(3 + (int)applyScale); + } + + const int numOfIntArgs = block.getIArguments()->size(); + const int inRank = input->rankOf(); + + // get axes args to normalize input array over + std::vector axes; + if(numOfIntArgs > 2) + for(int i = 2; i < numOfIntArgs; ++i) + axes.push_back(INT_ARG(i)); + else + axes.push_back(inRank-1); // default dimension to reduce along is last dimension + + const int numOfAxes = axes.size(); + REQUIRE_TRUE(numOfAxes <= inRank, 0, "BATCHNORM_BP CUDNN op: too big number of input axes to normalize over, expected number should be less or equal to rank of input array, but got %i and %i correspondingly !", numOfAxes, inRank); + + // evaluate expected shape for mean, variance and gamma. These 3 arrays should have identical shapes + // for example if input shape is {2,3,4,5,6} and axes = {1,3}, then expected shape would be {1,3,1,5,1}, and if axes = {3}, then expected shape would be {5} + std::vector expShape; + if(numOfAxes == 1) + expShape.push_back(input->sizeAt(axes[0])); + else { // get, for example, something like {1, inputDim1, 1, inputDim3, 1} if axes = {1, 3} + expShape = std::vector(inRank, 1); + for(uint i = 0; i < numOfAxes; ++i) + expShape[axes[i]] = input->sizeAt(axes[i]); + } + + REQUIRE_TRUE(mean->isSameShape(expShape), 0, "BATCHNORM_BP CUDNN op: wrong shape of mean array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(mean).c_str()); + REQUIRE_TRUE(variance->isSameShape(expShape), 0, "BATCHNORM_BP CUDNN op: wrong shape of variance array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(variance).c_str()); + if(gamma) + REQUIRE_TRUE(gamma->isSameShape(expShape), 0, "BATCHNORM_BP CUDNN op: wrong shape of gamma array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(gamma).c_str()); + if(beta) + REQUIRE_TRUE(beta->isSameShape(expShape), 0, "BATCHNORM_BP CUDNN op: wrong shape of beta array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(beta).c_str()); + + REQUIRE_TRUE(input->isSameShape(gradO), 0, "BATCHNORM_BP CUDNN op: wrong shape of output gradients array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(input).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); + + // types of all input arrays should be the same (except gradO) + for(int i = 1; i < block.width() - 2; ++i) + REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM_BP CUDNN op: types of arrays (input, mean, variance, gamma, beta) should be the same !"); + + // cudnn supports NCHW format only + const bool needPermut = axes.size() == 1 && mean->lengthOf() != input->sizeAt(1); + + if(needPermut) { // if NHWC + std::vector perm = inRank == 4 ? std::vector({0, 3, 1, 2}) : std::vector({0, 4, 1, 2, 3}); // NHWC -> NCHW + input = new NDArray(input->permute(perm)); + gradO = new NDArray(gradO->permute(perm)); + gradI = new NDArray(gradI->permute(perm)); + } + + // cudnn requires gamma, gradG, gradB to be non-nullptr + if(!applyScale) { + gamma = new NDArray(mean); + gradG = new NDArray(mean); + *gamma = 1; + } + if(!applyOffset) + gradB = new NDArray(mean); + + // calculations + batchnormBpCUDNN(block.launchContext(), input, mean, variance, gamma, gradO, gradI, gradG, gradB, epsilon, axes.size() == 1); + + *gradM = 0; // put zeros so far + *gradV = 0; // put zeros so far + + if(needPermut) { + delete input; + delete gradO; + delete gradI; + } + + if(!applyScale) { + delete gamma; + delete gradG; + } + + if(!applyOffset) + delete gradB; + + return Status::OK(); + +} + +PLATFORM_CHECK(batchnorm_bp, ENGINE_CUDA) { + + NDArray* input = INPUT_VARIABLE(0); + NDArray* mean = INPUT_VARIABLE(1); + NDArray* variance = INPUT_VARIABLE(2); + NDArray* gamma = nullptr; + NDArray* beta = nullptr; + NDArray* gradO = INPUT_VARIABLE(block.width() - 1); // next epsilon + + NDArray* gradI = OUTPUT_VARIABLE(0); + NDArray* gradM = OUTPUT_VARIABLE(1); + NDArray* gradV = OUTPUT_VARIABLE(2); + NDArray* gradG = nullptr; + NDArray* gradB = nullptr; + + const int numOfIntArgs = block.getIArguments()->size(); + const int xRank = input->rankOf(); + + // *********************************** // + if(xRank != 4 && xRank != 5) + return false; + + // *********************************** // + const bool badType = input->dataType() != DataType::DOUBLE && input->dataType() != DataType::FLOAT32 && input->dataType() != DataType::HALF; + if(badType) + return false; + + // *********************************** // + // get axes args to normalize input array over + std::vector axes; + if(numOfIntArgs > 2) + for(int i = 2; i < numOfIntArgs; ++i) + axes.push_back(INT_ARG(i)); + else + axes.push_back(xRank-1); // default dimension to reduce along is last dimension + + if(axes.size() != 1 && axes.size() != 3 && axes.size() != 4) + return false; + + // *********************************** // + bool allParamsHaveSameShapeAndStrides = shape::haveSameShapeAndStrides(mean->getShapeInfo(), variance->getShapeInfo()); + if(gamma) + allParamsHaveSameShapeAndStrides &= shape::haveSameShapeAndStrides(mean->getShapeInfo(), gamma->getShapeInfo()); + if(gradG) + allParamsHaveSameShapeAndStrides &= shape::haveSameShapeAndStrides(mean->getShapeInfo(), gradG->getShapeInfo()); + if(gradB) + allParamsHaveSameShapeAndStrides &= shape::haveSameShapeAndStrides(mean->getShapeInfo(), gradB->getShapeInfo()); + + if(!allParamsHaveSameShapeAndStrides) + return false; + + // *********************************** // + bool isFormatGood = false; + if(axes.size() == 1) + isFormatGood = mean->lengthOf() == input->sizeAt(1) || mean->lengthOf() == input->sizeAt(-1); // mean [C] + else { + auto inputShapeModif = input->getShapeAsVector(); // [dim0,dim1,dim2,dim3] 4D or [dim0,dim1,dim2,dim3,dim4] + inputShapeModif[0] = 1; + isFormatGood = mean->isSameShape(inputShapeModif); // mean [1,dim1,dim2,dim3] 4D or [1,dim1,dim2,dim3,dim4] + } + if(!isFormatGood) + return false; + + return true; +} + } } diff --git a/libnd4j/include/ops/declarable/platform/cudnn/cudnnUtils.cu b/libnd4j/include/ops/declarable/platform/cudnn/cudnnUtils.cu new file mode 100644 index 000000000..fa7b1ecfa --- /dev/null +++ b/libnd4j/include/ops/declarable/platform/cudnn/cudnnUtils.cu @@ -0,0 +1,412 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// + + +#include "cudnnUtils.h" +#include + +namespace nd4j { +namespace ops { +namespace platforms { + +////////////////////////////////////////////////////////////////////////// +void checkConv2dCUDNNPadAsymmetric(NDArray* &input, NDArray* &gradI, + const int iH, const int iW, + const int oH, const int oW, + const int kH, const int kW, + const int sH, const int sW, + const int pH, const int pW, + const int dH, const int dW, + const bool isNCHW) { + + const auto pHsum = ((oH - 1) * sH + ((kH - 1) * dH + 1) - iH); + const auto pWsum = ((oW - 1) * sW + ((kW - 1) * dW + 1) - iW); + + const bool isPHasymm = pH != (pHsum - pH); + const bool isPWasymm = pW != (pWsum - pW); + + if(!isPHasymm && !isPWasymm) + return; + + std::vector newShape = input->getShapeAsVector(); + + const int iHposition = isNCHW ? 2 : 1; + + if(isPHasymm) + newShape[iHposition] += 1; + if(isPWasymm) + newShape[iHposition + 1] += 1; + + NDArray* newInput = new NDArray(input->ordering(), newShape, input->dataType(), input->getContext()); + + if(isNCHW) + (*newInput)({0,0, 0,0, 0,input->sizeAt(2), 0,input->sizeAt(3)}).assign(input); + else + (*newInput)({0,0, 0,input->sizeAt(1), 0,input->sizeAt(2), 0,0}).assign(input); + + input = newInput; + + if(gradI != nullptr) + gradI = new NDArray(gradI->ordering(), newShape, gradI->dataType(), gradI->getContext()); +} + + +////////////////////////////////////////////////////////////////////////// +void checkConv3dCUDNNPadAsymmetric(NDArray* &input, NDArray* &gradI, + const int iD, const int iH, const int iW, + const int oD, const int oH, const int oW, + const int kD, const int kH, const int kW, + const int sD, const int sH, const int sW, + const int pD, const int pH, const int pW, + const int dD, const int dH, const int dW, + const bool isNCDHW) { + + const auto pDsum = ((oD - 1) * sD + ((kD - 1) * dD + 1) - iD); + const auto pHsum = ((oH - 1) * sH + ((kH - 1) * dH + 1) - iH); + const auto pWsum = ((oW - 1) * sW + ((kW - 1) * dW + 1) - iW); + + const bool isPDasymm = pD != (pDsum - pD); + const bool isPHasymm = pH != (pHsum - pH); + const bool isPWasymm = pW != (pWsum - pW); + + if(!isPDasymm && !isPHasymm && !isPWasymm) + return; + + std::vector newShape = input->getShapeAsVector(); + + const int iDposition = isNCDHW ? 2 : 1; + + if(isPDasymm) + newShape[iDposition] += 1; + if(isPHasymm) + newShape[iDposition + 1] += 1; + if(isPWasymm) + newShape[iDposition + 2] += 1; + + NDArray* newInput = new NDArray(input->ordering(), newShape, input->dataType(), input->getContext()); + + if(isNCDHW) + (*newInput)({0,0, 0,0, 0,input->sizeAt(2), 0,input->sizeAt(3), 0,input->sizeAt(4)}).assign(input); + else + (*newInput)({0,0, 0,input->sizeAt(1), 0,input->sizeAt(2), 0,input->sizeAt(3), 0,0}).assign(input); + + input = newInput; + + if(gradI != nullptr) + gradI = new NDArray(gradI->ordering(), newShape, gradI->dataType(), gradI->getContext()); +} + +////////////////////////////////////////////////////////////////////////// +void pooling2dCUDNN(const LaunchContext* context, + const NDArray* input, NDArray* output, + const int kH, const int kW, + const int sH, const int sW, + const int pH, const int pW, + const int dH, const int dW, + const bool isNCHW, const cudnnPoolingMode_t mode) { + + int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + + auto handle = reinterpret_cast(context->getCuDnnHandle()); + cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); + if (err != 0) throw nd4j::cuda_exception::build("pooling2dCUDNN: can't set stream for cuDNN", err); + + cudnnTensorFormat_t format = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; + + // input descriptor + cudnnTensorDescriptor_t x; + cudnnCreateTensorDescriptor(&x); + if(input->ews() == 1) + err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); + else + err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1)); + if (err != 0) throw nd4j::cuda_exception::build("pooling2dCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for input failed", err); + + // output descriptor + cudnnTensorDescriptor_t z; + cudnnCreateTensorDescriptor(&z); + if(output->ews() == 1) + err = cudnnSetTensor4dDescriptor(z, format, cudnnDataType(output->dataType()), bS, oC, oH, oW); + else + err = cudnnSetTensor4dDescriptorEx(z, cudnnDataType(output->dataType()), bS, oC, oH, oW, output->strideAt(0), output->strideAt(indIOioC), output->strideAt(indOoH), output->strideAt(indOoH + 1)); + if (err != 0) throw nd4j::cuda_exception::build("pooling2dCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for output failed", err); + + // description of pooling + cudnnPoolingDescriptor_t pooling; + cudnnCreatePoolingDescriptor(&pooling); + err = cudnnSetPooling2dDescriptor(pooling, mode, CUDNN_PROPAGATE_NAN, kH, kW, pH, pW, sH, sW); + if (err != 0) throw nd4j::cuda_exception::build("pooling2dCUDNN: cudnnSetPooling2dDescriptor failed", err); + + // provide scaling parameters + const float alpha32(1), beta32(0); + const double alpha64(1), beta64(0); + const void* alpha = output->sizeOfT() <= 4 ? reinterpret_cast(&alpha32) : reinterpret_cast(&alpha64); + const void* beta = output->sizeOfT() <= 4 ? reinterpret_cast(&beta32) : reinterpret_cast(&beta64); + + NDArray::prepareSpecialUse({output}, {input}); + + // run calculation + err = cudnnPoolingForward(*handle, pooling, alpha, x, input->getSpecialBuffer(), beta, z, output->specialBuffer()); + if (err != 0) throw nd4j::cuda_exception::build("pooling2dCUDNN: cudnnPoolingForward failed", err); + + auto cudaErr = cudaStreamSynchronize(*context->getCudaStream()); + if (cudaErr != 0) + throw cuda_exception::build("pooling2dCUDNN: cudaStreamSynchronize failed !", cudaErr); + + NDArray::registerSpecialUse({output}, {input}); +} + +////////////////////////////////////////////////////////////////////////// +void pooling2dBpCUDNN(const LaunchContext* context, + const NDArray* input, const NDArray* gradO, + NDArray* gradI, + const int kH, const int kW, + const int sH, const int sW, + const int pH, const int pW, + const int dH, const int dW, + const bool isNCHW, const cudnnPoolingMode_t mode) { + + int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + + auto handle = reinterpret_cast(context->getCuDnnHandle()); + cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); + if (err != 0) throw nd4j::cuda_exception::build("pooling2dBpCUDNN: can't set stream for cuDNN", err); + + cudnnTensorFormat_t format = isNCHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; + + // input and gradI descriptor + cudnnTensorDescriptor_t x; + cudnnCreateTensorDescriptor(&x); + if(input->ews() == 1) + err = cudnnSetTensor4dDescriptor(x, format, cudnnDataType(input->dataType()), bS, iC, iH, iW); + else + err = cudnnSetTensor4dDescriptorEx(x, cudnnDataType(input->dataType()), bS, iC, iH, iW, input->strideAt(0), input->strideAt(indIOioC), input->strideAt(indIiH), input->strideAt(indIiH + 1)); + if (err != 0) throw nd4j::cuda_exception::build("pooling2dBpCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for input/gradI failed", err); + + // gradO descriptor + cudnnTensorDescriptor_t dz; + cudnnCreateTensorDescriptor(&dz); + if(gradO->ews() == 1) + err = cudnnSetTensor4dDescriptor(dz, format, cudnnDataType(gradO->dataType()), bS, oC, oH, oW); + else + err = cudnnSetTensor4dDescriptorEx(dz, cudnnDataType(gradO->dataType()), bS, oC, oH, oW, gradO->strideAt(0), gradO->strideAt(indIOioC), gradO->strideAt(indOoH), gradO->strideAt(indOoH + 1)); + if (err != 0) throw nd4j::cuda_exception::build("pooling2dBpCUDNN: cudnnSetTensor4dDescriptor/cudnnSetTensor4dDescriptorEx for gradO failed", err); + + // description of pooling + cudnnPoolingDescriptor_t pooling; + cudnnCreatePoolingDescriptor(&pooling); + err = cudnnSetPooling2dDescriptor(pooling, mode, CUDNN_PROPAGATE_NAN, kH, kW, pH, pW, sH, sW); + if (err != 0) throw nd4j::cuda_exception::build("pooling2dBpCUDNN: cudnnSetPooling2dDescriptor failed", err); + + // provide scaling parameters + const float alpha32(1), beta32(0); + const double alpha64(1), beta64(0); + const void* alpha = gradO->sizeOfT() <= 4 ? reinterpret_cast(&alpha32) : reinterpret_cast(&alpha64); + const void* beta = gradO->sizeOfT() <= 4 ? reinterpret_cast(&beta32) : reinterpret_cast(&beta64); + + NDArray::prepareSpecialUse({gradI}, {input, gradO}); + + // run calculation for gradI + err = cudnnPoolingBackward(*handle, pooling, alpha, dz, gradO->getSpecialBuffer(), dz, gradO->getSpecialBuffer(), x, input->getSpecialBuffer(), beta, x, gradI->getSpecialBuffer()); + if (err != 0) throw nd4j::cuda_exception::build("pooling2dBpCUDNN: cudnnPoolingBackward failed", err); + + auto cudaErr = cudaStreamSynchronize(*context->getCudaStream()); + if (cudaErr != 0) + throw cuda_exception::build("pooling2dBpCUDNN: cudaStreamSynchronize failed !", cudaErr); + + NDArray::registerSpecialUse({gradI}, {input, gradO}); +} + +////////////////////////////////////////////////////////////////////////// +void pooling3dCUDNN(const LaunchContext* context, + const NDArray* input, NDArray* output, + const int kD, const int kH, const int kW, + const int sD, const int sH, const int sW, + const int pD, const int pH, const int pW, + const int dD, const int dH, const int dW, + const bool isNCDHW, const cudnnPoolingMode_t mode) { + + auto handle = reinterpret_cast(context->getCuDnnHandle()); + cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); + if (err != 0) throw nd4j::cuda_exception::build("pooling3dCUDNN: can't set stream for cuDNN", err); +printf("fffffffffff\n"); + const int numDims = 5; + + int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + + const int pSizes[] = {pD, pH, pW}; + const int sSizes[] = {sD, sH, sW}; + const int kSizes[] = {kD, kH, kW}; + + const int xShape[] = {bS, iC, iD, iH, iW}; + const int zShape[] = {bS, oC, oD, oH, oW}; + + const int xStrides[] = {(int)input->strideAt(0), (int)input->strideAt(1), (int)input->strideAt(2), (int)input->strideAt(3), (int)input->strideAt(4)}; + const int zStrides[] = {(int)output->strideAt(0), (int)output->strideAt(1), (int)output->strideAt(2), (int)output->strideAt(3), (int)output->strideAt(4)}; + + cudnnTensorFormat_t format = isNCDHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; + + // input descriptor + cudnnTensorDescriptor_t x; + cudnnCreateTensorDescriptor(&x); + if(input->ews() == 1) + err = cudnnSetTensorNdDescriptorEx(x, format, cudnnDataType(input->dataType()), numDims, xShape); + else + err = cudnnSetTensorNdDescriptor(x, cudnnDataType(input->dataType()), numDims, xShape, xStrides); + if (err != 0) throw nd4j::cuda_exception::build("pooling3dCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for input failed", err); + + // output descriptor + cudnnTensorDescriptor_t z; + cudnnCreateTensorDescriptor(&z); + if(output->ews() == 1) + err = cudnnSetTensorNdDescriptorEx(z, format, cudnnDataType(output->dataType()), numDims, zShape); + else + err = cudnnSetTensorNdDescriptor(z, cudnnDataType(output->dataType()), numDims, zShape, zStrides); + if (err != 0) throw nd4j::cuda_exception::build("pooling3dCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for output failed", err); + + // description of pooling + cudnnPoolingDescriptor_t pooling; + cudnnCreatePoolingDescriptor(&pooling); + err = cudnnSetPoolingNdDescriptor(pooling, mode, CUDNN_PROPAGATE_NAN, numDims - 2, kSizes, pSizes, sSizes); + if (err != 0) throw nd4j::cuda_exception::build("pooling3dCUDNN: cudnnSetPoolingNdDescriptor failed", err); + + // provide scaling parameters + const float alpha32(1), beta32(0); + const double alpha64(1), beta64(0); + const void* alpha = output->sizeOfT() <= 4 ? reinterpret_cast(&alpha32) : reinterpret_cast(&alpha64); + const void* beta = output->sizeOfT() <= 4 ? reinterpret_cast(&beta32) : reinterpret_cast(&beta64); + + NDArray::prepareSpecialUse({output}, {input}); + + // run calculation + err = cudnnPoolingForward(*handle, pooling, alpha, x, input->getSpecialBuffer(), beta, z, output->specialBuffer()); + if (err != 0) throw nd4j::cuda_exception::build("pooling3dCUDNN: cudnnPoolingForward failed", err); + + auto cudaErr = cudaStreamSynchronize(*context->getCudaStream()); + if (cudaErr != 0) + throw cuda_exception::build("pooling3dCUDNN: cudaStreamSynchronize failed !", cudaErr); + + NDArray::registerSpecialUse({output}, {input}); +} + +////////////////////////////////////////////////////////////////////////// +void pooling3dBpCUDNN(const LaunchContext* context, + const NDArray* input, const NDArray* gradO, + NDArray* gradI, + const int kD, const int kH, const int kW, + const int sD, const int sH, const int sW, + const int pD, const int pH, const int pW, + const int dD, const int dH, const int dW, + const bool isNCDHW, const cudnnPoolingMode_t mode) { + + auto handle = reinterpret_cast(context->getCuDnnHandle()); + cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); + if (err != 0) throw nd4j::cuda_exception::build("pooling3dBpCUDNN: can't set stream for cuDNN", err); + + const int numDims = 5; + + int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + + const int pSizes[] = {pD, pH, pW}; + const int sSizes[] = {sD, sH, sW}; + const int kSizes[] = {kD, kH, kW}; + + const int xShape[] = {bS, iC, iD, iH, iW}; + const int dzShape[] = {bS, oC, oD, oH, oW}; + + const int xStrides[] = {(int)input->strideAt(0), (int)input->strideAt(1), (int)input->strideAt(2), (int)input->strideAt(3), (int)input->strideAt(4)}; + const int dzStrides[] = {(int)gradO->strideAt(0), (int)gradO->strideAt(1), (int)gradO->strideAt(2), (int)gradO->strideAt(3), (int)gradO->strideAt(4)}; + + cudnnTensorFormat_t format = isNCDHW ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; + + // input and gradI descriptor + cudnnTensorDescriptor_t x; + cudnnCreateTensorDescriptor(&x); + if(input->ews() == 1) + err = cudnnSetTensorNdDescriptorEx(x, format, cudnnDataType(input->dataType()), numDims, xShape); + else + err = cudnnSetTensorNdDescriptor(x, cudnnDataType(input->dataType()), numDims, xShape, xStrides); + if (err != 0) throw nd4j::cuda_exception::build("pooling3dBpCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for input/gradI failed", err); + + // gradO descriptor + cudnnTensorDescriptor_t dz; + cudnnCreateTensorDescriptor(&dz); + if(gradO->ews() == 1) + err = cudnnSetTensorNdDescriptorEx(dz, format, cudnnDataType(gradO->dataType()), numDims, dzShape); + else + err = cudnnSetTensorNdDescriptor(dz, cudnnDataType(gradO->dataType()), numDims, dzShape, dzStrides); + if (err != 0) throw nd4j::cuda_exception::build("pooling3dBpCUDNN: cudnnSetTensorNdDescriptor/cudnnSetTensorNdDescriptorEx for gradO failed", err); + + // description of pooling + cudnnPoolingDescriptor_t pooling; + cudnnCreatePoolingDescriptor(&pooling); + err = cudnnSetPoolingNdDescriptor(pooling, mode, CUDNN_PROPAGATE_NAN, numDims - 2, kSizes, pSizes, sSizes); + if (err != 0) throw nd4j::cuda_exception::build("pooling3dBpCUDNN: cudnnSetPoolingNdDescriptor failed", err); + + // provide scaling parameters + const float alpha32(1), beta32(0); + const double alpha64(1), beta64(0); + const void* alpha = gradO->sizeOfT() <= 4 ? reinterpret_cast(&alpha32) : reinterpret_cast(&alpha64); + const void* beta = gradO->sizeOfT() <= 4 ? reinterpret_cast(&beta32) : reinterpret_cast(&beta64); + + // cudnn maxpool2d_bp api requires ff output as one of input arguments + if(mode == CUDNN_POOLING_MAX) { + + NDArray temp(gradO); + + NDArray::prepareSpecialUse({gradI}, {input, gradO, &temp}); + + // run ff calculation + err = cudnnPoolingForward(*handle, pooling, alpha, x, input->getSpecialBuffer(), beta, dz, temp.specialBuffer()); + if (err != 0) throw nd4j::cuda_exception::build("pooling3dCUDNN: cudnnPoolingForward failed", err); + + // run bp calculation for gradI + err = cudnnPoolingBackward(*handle, pooling, alpha, dz, temp.getSpecialBuffer(), dz, gradO->getSpecialBuffer(), x, input->getSpecialBuffer(), beta, x, gradI->getSpecialBuffer()); + if (err != 0) throw nd4j::cuda_exception::build("pooling2dBpCUDNN: cudnnPoolingBackward failed", err); + + NDArray::registerSpecialUse({gradI}, {input, gradO, &temp}); + } + else { + + NDArray::prepareSpecialUse({gradI}, {input, gradO}); + + // run bp calculation for gradI + err = cudnnPoolingBackward(*handle, pooling, alpha, dz, gradO->getSpecialBuffer(), dz, gradO->getSpecialBuffer(), x, input->getSpecialBuffer(), beta, x, gradI->getSpecialBuffer()); + if (err != 0) throw nd4j::cuda_exception::build("pooling2dBpCUDNN: cudnnPoolingBackward failed", err); + + NDArray::registerSpecialUse({gradI}, {input, gradO}); + } + + auto cudaErr = cudaStreamSynchronize(*context->getCudaStream()); + if (cudaErr != 0) + throw cuda_exception::build("pooling3dBpCUDNN: cudaStreamSynchronize failed !", cudaErr); +} + +} +} +} diff --git a/libnd4j/include/ops/declarable/platform/cudnn/cudnnUtils.h b/libnd4j/include/ops/declarable/platform/cudnn/cudnnUtils.h index bdff86e24..5c46fb7b0 100644 --- a/libnd4j/include/ops/declarable/platform/cudnn/cudnnUtils.h +++ b/libnd4j/include/ops/declarable/platform/cudnn/cudnnUtils.h @@ -30,8 +30,8 @@ #include -namespace nd4j { -namespace ops { +namespace nd4j { +namespace ops { namespace platforms { DECLARE_PLATFORM(conv2d, ENGINE_CUDA); @@ -46,6 +46,18 @@ namespace platforms { DECLARE_PLATFORM(batchnorm, ENGINE_CUDA); DECLARE_PLATFORM(batchnorm_bp, ENGINE_CUDA); + DECLARE_PLATFORM(avgpool2d, ENGINE_CUDA); + DECLARE_PLATFORM(avgpool2d_bp, ENGINE_CUDA); + + DECLARE_PLATFORM(maxpool2d, ENGINE_CUDA); + DECLARE_PLATFORM(maxpool2d_bp, ENGINE_CUDA); + + DECLARE_PLATFORM(avgpool3dnew, ENGINE_CUDA); + DECLARE_PLATFORM(avgpool3dnew_bp, ENGINE_CUDA); + + DECLARE_PLATFORM(maxpool3dnew, ENGINE_CUDA); + DECLARE_PLATFORM(maxpool3dnew_bp, ENGINE_CUDA); + ////////////////////////////////////////////////////////////////////////// FORCEINLINE cudnnDataType_t cudnnDataType(nd4j::DataType dataType) { switch (dataType) { @@ -65,91 +77,62 @@ FORCEINLINE cudnnDataType_t cudnnDataType(nd4j::DataType dataType) { } ////////////////////////////////////////////////////////////////////////// -FORCEINLINE void checkConv2dCUDNNPadAsymmetric(NDArray* &input, NDArray* &gradI, - const int iH, const int iW, - const int oH, const int oW, - const int kH, const int kW, - const int sH, const int sW, - const int pH, const int pW, - const int dH, const int dW, - const bool isNCHW) { - - const auto pHsum = ((oH - 1) * sH + ((kH - 1) * dH + 1) - iH); - const auto pWsum = ((oW - 1) * sW + ((kW - 1) * dW + 1) - iW); - - const bool isPHasymm = pH != (pHsum - pH); - const bool isPWasymm = pW != (pWsum - pW); - - if(!isPHasymm && !isPWasymm) - return; - - std::vector newShape = input->getShapeAsVector(); - - const int iHposition = isNCHW ? 2 : 1; - - if(isPHasymm) - newShape[iHposition] += 1; - if(isPWasymm) - newShape[iHposition + 1] += 1; - - NDArray* newInput = new NDArray(input->ordering(), newShape, input->dataType(), input->getContext()); - - if(isNCHW) - (*newInput)({0,0, 0,0, 0,input->sizeAt(2), 0,input->sizeAt(3)}).assign(input); - else - (*newInput)({0,0, 0,input->sizeAt(1), 0,input->sizeAt(2), 0,0}).assign(input); - - input = newInput; - - if(gradI != nullptr) - gradI = new NDArray(gradI->ordering(), newShape, gradI->dataType(), gradI->getContext()); -} - +void checkConv2dCUDNNPadAsymmetric(NDArray* &input, NDArray* &gradI, + const int iH, const int iW, + const int oH, const int oW, + const int kH, const int kW, + const int sH, const int sW, + const int pH, const int pW, + const int dH, const int dW, + const bool isNCHW); ////////////////////////////////////////////////////////////////////////// -FORCEINLINE void checkConv3dCUDNNPadAsymmetric(NDArray* &input, NDArray* &gradI, - const int iD, const int iH, const int iW, - const int oD, const int oH, const int oW, - const int kD, const int kH, const int kW, - const int sD, const int sH, const int sW, - const int pD, const int pH, const int pW, - const int dD, const int dH, const int dW, - const bool isNCDHW) { +void checkConv3dCUDNNPadAsymmetric(NDArray* &input, NDArray* &gradI, + const int iD, const int iH, const int iW, + const int oD, const int oH, const int oW, + const int kD, const int kH, const int kW, + const int sD, const int sH, const int sW, + const int pD, const int pH, const int pW, + const int dD, const int dH, const int dW, + const bool isNCDHW); - const auto pDsum = ((oD - 1) * sD + ((kD - 1) * dD + 1) - iD); - const auto pHsum = ((oH - 1) * sH + ((kH - 1) * dH + 1) - iH); - const auto pWsum = ((oW - 1) * sW + ((kW - 1) * dW + 1) - iW); +////////////////////////////////////////////////////////////////////////// +void pooling2dCUDNN(const LaunchContext* context, + const NDArray* input, NDArray* output, + const int kH, const int kW, + const int sH, const int sW, + const int pH, const int pW, + const int dH, const int dW, + const bool isNCHW, const cudnnPoolingMode_t mode); - const bool isPDasymm = pD != (pDsum - pD); - const bool isPHasymm = pH != (pHsum - pH); - const bool isPWasymm = pW != (pWsum - pW); +////////////////////////////////////////////////////////////////////////// +void pooling2dBpCUDNN(const LaunchContext* context, + const NDArray* input, const NDArray* gradO, + NDArray* gradI, + const int kH, const int kW, + const int sH, const int sW, + const int pH, const int pW, + const int dH, const int dW, + const bool isNCHW, const cudnnPoolingMode_t mode); - if(!isPDasymm && !isPHasymm && !isPWasymm) - return; +////////////////////////////////////////////////////////////////////////// +void pooling3dCUDNN(const LaunchContext* context, + const NDArray* input, NDArray* output, + const int kD, const int kH, const int kW, + const int sD, const int sH, const int sW, + const int pD, const int pH, const int pW, + const int dD, const int dH, const int dW, + const bool isNCDHW, const cudnnPoolingMode_t mode); - std::vector newShape = input->getShapeAsVector(); - - const int iDposition = isNCDHW ? 2 : 1; - - if(isPDasymm) - newShape[iDposition] += 1; - if(isPHasymm) - newShape[iDposition + 1] += 1; - if(isPWasymm) - newShape[iDposition + 2] += 1; - - NDArray* newInput = new NDArray(input->ordering(), newShape, input->dataType(), input->getContext()); - - if(isNCDHW) - (*newInput)({0,0, 0,0, 0,input->sizeAt(2), 0,input->sizeAt(3), 0,input->sizeAt(4)}).assign(input); - else - (*newInput)({0,0, 0,input->sizeAt(1), 0,input->sizeAt(2), 0,input->sizeAt(3), 0,0}).assign(input); - - input = newInput; - - if(gradI != nullptr) - gradI = new NDArray(gradI->ordering(), newShape, gradI->dataType(), gradI->getContext()); -} +////////////////////////////////////////////////////////////////////////// +void pooling3dBpCUDNN(const LaunchContext* context, + const NDArray* input, const NDArray* gradO, + NDArray* gradI, + const int kD, const int kH, const int kW, + const int sD, const int sH, const int sW, + const int pD, const int pH, const int pW, + const int dD, const int dH, const int dW, + const bool isNCDHW, const cudnnPoolingMode_t mode); } } diff --git a/libnd4j/include/ops/declarable/platform/cudnn/maxpool2d.cu b/libnd4j/include/ops/declarable/platform/cudnn/maxpool2d.cu new file mode 100644 index 000000000..6d5affe79 --- /dev/null +++ b/libnd4j/include/ops/declarable/platform/cudnn/maxpool2d.cu @@ -0,0 +1,132 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// + + +#include "cudnnUtils.h" +#include + +namespace nd4j { +namespace ops { +namespace platforms { + + +////////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(maxpool2d, ENGINE_CUDA) { + + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - paddingModee; + const auto kH = INT_ARG(0); + const auto kW = INT_ARG(1); + const auto sH = INT_ARG(2); + const auto sW = INT_ARG(3); + auto pH = INT_ARG(4); + auto pW = INT_ARG(5); + const auto dH = INT_ARG(6); + const auto dW = INT_ARG(7); + const auto paddingMode = static_cast(INT_ARG(8)); + const int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC + + REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D CUDNN op: input should have rank of 4, but got %i instead", input->rankOf()); + REQUIRE_TRUE(dH != 0 && dW != 0, 0, "MAXPOOL2D CUDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW); + + int oH = 0; + int oW = 0; + + const int iH = static_cast(isNCHW ? input->sizeAt(2) : input->sizeAt(1)); + const int iW = static_cast(isNCHW ? input->sizeAt(3) : input->sizeAt(2)); + + ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, paddingMode); + + if (paddingMode) + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); + + pooling2dCUDNN(block.launchContext(), input, output, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW, CUDNN_POOLING_MAX); + + return Status::OK(); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_CHECK(maxpool2d, ENGINE_CUDA) { + + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32; + + return goodType && input->dataType() == output->dataType(); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(maxpool2d_bp, ENGINE_CUDA) { + + auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + + const auto kH = INT_ARG(0); // filter(kernel) height + const auto kW = INT_ARG(1); // filter(kernel) width + const auto sH = INT_ARG(2); // strides height + const auto sW = INT_ARG(3); // strides width + auto pH = INT_ARG(4); // paddings height + auto pW = INT_ARG(5); // paddings width + const auto dH = INT_ARG(6); // dilations height + const auto dW = INT_ARG(7); // dilations width + const auto paddingMode = INT_ARG(8); // 0-VALID, 1-SAME + const auto isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC + + REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D_BP CUDNN op: input should have rank of 4, but got %i instead", input->rankOf()); + REQUIRE_TRUE(dH != 0 && dW != 0, 0, "MAXPOOL2D_BP CUDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW); + + int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + + std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1}); + std::vector expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW, 0,indIOioC,indIiH,indIiH+1}); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "MAXPOOL2D_BP CUDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, "MAXPOOL2D_BP CUDNN op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradI).c_str()); + + if(paddingMode) // SAME + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); + + pooling2dBpCUDNN(block.launchContext(), input, gradO, gradI, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW, CUDNN_POOLING_MAX); + + return Status::OK(); +} + +PLATFORM_CHECK(maxpool2d_bp, ENGINE_CUDA) { + + auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + + const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32; + + return goodType && (input->dataType() == gradO->dataType()) + && (input->dataType() == gradI->dataType()) + && shape::haveSameShapeAndStrides(input->getShapeInfo(), gradI->getShapeInfo()); +} + + +} +} +} diff --git a/libnd4j/include/ops/declarable/platform/cudnn/maxpool3d.cu b/libnd4j/include/ops/declarable/platform/cudnn/maxpool3d.cu new file mode 100644 index 000000000..fc2e38577 --- /dev/null +++ b/libnd4j/include/ops/declarable/platform/cudnn/maxpool3d.cu @@ -0,0 +1,140 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// + + +#include "cudnnUtils.h" +#include + +namespace nd4j { +namespace ops { +namespace platforms { + + +////////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(maxpool3dnew, ENGINE_CUDA) { + + auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW) + + int kD = INT_ARG(0); // filter(kernel) depth + int kH = INT_ARG(1); // filter(kernel) height + int kW = INT_ARG(2); // filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID + // int extraParam0 = INT_ARG(13); + int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC + + REQUIRE_TRUE(input->rankOf() == 5, 0, "MAXPOOL3DNEW CUDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); + REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "MAXPOOL3DNEW CUDNN OP: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); + + int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + + std::vector expectedOutputShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); + REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, "MAXPOOL3DNEW CUDNN OP: wrong shape of output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedOutputShape).c_str(), ShapeUtils::shapeAsString(output).c_str()); + + if(paddingMode) // SAME + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); + + pooling3dCUDNN(block.launchContext(), input, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW, CUDNN_POOLING_MAX); + + return Status::OK(); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_CHECK(maxpool3dnew, ENGINE_CUDA) { + + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32; + + return goodType && input->dataType() == output->dataType(); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(maxpool3dnew_bp, ENGINE_CUDA) { + + auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next + auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon + + const int kD = INT_ARG(0); // filter(kernel) depth + const int kH = INT_ARG(1); // filter(kernel) height + const int kW = INT_ARG(2); // filter(kernel) width + const int sD = INT_ARG(3); // strides depth + const int sH = INT_ARG(4); // strides height + const int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + const int dD = INT_ARG(9); // dilations depth + const int dH = INT_ARG(10); // dilations height + const int dW = INT_ARG(11); // dilations width + const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID + // const int extraParam0 = INT_ARG(13); // define what divisor to use while averaging + const int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC + + REQUIRE_TRUE(input->rankOf() == 5, 0, "MAXPOOL3DNEW_BP CUDNN OP: input should have rank of 5, but got %i instead", input->rankOf()); + REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "MAXPOOL3DNEW_BP CUDNN OP: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); + + int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + + std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); + std::vector expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "MAXPOOL3DNEW_BP CUDNN: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, "MAXPOOL3DNEW_BP CUDNN: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradI).c_str()); + + if(isSameMode) // SAME + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); + + pooling3dBpCUDNN(block.launchContext(), input, gradO, gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW, CUDNN_POOLING_MAX); + + return Status::OK(); +} + +PLATFORM_CHECK(maxpool3dnew_bp, ENGINE_CUDA) { + + auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next + auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon + + const auto goodType = input->dataType() == DataType::DOUBLE || input->dataType() == DataType::FLOAT32 || input->dataType() == DataType::HALF || input->dataType() == DataType::INT32; + + return goodType && (input->dataType() == gradO->dataType()) + && (input->dataType() == gradI->dataType()) + && shape::haveSameShapeAndStrides(input->getShapeInfo(), gradI->getShapeInfo()); +} + + +} +} +} diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d.cpp index bf614bfab..1c1e9d6a4 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d.cpp @@ -30,111 +30,231 @@ using namespace dnnl; using namespace samediff; -namespace nd4j { - namespace ops { - namespace platforms { - PLATFORM_IMPL(avgpool2d, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); +namespace nd4j { +namespace ops { +namespace platforms { - REQUIRE_TRUE(input->rankOf() == 4, 0, "Input should have rank of 4, but got %i instead", - input->rankOf()); +////////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(avgpool2d, ENGINE_CPU) { + auto input = INPUT_VARIABLE(0); - // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; - auto argI = *(block.getIArguments()); - auto output = OUTPUT_VARIABLE(0); + REQUIRE_TRUE(input->rankOf() == 4, 0, "Input should have rank of 4, but got %i instead", + input->rankOf()); - const auto kH = INT_ARG(0); - const auto kW = INT_ARG(1); - const auto sH = INT_ARG(2); - const auto sW = INT_ARG(3); - int pH = INT_ARG(4); - int pW = INT_ARG(5); - const auto dH = INT_ARG(6); - const auto dW = INT_ARG(7); - const auto isSameMode = static_cast(INT_ARG(8)); - const auto extraParam0 = INT_ARG(9); + // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + auto argI = *(block.getIArguments()); + auto output = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D op: dilation must not be zero, but got instead {%i, %i}", - dH, dW); + const auto kH = INT_ARG(0); + const auto kW = INT_ARG(1); + const auto sH = INT_ARG(2); + const auto sW = INT_ARG(3); + int pH = INT_ARG(4); + int pW = INT_ARG(5); + const auto dH = INT_ARG(6); + const auto dW = INT_ARG(7); + const auto isSameMode = static_cast(INT_ARG(8)); + const auto extraParam0 = INT_ARG(9); - int oH = 0; - int oW = 0; + REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D op: dilation must not be zero, but got instead {%i, %i}", + dH, dW); - int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC + int oH = 0; + int oW = 0; - const int iH = static_cast(isNCHW ? input->sizeAt(2) : input->sizeAt(1)); - const int iW = static_cast(isNCHW ? input->sizeAt(3) : input->sizeAt(2)); + int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC - if (!isNCHW) { - input = new NDArray( - input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - output = new NDArray( - output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW] - } + const int iH = static_cast(isNCHW ? input->sizeAt(2) : input->sizeAt(1)); + const int iW = static_cast(isNCHW ? input->sizeAt(3) : input->sizeAt(2)); - ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); - - if (isSameMode) - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - - const int bS = input->sizeAt(0); - const int iC = input->sizeAt(1); - const int oC = output->sizeAt(1); - - auto poolingMode = PoolingType::AVG_POOL; - - dnnl_memory_desc_t empty; - dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty); - dnnl::memory::desc user_src_md(empty), user_dst_md(empty); - dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; - dnnl::algorithm algorithm; - mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0, - true, - bS, iC, iH, iW, oC, oH, oW, input, nullptr, output, - algorithm, - &pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr, - &user_dst_md, - pool_strides, pool_kernel, pool_padding, pool_padding_r); - auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md, - pool_dst_md, - pool_strides, pool_kernel, pool_padding, pool_padding_r); - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine); - auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer()); - auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer()); - auto pool_src_memory = user_src_memory; - dnnl::stream stream(engine); - if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) { - pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine); - reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory); - } - auto pool_dst_memory = user_dst_memory; - if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) { - pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine); - } - pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory}, - {DNNL_ARG_DST, pool_dst_memory}}); - if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) { - reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory); - } - stream.wait(); - - //streams[0].submitAndWait(); - - if (!isNCHW) { - delete input; - delete output; - } - - return Status::OK(); - } - - PLATFORM_CHECK(avgpool2d, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output}); - } - } + if (!isNCHW) { + input = new NDArray( + input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] + output = new NDArray( + output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW] } + + ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); + + if (isSameMode) + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); + + const int bS = input->sizeAt(0); + const int iC = input->sizeAt(1); + const int oC = output->sizeAt(1); + + auto poolingMode = PoolingType::AVG_POOL; + + dnnl_memory_desc_t empty; + dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty); + dnnl::memory::desc user_src_md(empty), user_dst_md(empty); + dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; + dnnl::algorithm algorithm; + mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0, + true, + bS, iC, iH, iW, oC, oH, oW, input, nullptr, output, + algorithm, + &pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr, + &user_dst_md, + pool_strides, pool_kernel, pool_padding, pool_padding_r); + auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md, + pool_dst_md, + pool_strides, pool_kernel, pool_padding, pool_padding_r); + auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine); + auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer()); + auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer()); + auto pool_src_memory = user_src_memory; + dnnl::stream stream(engine); + if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) { + pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine); + reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory); + } + auto pool_dst_memory = user_dst_memory; + if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) { + pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine); + } + pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory}, + {DNNL_ARG_DST, pool_dst_memory}}); + if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) { + reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory); + } + stream.wait(); + + //streams[0].submitAndWait(); + + if (!isNCHW) { + delete input; + delete output; + } + + return Status::OK(); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_CHECK(avgpool2d, ENGINE_CPU) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output}); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(avgpool2d_bp, ENGINE_CPU) { + auto input = INPUT_VARIABLE( + 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto gradO = INPUT_VARIABLE( + 1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + auto gradI = OUTPUT_VARIABLE( + 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + + int kH = INT_ARG(0); // filter(kernel) height + int kW = INT_ARG(1); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int extraParam0 = INT_ARG(9); + int isNCHW = + block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC + + REQUIRE_TRUE(input->rankOf() == 4, 0, + "AVGPOOL2D_BP op: input should have rank of 4, but got %i instead", input->rankOf()); + REQUIRE_TRUE(dH != 0 && dW != 0, 0, + "AVGPOOL2D_BP op: dilation must not be zero, but got instead {%i, %i}", dH, dW); + + int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWoC, indWkH, indOoH); + + std::string expectedGradOShape = ShapeUtils::shapeAsString( + ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1})); + std::string expectedGradIShape = ShapeUtils::shapeAsString( + ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, iH, iW, 0, indIOioC, indIiH, indIiH + 1})); + REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, + "AVGPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", + expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, + "AVGPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", + expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str()); + + + if (!isNCHW) { + input = new NDArray(input->permute( + {0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] + gradI = new NDArray(gradI->permute( + {0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] + gradO = new NDArray(gradO->permute( + {0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW] + } + + if (isSameMode) // SAME + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); + + auto poolingMode = PoolingType::AVG_POOL; + + dnnl_memory_desc_t empty; + dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty); + dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty); + dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; + dnnl::algorithm algorithm; + mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0, + true, + bS, iC, iH, iW, oC, oH, oW, input, gradI, gradO, algorithm, + &pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md, + &user_diff_src_md, &user_dst_md, + pool_strides, pool_kernel, pool_padding, pool_padding_r); + auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm, + input->buffer() != nullptr ? pool_src_md : pool_diff_src_md, + pool_dst_md, pool_strides, pool_kernel, pool_padding, + pool_padding_r); + auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine); + auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides, + pool_kernel, pool_padding, pool_padding_r); + auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc); + auto userB_src_memory = dnnl::memory(user_src_md, engine, gradI->buffer()); + auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer()); + auto poolB_src_memory = userB_src_memory; + dnnl::stream stream(engine); + if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) { + poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine); + } + auto poolB_dst_memory = userB_dst_memory; + if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) { + poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine); + reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory); + } + pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory}, + {DNNL_ARG_DIFF_SRC, poolB_src_memory}}); + if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) { + reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory); + } + stream.wait(); + + if (!isNCHW) { + delete input; + delete gradI; + delete gradO; + } + + + return Status::OK(); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_CHECK(avgpool2d_bp, ENGINE_CPU) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output}); +} + + +} +} } \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d_bp.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d_bp.cpp deleted file mode 100644 index af1fd04fd..000000000 --- a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d_bp.cpp +++ /dev/null @@ -1,149 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author saudet -// @author raver119@gmail.com -// - -#include -#include -#include - -#include -#include "mkldnnUtils.h" -#include - -using namespace dnnl; - -namespace nd4j { - namespace ops { - namespace platforms { - PLATFORM_IMPL(avgpool2d_bp, ENGINE_CPU) { - auto input = INPUT_VARIABLE( - 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto gradO = INPUT_VARIABLE( - 1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - auto gradI = OUTPUT_VARIABLE( - 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - - int kH = INT_ARG(0); // filter(kernel) height - int kW = INT_ARG(1); // filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int extraParam0 = INT_ARG(9); - int isNCHW = - block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC - - REQUIRE_TRUE(input->rankOf() == 4, 0, - "AVGPOOL2D_BP op: input should have rank of 4, but got %i instead", input->rankOf()); - REQUIRE_TRUE(dH != 0 && dW != 0, 0, - "AVGPOOL2D_BP op: dilation must not be zero, but got instead {%i, %i}", dH, dW); - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, - indIiH, indWiC, indWoC, indWkH, indOoH); - - std::string expectedGradOShape = ShapeUtils::shapeAsString( - ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1})); - std::string expectedGradIShape = ShapeUtils::shapeAsString( - ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, iH, iW, 0, indIOioC, indIiH, indIiH + 1})); - REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, - "AVGPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", - expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, - "AVGPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", - expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str()); - - - if (!isNCHW) { - input = new NDArray(input->permute( - {0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - gradI = new NDArray(gradI->permute( - {0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - gradO = new NDArray(gradO->permute( - {0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW] - } - - if (isSameMode) // SAME - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - - auto poolingMode = PoolingType::AVG_POOL; - - dnnl_memory_desc_t empty; - dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty); - dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty); - dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; - dnnl::algorithm algorithm; - mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0, - true, - bS, iC, iH, iW, oC, oH, oW, input, gradI, gradO, algorithm, - &pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md, - &user_diff_src_md, &user_dst_md, - pool_strides, pool_kernel, pool_padding, pool_padding_r); - auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm, - input->buffer() != nullptr ? pool_src_md : pool_diff_src_md, - pool_dst_md, pool_strides, pool_kernel, pool_padding, - pool_padding_r); - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine); - auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides, - pool_kernel, pool_padding, pool_padding_r); - auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc); - auto userB_src_memory = dnnl::memory(user_src_md, engine, gradI->buffer()); - auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer()); - auto poolB_src_memory = userB_src_memory; - dnnl::stream stream(engine); - if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) { - poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine); - } - auto poolB_dst_memory = userB_dst_memory; - if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) { - poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine); - reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory); - } - pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory}, - {DNNL_ARG_DIFF_SRC, poolB_src_memory}}); - if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) { - reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory); - } - stream.wait(); - - if (!isNCHW) { - delete input; - delete gradI; - delete gradO; - } - - - return Status::OK(); - } - - PLATFORM_CHECK(avgpool2d_bp, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output}); - } - } - } -} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp index ba1711032..559edf2cd 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp @@ -34,24 +34,23 @@ namespace ops { namespace platforms { ////////////////////////////////////////////////////////////////////// -static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, const NDArray *weights, +static void conv2dMKLDNN(nd4j::graph::Context &block, const NDArray *input, const NDArray *weights, const NDArray *bias, NDArray *output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, - indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); dnnl_memory_desc_t empty; - dnnl::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md( - empty); - dnnl::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md( - empty); + dnnl::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty); + dnnl::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md(empty); + dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation; + mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, bS, iC, iH, iW, oC, oH, oW, input, nullptr, weights, nullptr, bias, output, @@ -61,13 +60,12 @@ static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, con &user_bias_md, &user_dst_md, conv_strides, conv_padding, conv_padding_r, conv_dilation); - auto conv_desc = bias != nullptr - ? convolution_forward::desc(prop_kind::forward, + auto conv_desc = bias != nullptr ? convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, conv_src_md, conv_weights_md, conv_bias_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r) - : convolution_forward::desc(prop_kind::forward, + : convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, conv_src_md, conv_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, @@ -112,6 +110,135 @@ static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, con stream.wait(); } +////////////////////////////////////////////////////////////////////// +static void conv2dBpMKLDNN(nd4j::graph::Context &block, + const NDArray *input, const NDArray *weights, const NDArray *bias, const NDArray *gradO, + NDArray *gradI, NDArray *gradW, NDArray *gradB, + const int kH, const int kW, const int sH,const int sW, int pH, int pW, const int dH, const int dW, + const int paddingMode, const int isNCHW) { + + int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); + + dnnl_memory_desc_t empty; + dnnl::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty), conv_diff_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty); + dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty), user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty); + + dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation; + + mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, + bS, iC, iH, iW, oC, oH, oW, input, gradI, weights, gradW, + gradB, gradO, + &conv_src_md, &conv_diff_src_md, &conv_weights_md, + &conv_diff_weights_md, &conv_bias_md, &conv_dst_md, + &user_src_md, &user_diff_src_md, &user_weights_md, + &user_diff_weights_md, &user_bias_md, &user_dst_md, + conv_strides, conv_padding, conv_padding_r, conv_dilation); + auto conv_desc = gradB != nullptr + ? convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, conv_src_md, conv_weights_md, conv_bias_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r) + : convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, conv_src_md, conv_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r); + + auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, mkldnnUtils::getEngine( LaunchContext::defaultContext()->engine())); + + auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + dnnl::stream stream(engine); + + if (gradW != nullptr) { + auto convW_desc = gradB != nullptr ? convolution_backward_weights::desc(algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, conv_bias_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r) + : convolution_backward_weights::desc(algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r); + + + auto convW_prim_desc = convolution_backward_weights::primitive_desc(convW_desc, engine, conv_prim_desc); + + auto userW_src_memory = dnnl::memory(user_src_md, engine, const_cast(input)->buffer()); + auto userW_weights_memory = dnnl::memory(user_diff_weights_md, engine, gradW->buffer()); + auto userW_dst_memory = dnnl::memory(user_dst_md, engine,const_cast(gradO)->buffer()); + + auto convW_src_memory = userW_src_memory; + + if (convW_prim_desc.src_desc() != userW_src_memory.get_desc()) { + convW_src_memory = dnnl::memory(convW_prim_desc.src_desc(), engine); + reorder(userW_src_memory, convW_src_memory).execute(stream, userW_src_memory,convW_src_memory); + } + + auto convW_weights_memory = userW_weights_memory; + if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) { + convW_weights_memory = dnnl::memory(convW_prim_desc.diff_weights_desc(), engine); + } + + auto convW_dst_memory = userW_dst_memory; + if (convW_prim_desc.diff_dst_desc() != userW_dst_memory.get_desc()) { + convW_dst_memory = dnnl::memory(convW_prim_desc.diff_dst_desc(), engine); + reorder(userW_dst_memory, convW_dst_memory).execute(stream, userW_dst_memory, convW_dst_memory); + } + + if (gradB != nullptr) { + auto convW_bias_memory = dnnl::memory(convW_prim_desc.diff_bias_desc(), engine, gradB->buffer()); + + convolution_backward_weights(convW_prim_desc).execute(stream, + {{DNNL_ARG_SRC, convW_src_memory}, + {DNNL_ARG_DIFF_DST, convW_dst_memory}, + {DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory}, + {DNNL_ARG_DIFF_BIAS, convW_bias_memory}}); + } + else { + convolution_backward_weights(convW_prim_desc).execute(stream, + {{DNNL_ARG_SRC, convW_src_memory}, + {DNNL_ARG_DIFF_DST, convW_dst_memory}, + {DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory}}); + } + + if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) { + reorder(convW_weights_memory, userW_weights_memory).execute(stream, convW_weights_memory, + userW_weights_memory); + } + + stream.wait(); + } + + if (gradI != nullptr) { + + auto convI_desc = convolution_backward_data::desc(algorithm::convolution_auto, conv_diff_src_md, conv_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r); + + + auto convI_prim_desc = convolution_backward_data::primitive_desc(convI_desc, engine, conv_prim_desc); + auto userI_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer()); + auto userI_weights_memory = dnnl::memory(user_weights_md, engine,const_cast(weights)->buffer()); + auto userI_dst_memory = dnnl::memory(user_dst_md, engine, const_cast(gradO)->buffer()); + + auto convI_src_memory = userI_src_memory; + if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) { + convI_src_memory = dnnl::memory(convI_prim_desc.diff_src_desc(), engine); + } + + auto convI_weights_memory = userI_weights_memory; + if (convI_prim_desc.weights_desc() != userI_weights_memory.get_desc()) { + convI_weights_memory = dnnl::memory(convI_prim_desc.weights_desc(), engine); + reorder(userI_weights_memory, convI_weights_memory).execute(stream, userI_weights_memory, convI_weights_memory); + } + + auto convI_dst_memory = userI_dst_memory; + if (convI_prim_desc.diff_dst_desc() != userI_dst_memory.get_desc()) { + convI_dst_memory = dnnl::memory(convI_prim_desc.diff_dst_desc(), engine); + reorder(userI_dst_memory, convI_dst_memory).execute(stream, userI_dst_memory, convI_dst_memory); + } + + convolution_backward_data(convI_prim_desc).execute(stream, + {{DNNL_ARG_DIFF_DST, convI_dst_memory}, + {DNNL_ARG_WEIGHTS, convI_weights_memory}, + {DNNL_ARG_DIFF_SRC, convI_src_memory}}); + + if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) { + reorder(convI_src_memory, userI_src_memory).execute(stream, convI_src_memory, userI_src_memory); + } + + stream.wait(); + } +} + ////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(conv2d, ENGINE_CPU) { auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) @@ -132,7 +259,7 @@ PLATFORM_IMPL(conv2d, ENGINE_CPU) { int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0)); // filter(kernel) height int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1)); // filter(kernel) width - conv2d_mkldnn(block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW); + conv2dMKLDNN(block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW); return Status::OK(); } @@ -152,6 +279,7 @@ PLATFORM_CHECK(conv2d, ENGINE_CPU) { ////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(conv2d_bp, ENGINE_CPU) { + auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] @@ -172,158 +300,11 @@ PLATFORM_IMPL(conv2d_bp, ENGINE_CPU) { int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - REQUIRE_TRUE(input->rankOf() == 4, 0, - "CUSTOM CONV2D_BP OP: rank of input array must be equal to 4, but got %i instead !", - input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 4, 0, - "CUSTOM CONV2D_BP OP: rank of weights array must be equal to 4, but got %i instead !", - weights->rankOf()); - REQUIRE_TRUE(gradO->rankOf() == 4, 0, - "CUSTOM CONV2D_BP OP: rank of output's gradients (next epsilon) array must be equal to 4, but got %i instead !", - gradO->rankOf()); + REQUIRE_TRUE(input->rankOf() == 4, 0,"CUSTOM CONV2D_BP OP: rank of input array must be equal to 4, but got %i instead !",input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 4, 0,"CUSTOM CONV2D_BP OP: rank of weights array must be equal to 4, but got %i instead !",weights->rankOf()); + REQUIRE_TRUE(gradO->rankOf() == 4, 0,"CUSTOM CONV2D_BP OP: rank of output's gradients (next epsilon) array must be equal to 4, but got %i instead !",gradO->rankOf()); - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, - indIiH, indWiC, indWoC, indWkH, indOoH); - - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); - - dnnl_memory_desc_t empty; - dnnl::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty), - conv_diff_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty); - dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty), - user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty); - dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation; - mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, - bS, iC, iH, iW, oC, oH, oW, input, gradI, weights, gradW, - gradB, gradO, - &conv_src_md, &conv_diff_src_md, &conv_weights_md, - &conv_diff_weights_md, &conv_bias_md, &conv_dst_md, - &user_src_md, &user_diff_src_md, &user_weights_md, - &user_diff_weights_md, &user_bias_md, &user_dst_md, - conv_strides, conv_padding, conv_padding_r, conv_dilation); - auto conv_desc = gradB != nullptr - ? convolution_forward::desc(prop_kind::forward, - algorithm::convolution_auto, conv_src_md, - conv_weights_md, conv_bias_md, - conv_dst_md, conv_strides, conv_dilation, conv_padding, - conv_padding_r) - : convolution_forward::desc(prop_kind::forward, - algorithm::convolution_auto, conv_src_md, - conv_weights_md, - conv_dst_md, conv_strides, conv_dilation, conv_padding, - conv_padding_r); - auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, mkldnnUtils::getEngine( - LaunchContext::defaultContext()->engine())); - if (gradW != nullptr) { - auto convW_desc = gradB != nullptr - ? convolution_backward_weights::desc( - algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, conv_bias_md, - conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r) - : convolution_backward_weights::desc( - algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, - conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - dnnl::stream stream(engine); - auto convW_prim_desc = convolution_backward_weights::primitive_desc(convW_desc, engine, - conv_prim_desc); - auto userW_src_memory = dnnl::memory(user_src_md, engine, - const_cast(input)->buffer()); - auto userW_weights_memory = dnnl::memory(user_diff_weights_md, engine, gradW->buffer()); - auto userW_dst_memory = dnnl::memory(user_dst_md, engine, - const_cast(gradO)->buffer()); - - auto convW_src_memory = userW_src_memory; - if (convW_prim_desc.src_desc() != userW_src_memory.get_desc()) { - convW_src_memory = dnnl::memory(convW_prim_desc.src_desc(), engine); - reorder(userW_src_memory, convW_src_memory).execute(stream, userW_src_memory, - convW_src_memory); - } - - auto convW_weights_memory = userW_weights_memory; - if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) { - convW_weights_memory = dnnl::memory(convW_prim_desc.diff_weights_desc(), engine); - } - - auto convW_dst_memory = userW_dst_memory; - if (convW_prim_desc.diff_dst_desc() != userW_dst_memory.get_desc()) { - convW_dst_memory = dnnl::memory(convW_prim_desc.diff_dst_desc(), engine); - reorder(userW_dst_memory, convW_dst_memory).execute(stream, userW_dst_memory, - convW_dst_memory); - } - - if (gradB != nullptr) { - auto convW_bias_memory = dnnl::memory(convW_prim_desc.diff_bias_desc(), engine, - gradB->buffer()); - convolution_backward_weights(convW_prim_desc).execute(stream, - {{DNNL_ARG_SRC, convW_src_memory}, - {DNNL_ARG_DIFF_DST, convW_dst_memory}, - {DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory}, - {DNNL_ARG_DIFF_BIAS, convW_bias_memory}}); - } else { - convolution_backward_weights(convW_prim_desc).execute(stream, - {{DNNL_ARG_SRC, convW_src_memory}, - {DNNL_ARG_DIFF_DST, convW_dst_memory}, - {DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory}}); - } - - if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) { - reorder(convW_weights_memory, userW_weights_memory).execute(stream, convW_weights_memory, - userW_weights_memory); - } - - stream.wait(); - } - - if (gradI != nullptr) { - auto convI_desc = - convolution_backward_data::desc(algorithm::convolution_auto, conv_diff_src_md, - conv_weights_md, conv_dst_md, conv_strides, conv_dilation, - conv_padding, conv_padding_r); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - dnnl::stream stream(engine); - auto convI_prim_desc = convolution_backward_data::primitive_desc(convI_desc, engine, - conv_prim_desc); - auto userI_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer()); - auto userI_weights_memory = dnnl::memory(user_weights_md, engine, - const_cast(weights)->buffer()); - auto userI_dst_memory = dnnl::memory(user_dst_md, engine, - const_cast(gradO)->buffer()); - - auto convI_src_memory = userI_src_memory; - if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) { - convI_src_memory = dnnl::memory(convI_prim_desc.diff_src_desc(), engine); - } - - auto convI_weights_memory = userI_weights_memory; - if (convI_prim_desc.weights_desc() != userI_weights_memory.get_desc()) { - convI_weights_memory = dnnl::memory(convI_prim_desc.weights_desc(), engine); - reorder(userI_weights_memory, convI_weights_memory).execute(stream, userI_weights_memory, - convI_weights_memory); - } - - auto convI_dst_memory = userI_dst_memory; - if (convI_prim_desc.diff_dst_desc() != userI_dst_memory.get_desc()) { - convI_dst_memory = dnnl::memory(convI_prim_desc.diff_dst_desc(), engine); - reorder(userI_dst_memory, convI_dst_memory).execute(stream, userI_dst_memory, - convI_dst_memory); - } - - convolution_backward_data(convI_prim_desc).execute(stream, - {{DNNL_ARG_DIFF_DST, convI_dst_memory}, - {DNNL_ARG_WEIGHTS, convI_weights_memory}, - {DNNL_ARG_DIFF_SRC, convI_src_memory}}); - - if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) { - reorder(convI_src_memory, userI_src_memory).execute(stream, convI_src_memory, - userI_src_memory); - } - - stream.wait(); - }; + conv2dBpMKLDNN(block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp index 0a79df793..747d84c36 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp @@ -34,62 +34,23 @@ namespace ops { namespace platforms { ////////////////////////////////////////////////////////////////////// -PLATFORM_IMPL(conv3dnew, ENGINE_CPU) { - auto input = INPUT_VARIABLE( - 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always - auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto output = OUTPUT_VARIABLE( - 0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW) - - REQUIRE_TRUE(input->rankOf() == 5, 0, - "CUSTOM CONV3D OP: rank of input array must be equal to 5, but got %i instead !", - input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 5, 0, - "CUSTOM CONV3D OP: rank of weights array must be equal to 5, but got %i instead !", - weights->rankOf()); - - int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) depth - int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) height - int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast(weights->sizeAt(2));// filter(kernel) width - int sD = INT_ARG(3); // strides depth - int sH = INT_ARG(4); // strides height - int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID - int isNCDHW = - block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW +static void conv3dMKLDNN(nd4j::graph::Context &block, + const NDArray *input, const NDArray *weights, const NDArray *bias, + NDArray *output, + const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, int pD, int pH, int pW, const int dD, const int dH, const int dW, + const int paddingMode, const int isNCDHW) { int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, - indIOioC, indIOioD, indWiC, indWoC, indWkD); - - std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC}); - REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, - "CUSTOM CONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", - expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if (bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, - "CUSTOM CONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", - oC, bias->rankOf(), bias->lengthOf()); - - if (isSameMode) // SAME - ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); - + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); dnnl_memory_desc_t empty; - dnnl::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md( - empty); - dnnl::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md( - empty); + dnnl::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md( empty); + dnnl::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md( empty); + dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation; - mkldnnUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isSameMode, + + mkldnnUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, isNCDHW, bS, iC, iD, iH, iW, oC, oD, oH, oW, input, nullptr, weights, nullptr, bias, output, @@ -98,151 +59,73 @@ PLATFORM_IMPL(conv3dnew, ENGINE_CPU) { &user_src_md, nullptr, &user_weights_md, nullptr, &user_bias_md, &user_dst_md, conv_strides, conv_padding, conv_padding_r, conv_dilation); - auto conv_desc = bias != nullptr - ? convolution_forward::desc(prop_kind::forward, - algorithm::convolution_auto, conv_src_md, - conv_weights_md, conv_bias_md, - conv_dst_md, conv_strides, conv_dilation, conv_padding, - conv_padding_r) - : convolution_forward::desc(prop_kind::forward, - algorithm::convolution_auto, conv_src_md, - conv_weights_md, - conv_dst_md, conv_strides, conv_dilation, conv_padding, - conv_padding_r); + auto conv_desc = bias != nullptr ? convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, conv_src_md, conv_weights_md, conv_bias_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r) + : convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, conv_src_md, conv_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r); + auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); dnnl::stream stream(engine); + auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, engine); auto user_src_memory = dnnl::memory(user_src_md, engine, const_cast(input)->buffer()); - auto user_weights_memory = dnnl::memory(user_weights_md, engine, - const_cast(weights)->buffer()); + auto user_weights_memory = dnnl::memory(user_weights_md, engine, const_cast(weights)->buffer()); auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer()); + auto conv_src_memory = user_src_memory; if (conv_prim_desc.src_desc() != user_src_memory.get_desc()) { conv_src_memory = dnnl::memory(conv_prim_desc.src_desc(), engine); reorder(user_src_memory, conv_src_memory).execute(stream, user_src_memory, conv_src_memory); } + auto conv_weights_memory = user_weights_memory; if (conv_prim_desc.weights_desc() != user_weights_memory.get_desc()) { conv_weights_memory = dnnl::memory(conv_prim_desc.weights_desc(), engine); - reorder(user_weights_memory, conv_weights_memory).execute(stream, user_weights_memory, - conv_weights_memory); + reorder(user_weights_memory, conv_weights_memory).execute(stream, user_weights_memory, conv_weights_memory); } + auto conv_dst_memory = user_dst_memory; if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) { conv_dst_memory = dnnl::memory(conv_prim_desc.dst_desc(), engine); } + if (bias != nullptr) { - auto conv_bias_memory = dnnl::memory(conv_prim_desc.bias_desc(), engine, bias->buffer()); + auto conv_bias_memory = dnnl::memory(conv_prim_desc.bias_desc(), engine, bias->getBuffer()); convolution_forward(conv_prim_desc).execute(stream, {{DNNL_ARG_SRC, conv_src_memory}, {DNNL_ARG_WEIGHTS, conv_weights_memory}, {DNNL_ARG_BIAS, conv_bias_memory}, {DNNL_ARG_DST, conv_dst_memory}}); - } else { + } + else { convolution_forward(conv_prim_desc).execute(stream, {{DNNL_ARG_SRC, conv_src_memory}, {DNNL_ARG_WEIGHTS, conv_weights_memory}, {DNNL_ARG_DST, conv_dst_memory}}); } - if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) { + + if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) reorder(conv_dst_memory, user_dst_memory).execute(stream, conv_dst_memory, user_dst_memory); - } + stream.wait(); - - return Status::OK(); -} - -PLATFORM_CHECK(conv3dnew, ENGINE_CPU) { - // we don't want to use mkldnn if cpu doesn't support avx/avx2 - if (::optimalLevel() < 2) - return false; - - auto input = INPUT_VARIABLE( - 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always - auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto output = OUTPUT_VARIABLE( - 0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW) - - return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, weights, bias, output}); } ////////////////////////////////////////////////////////////////////// -PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) { - auto input = INPUT_VARIABLE( - 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE( - 1); // [kD, kH, kW, iC, oC] always - auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE( - 2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next - - auto gradI = OUTPUT_VARIABLE( - 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon - auto gradW = OUTPUT_VARIABLE( - 1); // [kD, kH, kW, iC, oC] always - auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] - - REQUIRE_TRUE(input->rankOf() == 5, 0, - "CUSTOM CONV3D_BP OP: rank of input array must be equal to 5, but got %i instead !", - input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 5, 0, - "CUSTOM CONV3D_BP OP: rank of weights array must be equal to 5, but got %i instead !", - weights->rankOf()); - REQUIRE_TRUE(gradO->rankOf() == 5, 0, - "CUSTOM CONV3D_BP OP: rank of output gradients (next epsilon) array must be equal to 5, but got %i instead !", - gradO->rankOf()); - - int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) depth - int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) height - int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast(weights->sizeAt(2));// filter(kernel) width - int sD = INT_ARG(3); // strides depth - int sH = INT_ARG(4); // strides height - int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID - int isNDHWC = - block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW +static void conv3dBpMKLDNN(nd4j::graph::Context &block, + const NDArray *input, const NDArray *weights, const NDArray *bias, const NDArray *gradO, + NDArray *gradI, NDArray *gradW, NDArray *gradB, + const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, int pD, int pH, int pW, const int dD, const int dH, const int dW, + const int paddingMode, const int isNCDHW) { int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNDHWC, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, - indIOioC, indIOioD, indWiC, indWoC, indWkD); - - if(isSameMode) // SAME - ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); - - int trueoD, trueoH, trueoW; // true output depth/height/width - ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, - dW, iD, iH, iW, isSameMode); - - std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx( - {bS, oC, trueoD, trueoH, trueoW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2})); - std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC}); - REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, - "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", - expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, - "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", - expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str()); - if (bias) - REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, - "CUSTOM CONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", - oC, bias->rankOf(), bias->lengthOf()); - + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); dnnl_memory_desc_t empty; - dnnl::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty), - conv_diff_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty); - dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty), - user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty); + dnnl::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty), conv_diff_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty); + dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty), user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty); + dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation; - mkldnnUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isSameMode, - isNDHWC, + + mkldnnUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, + isNCDHW, bS, iC, iD, iH, iW, oC, oD, oH, oW, input, gradI, weights, gradW, gradB, gradO, &conv_src_md, &conv_diff_src_md, &conv_weights_md, @@ -250,43 +133,30 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) { &user_src_md, &user_diff_src_md, &user_weights_md, &user_diff_weights_md, &user_bias_md, &user_dst_md, conv_strides, conv_padding, conv_padding_r, conv_dilation); - auto conv_desc = gradB != nullptr - ? convolution_forward::desc(prop_kind::forward, - algorithm::convolution_auto, conv_src_md, - conv_weights_md, conv_bias_md, - conv_dst_md, conv_strides, conv_dilation, conv_padding, - conv_padding_r) - : convolution_forward::desc(prop_kind::forward, - algorithm::convolution_auto, conv_src_md, - conv_weights_md, - conv_dst_md, conv_strides, conv_dilation, conv_padding, - conv_padding_r); - auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, mkldnnUtils::getEngine( - LaunchContext::defaultContext()->engine())); - if (gradW != nullptr) { - auto convW_desc = gradB != nullptr - ? convolution_backward_weights::desc( - algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, conv_bias_md, - conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r) - : convolution_backward_weights::desc( - algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, - conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r); - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - dnnl::stream stream(engine); - auto convW_prim_desc = convolution_backward_weights::primitive_desc(convW_desc, engine, - conv_prim_desc); - auto userW_src_memory = dnnl::memory(user_src_md, engine, - const_cast(input)->buffer()); + auto conv_desc = gradB != nullptr ? convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, conv_src_md, conv_weights_md, conv_bias_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r) + : convolution_forward::desc(prop_kind::forward, algorithm::convolution_auto, conv_src_md, conv_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r); + + auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine())); + + auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + dnnl::stream stream(engine); + + if (gradW != nullptr) { + + auto convW_desc = gradB != nullptr ? convolution_backward_weights::desc(algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, conv_bias_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r) + : convolution_backward_weights::desc(algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + auto convW_prim_desc = convolution_backward_weights::primitive_desc(convW_desc, engine, conv_prim_desc); + + auto userW_src_memory = dnnl::memory(user_src_md, engine, const_cast(input)->buffer()); auto userW_weights_memory = dnnl::memory(user_diff_weights_md, engine, gradW->buffer()); - auto userW_dst_memory = dnnl::memory(user_dst_md, engine, - const_cast(gradO)->buffer()); + auto userW_dst_memory = dnnl::memory(user_dst_md, engine, const_cast(gradO)->buffer()); auto convW_src_memory = userW_src_memory; if (convW_prim_desc.src_desc() != userW_src_memory.get_desc()) { convW_src_memory = dnnl::memory(convW_prim_desc.src_desc(), engine); - reorder(userW_src_memory, convW_src_memory).execute(stream, userW_src_memory, - convW_src_memory); + reorder(userW_src_memory, convW_src_memory).execute(stream, userW_src_memory, convW_src_memory); } auto convW_weights_memory = userW_weights_memory; @@ -297,65 +167,53 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) { auto convW_dst_memory = userW_dst_memory; if (convW_prim_desc.diff_dst_desc() != userW_dst_memory.get_desc()) { convW_dst_memory = dnnl::memory(convW_prim_desc.diff_dst_desc(), engine); - reorder(userW_dst_memory, convW_dst_memory).execute(stream, userW_dst_memory, - convW_dst_memory); + reorder(userW_dst_memory, convW_dst_memory).execute(stream, userW_dst_memory, convW_dst_memory); } if (gradB != nullptr) { - auto convW_bias_memory = dnnl::memory(convW_prim_desc.diff_bias_desc(), engine, - gradB->buffer()); + + auto convW_bias_memory = dnnl::memory(convW_prim_desc.diff_bias_desc(), engine, gradB->buffer()); + convolution_backward_weights(convW_prim_desc).execute(stream, {{DNNL_ARG_SRC, convW_src_memory}, {DNNL_ARG_DIFF_DST, convW_dst_memory}, {DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory}, {DNNL_ARG_DIFF_BIAS, convW_bias_memory}}); - } else { + } + else { convolution_backward_weights(convW_prim_desc).execute(stream, {{DNNL_ARG_SRC, convW_src_memory}, {DNNL_ARG_DIFF_DST, convW_dst_memory}, {DNNL_ARG_DIFF_WEIGHTS, convW_weights_memory}}); } - if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) { - reorder(convW_weights_memory, userW_weights_memory).execute(stream, convW_weights_memory, - userW_weights_memory); - } + if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) + reorder(convW_weights_memory, userW_weights_memory).execute(stream, convW_weights_memory, userW_weights_memory); stream.wait(); } if (gradI != nullptr) { - auto convI_desc = convolution_backward_data::desc(algorithm::convolution_auto, - conv_diff_src_md, conv_weights_md, - conv_dst_md, conv_strides, conv_dilation, conv_padding, - conv_padding_r); + auto convI_desc = convolution_backward_data::desc(algorithm::convolution_auto, conv_diff_src_md, conv_weights_md, conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r); - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - dnnl::stream stream(engine); - auto convI_prim_desc = convolution_backward_data::primitive_desc(convI_desc, engine, - conv_prim_desc); + auto convI_prim_desc = convolution_backward_data::primitive_desc(convI_desc, engine, conv_prim_desc); auto userI_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer()); - auto userI_weights_memory = dnnl::memory(user_weights_md, engine, - const_cast(weights)->buffer()); - auto userI_dst_memory = dnnl::memory(user_dst_md, engine, - const_cast(gradO)->buffer()); + auto userI_weights_memory = dnnl::memory(user_weights_md, engine, const_cast(weights)->buffer()); + auto userI_dst_memory = dnnl::memory(user_dst_md, engine, const_cast(gradO)->buffer()); auto convI_src_memory = userI_src_memory; - if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) { + if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) convI_src_memory = dnnl::memory(convI_prim_desc.diff_src_desc(), engine); - } auto convI_weights_memory = userI_weights_memory; if (convI_prim_desc.weights_desc() != userI_weights_memory.get_desc()) { convI_weights_memory = dnnl::memory(convI_prim_desc.weights_desc(), engine); - reorder(userI_weights_memory, convI_weights_memory).execute(stream, userI_weights_memory, - convI_weights_memory); + reorder(userI_weights_memory, convI_weights_memory).execute(stream, userI_weights_memory, convI_weights_memory); } auto convI_dst_memory = userI_dst_memory; if (convI_prim_desc.diff_dst_desc() != userI_dst_memory.get_desc()) { convI_dst_memory = dnnl::memory(convI_prim_desc.diff_dst_desc(), engine); - reorder(userI_dst_memory, convI_dst_memory).execute(stream, userI_dst_memory, - convI_dst_memory); + reorder(userI_dst_memory, convI_dst_memory).execute(stream, userI_dst_memory, convI_dst_memory); } convolution_backward_data(convI_prim_desc).execute(stream, @@ -363,30 +221,128 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) { {DNNL_ARG_WEIGHTS, convI_weights_memory}, {DNNL_ARG_DIFF_SRC, convI_src_memory}}); - if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) { - reorder(convI_src_memory, userI_src_memory).execute(stream, convI_src_memory, - userI_src_memory); - } - - stream.wait(); + if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) + reorder(convI_src_memory, userI_src_memory).execute(stream, convI_src_memory, userI_src_memory); } +} + +////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(conv3dnew, ENGINE_CPU) { + + auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW) + + REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM CONV3D MKLDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 5, 0, "CUSTOM CONV3D MKLDNN OP: rank of weights array must be equal to 5, but got %i instead !", weights->rankOf()); + + int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) depth + int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) height + int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast(weights->sizeAt(2));// filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int paddingMode = INT_ARG(12); // 0-SAME, 1-VALID + int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + + int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + + std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC}); + REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM CONV3D MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); + + if (paddingMode) // SAME + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); + + conv3dMKLDNN(block, input, weights, bias, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, isNCDHW); + + return Status::OK(); +} + +PLATFORM_CHECK(conv3dnew, ENGINE_CPU) { + // we don't want to use mkldnn if cpu doesn't support avx/avx2 + if (::optimalLevel() < 2) + return false; + + auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW) + + return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, weights, bias, output}); +} + +////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) { + auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always + auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next + + auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon + auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always + auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] + + REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM CONV3D_BP MKLDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); + REQUIRE_TRUE(weights->rankOf() == 5, 0, "CUSTOM CONV3D_BP MKLDNN OP: rank of weights array must be equal to 5, but got %i instead !", weights->rankOf()); + REQUIRE_TRUE(gradO->rankOf() == 5, 0, "CUSTOM CONV3D_BP MKLDNN OP: rank of output gradients (next epsilon) array must be equal to 5, but got %i instead !", gradO->rankOf()); + + int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0));// filter(kernel) depth + int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1));// filter(kernel) height + int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast(weights->sizeAt(2));// filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID + int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW + + int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + + if(paddingMode) // SAME + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); + + int trueoD, trueoH, trueoW; // true output depth/height/width + ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode); + + std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx( {bS, oC, trueoD, trueoH, trueoW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2})); + std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC}); + REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); + + conv3dBpMKLDNN(block, input, weights, bias, gradO, gradI, gradW, gradB, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, isNCDHW); return Status::OK(); } PLATFORM_CHECK(conv3dnew_bp, ENGINE_CPU) { - auto input = INPUT_VARIABLE( - 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto weights = INPUT_VARIABLE( - 1); // [kD, kH, kW, iC, oC] always - auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE( - 2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next - auto gradI = OUTPUT_VARIABLE( - 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon - auto gradW = OUTPUT_VARIABLE( - 1); // [kD, kH, kW, iC, oC] always + auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always + auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next + + auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon + auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] return block.isUseMKLDNN() && diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp index 6db569eec..d95052c5a 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp @@ -177,7 +177,7 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N } ////////////////////////////////////////////////////////////////////////// -static void deconv2dBackPropMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, +static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int paddingMode) { @@ -492,7 +492,7 @@ PLATFORM_IMPL(deconv2d_bp, ENGINE_CPU) { gradO = new NDArray(gradO->permute({0,3,1,2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW] } - deconv2dBackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode); + deconv2dBpMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode); delete weights; delete gradW; diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/depthwiseConv2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/depthwiseConv2d.cpp index f3b745d09..fc7a1e9e3 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/depthwiseConv2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/depthwiseConv2d.cpp @@ -421,7 +421,7 @@ PLATFORM_CHECK(depthwise_conv2d, ENGINE_CPU) { return block.isUseMKLDNN() && mC == 1 && ( (xType==DataType::FLOAT32 && wType==DataType::FLOAT32 && bType==DataType::FLOAT32 && zType==DataType::FLOAT32) || - (xType==DataType::HALF && wType==DataType::HALF && bType==DataType::HALF && zType==DataType::HALF) || + (xType==DataType::BFLOAT16 && wType==DataType::BFLOAT16 && bType==DataType::BFLOAT16 && zType==DataType::BFLOAT16) || ((xType==DataType::UINT8 || xType==DataType::INT8) && wType==DataType::INT8 && (zType==DataType::UINT8 || zType==DataType::INT8 || zType==DataType::INT32 || zType==DataType::FLOAT32) && bType == zType) ); } diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d.cpp index 975cf7fe1..69aee8fad 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d.cpp @@ -29,117 +29,258 @@ using namespace dnnl; -namespace nd4j { - namespace ops { - namespace platforms { - PLATFORM_IMPL(maxpool2d, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); +namespace nd4j { +namespace ops { +namespace platforms { - REQUIRE_TRUE(input->rankOf() == 4, 0, "Input should have rank of 4, but got %i instead", - input->rankOf()); +////////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(maxpool2d, ENGINE_CPU) { + auto input = INPUT_VARIABLE(0); - // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; - auto argI = *(block.getIArguments()); - auto output = OUTPUT_VARIABLE(0); + REQUIRE_TRUE(input->rankOf() == 4, 0, "Input should have rank of 4, but got %i instead", + input->rankOf()); - const auto kH = INT_ARG(0); - const auto kW = INT_ARG(1); - const auto sH = INT_ARG(2); - const auto sW = INT_ARG(3); - int pH = INT_ARG(4); - int pW = INT_ARG(5); - const auto dH = INT_ARG(6); - const auto dW = INT_ARG(7); - const auto isSameMode = static_cast(INT_ARG(8)); + // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + auto argI = *(block.getIArguments()); + auto output = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D op: dilation must not be zero, but got instead {%i, %i}", - dH, dW); + const auto kH = INT_ARG(0); + const auto kW = INT_ARG(1); + const auto sH = INT_ARG(2); + const auto sW = INT_ARG(3); + int pH = INT_ARG(4); + int pW = INT_ARG(5); + const auto dH = INT_ARG(6); + const auto dW = INT_ARG(7); + const auto isSameMode = static_cast(INT_ARG(8)); - int oH = 0; - int oW = 0; + REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D op: dilation must not be zero, but got instead {%i, %i}", + dH, dW); - int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC + int oH = 0; + int oW = 0; - const int iH = static_cast(isNCHW ? input->sizeAt(2) : input->sizeAt(1)); - const int iW = static_cast(isNCHW ? input->sizeAt(3) : input->sizeAt(2)); + int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC - if (!isNCHW) { - input = new NDArray( - input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - output = new NDArray( - output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW] - } + const int iH = static_cast(isNCHW ? input->sizeAt(2) : input->sizeAt(1)); + const int iW = static_cast(isNCHW ? input->sizeAt(3) : input->sizeAt(2)); - ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); - - if (isSameMode) - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - - const int bS = input->sizeAt(0); - const int iC = input->sizeAt(1); - const int oC = output->sizeAt(1); - - auto poolingMode = PoolingType::MAX_POOL; - int extraParam0 = 1; - - dnnl_memory_desc_t empty; - dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty); - dnnl::memory::desc user_src_md(empty), user_dst_md(empty); - dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; - dnnl::algorithm algorithm; - - mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0, - true, - bS, iC, iH, iW, oC, oH, oW, input, nullptr, output, - algorithm, - &pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr, - &user_dst_md, - pool_strides, pool_kernel, pool_padding, pool_padding_r); - - auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md, - pool_dst_md, - pool_strides, pool_kernel, pool_padding, pool_padding_r); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine); - auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer()); - auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer()); - - auto pool_src_memory = user_src_memory; - dnnl::stream stream(engine); - if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) { - pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine); - reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory); - } - - auto pool_dst_memory = user_dst_memory; - if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) { - pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine); - } - - pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory}, - {DNNL_ARG_DST, pool_dst_memory}}); - - if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) { - reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory); - } - - stream.wait(); - - if (!isNCHW) { - delete input; - delete output; - } - - return Status::OK(); - } - - PLATFORM_CHECK(maxpool2d, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output}); - } - } + if (!isNCHW) { + input = new NDArray( + input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] + output = new NDArray( + output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW] } + + ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); + + if (isSameMode) + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); + + const int bS = input->sizeAt(0); + const int iC = input->sizeAt(1); + const int oC = output->sizeAt(1); + + auto poolingMode = PoolingType::MAX_POOL; + int extraParam0 = 1; + + dnnl_memory_desc_t empty; + dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty); + dnnl::memory::desc user_src_md(empty), user_dst_md(empty); + dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; + dnnl::algorithm algorithm; + + mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0, + true, + bS, iC, iH, iW, oC, oH, oW, input, nullptr, output, + algorithm, + &pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr, + &user_dst_md, + pool_strides, pool_kernel, pool_padding, pool_padding_r); + + auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md, + pool_dst_md, + pool_strides, pool_kernel, pool_padding, pool_padding_r); + + auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine); + auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer()); + auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer()); + + auto pool_src_memory = user_src_memory; + dnnl::stream stream(engine); + if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) { + pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine); + reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory); + } + + auto pool_dst_memory = user_dst_memory; + if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) { + pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine); + } + + pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory}, + {DNNL_ARG_DST, pool_dst_memory}}); + + if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) { + reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory); + } + + stream.wait(); + + if (!isNCHW) { + delete input; + delete output; + } + + return Status::OK(); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_CHECK(maxpool2d, ENGINE_CPU) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output}); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(maxpool2d_bp, ENGINE_CPU) { + + auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + + int kH = INT_ARG(0); // filter(kernel) height + int kW = INT_ARG(1); // filter(kernel) width + int sH = INT_ARG(2); // strides height + int sW = INT_ARG(3); // strides width + int pH = INT_ARG(4); // paddings height + int pW = INT_ARG(5); // paddings width + int dH = INT_ARG(6); // dilations height + int dW = INT_ARG(7); // dilations width + int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int extraParam0 = INT_ARG(9); + int isNCHW = + block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC + + REQUIRE_TRUE(input->rankOf() == 4, 0, + "AVGPOOL2D_BP op: input should have rank of 4, but got %i instead", input->rankOf()); + REQUIRE_TRUE(dH != 0 && dW != 0, 0, + "AVGPOOL2D_BP op: dilation must not be zero, but got instead {%i, %i}", dH, dW); + + int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, + indIiH, indWiC, indWoC, indWkH, indOoH); + + std::string expectedGradOShape = ShapeUtils::shapeAsString( + ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1})); + std::string expectedGradIShape = ShapeUtils::shapeAsString( + ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, iH, iW, 0, indIOioC, indIiH, indIiH + 1})); + REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, + "AVGPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", + expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, + "AVGPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", + expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str()); + + + if (!isNCHW) { + input = new NDArray(input->permute( + {0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] + gradI = new NDArray(gradI->permute( + {0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] + gradO = new NDArray(gradO->permute( + {0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW] + } + + if (isSameMode) // SAME + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); + + auto poolingMode = PoolingType::MAX_POOL; + + dnnl_memory_desc_t empty; + dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty); + dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty); + dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; + dnnl::algorithm algorithm; + + mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0, + true, + bS, iC, iH, iW, oC, oH, oW, input, gradI, gradO, algorithm, + &pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md, + &user_diff_src_md, &user_dst_md, + pool_strides, pool_kernel, pool_padding, pool_padding_r); + + // input is sometimes null, so we can't rely on pool_src_md being valid + auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm, + input->buffer() != nullptr ? pool_src_md : pool_diff_src_md, + pool_dst_md, pool_strides, pool_kernel, pool_padding, + pool_padding_r); + + auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + dnnl::stream stream(engine); + auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine); + + auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, + pool_strides, pool_kernel, pool_padding, pool_padding_r); + + auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc); + auto userB_src_memory = dnnl::memory(user_src_md, engine, gradI->buffer()); + auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer()); + + auto poolB_src_memory = userB_src_memory; + if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) { + poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine); + } + + auto poolB_dst_memory = userB_dst_memory; + if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) { + poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine); + reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory); + } + + auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer()); + auto pool_src_memory = user_src_memory; + if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) { + pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine); + reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory); + } + + auto pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine); + auto pool_workspace_memory = dnnl::memory(pool_prim_desc.workspace_desc(), engine); + + pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory}, + {DNNL_ARG_DST, pool_dst_memory}, + {DNNL_ARG_WORKSPACE, pool_workspace_memory}}); + // probably wrong, fix that + pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory}, + {DNNL_ARG_WORKSPACE, pool_workspace_memory}, + {DNNL_ARG_DIFF_SRC, poolB_src_memory}}); + + if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) { + reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory); + } + + stream.wait(); + + if (!isNCHW) { + delete input; + delete gradI; + delete gradO; + } + + return Status::OK(); +} + +PLATFORM_CHECK(maxpool2d_bp, ENGINE_CPU) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output}); +} + +} +} } \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d_bp.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d_bp.cpp deleted file mode 100644 index 686bdc7fb..000000000 --- a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d_bp.cpp +++ /dev/null @@ -1,174 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author saudet -// @author raver119@gmail.com -// - -#include -#include -#include - -#include -#include "mkldnnUtils.h" -#include - -using namespace dnnl; - -namespace nd4j { - namespace ops { - namespace platforms { - PLATFORM_IMPL(maxpool2d_bp, ENGINE_CPU) { - auto input = INPUT_VARIABLE( - 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto gradO = INPUT_VARIABLE( - 1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - auto gradI = OUTPUT_VARIABLE( - 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - - int kH = INT_ARG(0); // filter(kernel) height - int kW = INT_ARG(1); // filter(kernel) width - int sH = INT_ARG(2); // strides height - int sW = INT_ARG(3); // strides width - int pH = INT_ARG(4); // paddings height - int pW = INT_ARG(5); // paddings width - int dH = INT_ARG(6); // dilations height - int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int extraParam0 = INT_ARG(9); - int isNCHW = - block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC - - REQUIRE_TRUE(input->rankOf() == 4, 0, - "AVGPOOL2D_BP op: input should have rank of 4, but got %i instead", input->rankOf()); - REQUIRE_TRUE(dH != 0 && dW != 0, 0, - "AVGPOOL2D_BP op: dilation must not be zero, but got instead {%i, %i}", dH, dW); - - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, - indIiH, indWiC, indWoC, indWkH, indOoH); - - std::string expectedGradOShape = ShapeUtils::shapeAsString( - ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1})); - std::string expectedGradIShape = ShapeUtils::shapeAsString( - ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, iH, iW, 0, indIOioC, indIiH, indIiH + 1})); - REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, - "AVGPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", - expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, - "AVGPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", - expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str()); - - - if (!isNCHW) { - input = new NDArray(input->permute( - {0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - gradI = new NDArray(gradI->permute( - {0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - gradO = new NDArray(gradO->permute( - {0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW] - } - - if (isSameMode) // SAME - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - - auto poolingMode = PoolingType::MAX_POOL; - - dnnl_memory_desc_t empty; - dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty); - dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty); - dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; - dnnl::algorithm algorithm; - - mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0, - true, - bS, iC, iH, iW, oC, oH, oW, input, gradI, gradO, algorithm, - &pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md, - &user_diff_src_md, &user_dst_md, - pool_strides, pool_kernel, pool_padding, pool_padding_r); - - // input is sometimes null, so we can't rely on pool_src_md being valid - auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm, - input->buffer() != nullptr ? pool_src_md : pool_diff_src_md, - pool_dst_md, pool_strides, pool_kernel, pool_padding, - pool_padding_r); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - dnnl::stream stream(engine); - auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine); - - auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, - pool_strides, pool_kernel, pool_padding, pool_padding_r); - - auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc); - auto userB_src_memory = dnnl::memory(user_src_md, engine, gradI->buffer()); - auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer()); - - auto poolB_src_memory = userB_src_memory; - if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) { - poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine); - } - - auto poolB_dst_memory = userB_dst_memory; - if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) { - poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine); - reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory); - } - - auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer()); - auto pool_src_memory = user_src_memory; - if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) { - pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine); - reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory); - } - - auto pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine); - auto pool_workspace_memory = dnnl::memory(pool_prim_desc.workspace_desc(), engine); - - pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory}, - {DNNL_ARG_DST, pool_dst_memory}, - {DNNL_ARG_WORKSPACE, pool_workspace_memory}}); - // probably wrong, fix that - pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory}, - {DNNL_ARG_WORKSPACE, pool_workspace_memory}, - {DNNL_ARG_DIFF_SRC, poolB_src_memory}}); - - if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) { - reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory); - } - - stream.wait(); - - if (!isNCHW) { - delete input; - delete gradI; - delete gradO; - } - - return Status::OK(); - } - - PLATFORM_CHECK(maxpool2d_bp, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output}); - } - } - } -} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling3d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling3d.cpp index 604bdcb6b..a37422c55 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling3d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling3d.cpp @@ -28,124 +28,273 @@ using namespace dnnl; -namespace nd4j { - namespace ops { - namespace platforms { - PLATFORM_IMPL(maxpool3dnew, ENGINE_CPU) { - auto input = INPUT_VARIABLE( - 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto output = OUTPUT_VARIABLE( - 0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW) +namespace nd4j { +namespace ops { +namespace platforms { - int kD = INT_ARG(0); // filter(kernel) depth - int kH = INT_ARG(1); // filter(kernel) height - int kW = INT_ARG(2); // filter(kernel) width - int sD = INT_ARG(3); // strides depth - int sH = INT_ARG(4); // strides height - int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID - // int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases - int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW +////////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(maxpool3dnew, ENGINE_CPU) { + auto input = INPUT_VARIABLE( + 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto output = OUTPUT_VARIABLE( + 0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW) - REQUIRE_TRUE(input->rankOf() == 5, 0, - "MAXPOOL3DNEW OP: rank of input array must be equal to 5, but got %i instead !", - input->rankOf()); - REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, - "MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); + int kD = INT_ARG(0); // filter(kernel) depth + int kH = INT_ARG(1); // filter(kernel) height + int kW = INT_ARG(2); // filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID + // int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases + int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, - indIOioC, indIOioD, indWiC, indWoC, indWkD); + REQUIRE_TRUE(input->rankOf() == 5, 0, + "MAXPOOL3DNEW OP: rank of input array must be equal to 5, but got %i instead !", + input->rankOf()); + REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, + "MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); - std::string expectedOutputShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx( - {bS, iC, oD, oH, oW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2})); - REQUIRE_TRUE(expectedOutputShape == ShapeUtils::shapeAsString(output), 0, - "MAXPOOL3D op: wrong shape of output array, expected is %s, but got %s instead !", - expectedOutputShape.c_str(), ShapeUtils::shapeAsString(output).c_str()); - // REQUIRE_TRUE(iD >= kD && iH >= kH && iW >= kW, 0, "MAXPOOL3D OP: the input depth/height/width must be greater or equal to kernel(filter) depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", iD,iH,iW, kD,kH,kW); - // REQUIRE_TRUE(kD/2 >= pD && kH/2 >= pH && kW/2 >= pW, 0, "MAXPOOL3D OP: pad depth/height/width must not be greater than half of kernel depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", pD,pH,pW, kD,kH,kW); + int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, + indIOioC, indIOioD, indWiC, indWoC, indWkD); - if (!isNCDHW) { - input = new NDArray( - input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] - output = new NDArray( - output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW] - } + std::string expectedOutputShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, iC, oD, oH, oW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2})); + REQUIRE_TRUE(expectedOutputShape == ShapeUtils::shapeAsString(output), 0, + "MAXPOOL3D op: wrong shape of output array, expected is %s, but got %s instead !", + expectedOutputShape.c_str(), ShapeUtils::shapeAsString(output).c_str()); + // REQUIRE_TRUE(iD >= kD && iH >= kH && iW >= kW, 0, "MAXPOOL3D OP: the input depth/height/width must be greater or equal to kernel(filter) depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", iD,iH,iW, kD,kH,kW); + // REQUIRE_TRUE(kD/2 >= pD && kH/2 >= pH && kW/2 >= pW, 0, "MAXPOOL3D OP: pad depth/height/width must not be greater than half of kernel depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", pD,pH,pW, kD,kH,kW); - if (isSameMode) // SAME - ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, - dW); - - - auto poolingMode = PoolingType::MAX_POOL; - auto extraParam0 = 1; - - dnnl_memory_desc_t empty; - dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty); - dnnl::memory::desc user_src_md(empty), user_dst_md(empty); - dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; - dnnl::algorithm algorithm; - - mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, - extraParam0, true, - bS, iC, iD, iH, iW, oC, oD, oH, oW, input, nullptr, output, - algorithm, - &pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr, - &user_dst_md, - pool_strides, pool_kernel, pool_padding, pool_padding_r); - - auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md, - pool_dst_md, pool_strides, pool_kernel, pool_padding, - pool_padding_r); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - dnnl::stream stream(engine); - auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine); - auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer()); - auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer()); - - auto pool_src_memory = user_src_memory; - if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) { - pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine); - reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory); - } - - auto pool_dst_memory = user_dst_memory; - if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) { - pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine); - } - - pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory}, - {DNNL_ARG_DST, pool_dst_memory}}); - - if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) { - reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory); - } - - stream.wait(); - - - if (!isNCDHW) { - delete input; - delete output; - } - - return Status::OK(); - } - - PLATFORM_CHECK(maxpool3dnew, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output}); - } - } + if (!isNCDHW) { + input = new NDArray( + input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] + output = new NDArray( + output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW] } + + if (isSameMode) // SAME + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, + dW); + + + auto poolingMode = PoolingType::MAX_POOL; + auto extraParam0 = 1; + + dnnl_memory_desc_t empty; + dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty); + dnnl::memory::desc user_src_md(empty), user_dst_md(empty); + dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; + dnnl::algorithm algorithm; + + mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, + extraParam0, true, + bS, iC, iD, iH, iW, oC, oD, oH, oW, input, nullptr, output, + algorithm, + &pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr, + &user_dst_md, + pool_strides, pool_kernel, pool_padding, pool_padding_r); + + auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md, + pool_dst_md, pool_strides, pool_kernel, pool_padding, + pool_padding_r); + + auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + dnnl::stream stream(engine); + auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine); + auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer()); + auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer()); + + auto pool_src_memory = user_src_memory; + if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) { + pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine); + reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory); + } + + auto pool_dst_memory = user_dst_memory; + if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) { + pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine); + } + + pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory}, + {DNNL_ARG_DST, pool_dst_memory}}); + + if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) { + reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory); + } + + stream.wait(); + + + if (!isNCDHW) { + delete input; + delete output; + } + + return Status::OK(); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_CHECK(maxpool3dnew, ENGINE_CPU) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output}); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(maxpool3dnew_bp, ENGINE_CPU) { + auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next + auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon + + const int kD = INT_ARG(0); // filter(kernel) depth + const int kH = INT_ARG(1); // filter(kernel) height + const int kW = INT_ARG(2); // filter(kernel) width + const int sD = INT_ARG(3); // strides depth + const int sH = INT_ARG(4); // strides height + const int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + const int dD = INT_ARG(9); // dilations depth + const int dH = INT_ARG(10); // dilations height + const int dW = INT_ARG(11); // dilations width + const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID + // int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases + int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW + + REQUIRE_TRUE(input->rankOf() == 5, 0, + "MAXPOOL3D_BP op: input should have rank of 5, but got %i instead", input->rankOf()); + REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, + "MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); + + int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, + indIOioC, indIOioD, indWiC, indWoC, indWkD); + + std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, iC, oD, oH, oW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2})); + std::string expectedGradIShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx( + {bS, iC, iD, iH, iW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2})); + REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, + "MAXPOOL3D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", + expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, + "MAXPOOL3D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", + expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str()); + + if (!isNCDHW) { + input = new NDArray(input->permute( + {0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] + gradI = new NDArray(gradI->permute( + {0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] + gradO = new NDArray(gradO->permute( + {0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW] + } + + if (isSameMode) // SAME + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, + dW); + + + auto poolingMode = PoolingType::MAX_POOL; + auto extraParam0 = 1; + + dnnl_memory_desc_t empty; + dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty); + dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty); + dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; + dnnl::algorithm algorithm; + + mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, + extraParam0, true, + bS, iC, iD, iH, iW, oC, oD, oH, oW, input, gradI, gradO, + algorithm, + &pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md, + &user_diff_src_md, &user_dst_md, + pool_strides, pool_kernel, pool_padding, pool_padding_r); + + // input is sometimes null, so we can't rely on pool_src_md being valid + if (input->buffer() == nullptr) { + pool_src_md = pool_diff_src_md; + user_src_md = user_diff_src_md; + } + auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm, pool_src_md, pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r); + + auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + dnnl::stream stream(engine); + auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine); + + auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r); + + auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc); + auto userB_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer()); + auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer()); + + auto poolB_src_memory = userB_src_memory; + if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) { + poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine); + } + + auto poolB_dst_memory = userB_dst_memory; + if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) { + poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine); + reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory); + } + + + auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer()); + + auto pool_src_memory = user_src_memory; + if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) { + pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine); + reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory); + } + + auto pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine); + auto pool_workspace_memory = dnnl::memory(pool_prim_desc.workspace_desc(), engine); + + pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory}, + {DNNL_ARG_DST, pool_dst_memory}, + {DNNL_ARG_WORKSPACE, pool_workspace_memory}}); + pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory}, + {DNNL_ARG_WORKSPACE, pool_workspace_memory}, + {DNNL_ARG_DIFF_SRC, poolB_src_memory}}); + + + if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) { + reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory); + } + + stream.wait(); + + if (!isNCDHW) { + delete input; + delete gradI; + delete gradO; + } + + return Status::OK(); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_CHECK(maxpool3dnew_bp, ENGINE_CPU) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output}); +} + +} +} } \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling_3d_bp.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling_3d_bp.cpp deleted file mode 100644 index b684df1bb..000000000 --- a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling_3d_bp.cpp +++ /dev/null @@ -1,181 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author raver119@gmail.com -// - -#include -#include -#include - -#include -#include "mkldnnUtils.h" -#include - -using namespace dnnl; - -namespace nd4j { - namespace ops { - namespace platforms { - PLATFORM_IMPL(maxpool3dnew_bp, ENGINE_CPU) { - auto input = INPUT_VARIABLE( - 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto gradO = INPUT_VARIABLE( - 1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next - auto gradI = OUTPUT_VARIABLE( - 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon - - const int kD = INT_ARG(0); // filter(kernel) depth - const int kH = INT_ARG(1); // filter(kernel) height - const int kW = INT_ARG(2); // filter(kernel) width - const int sD = INT_ARG(3); // strides depth - const int sH = INT_ARG(4); // strides height - const int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - const int dD = INT_ARG(9); // dilations depth - const int dH = INT_ARG(10); // dilations height - const int dW = INT_ARG(11); // dilations width - const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID - // int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases - int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW - - REQUIRE_TRUE(input->rankOf() == 5, 0, - "MAXPOOL3D_BP op: input should have rank of 5, but got %i instead", input->rankOf()); - REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, - "MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, - indIOioC, indIOioD, indWiC, indWoC, indWkD); - - std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx( - {bS, iC, oD, oH, oW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2})); - std::string expectedGradIShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx( - {bS, iC, iD, iH, iW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2})); - REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, - "MAXPOOL3D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", - expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, - "MAXPOOL3D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", - expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str()); - - if (!isNCDHW) { - input = new NDArray(input->permute( - {0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] - gradI = new NDArray(gradI->permute( - {0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] - gradO = new NDArray(gradO->permute( - {0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW] - } - - if (isSameMode) // SAME - ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, - dW); - - - auto poolingMode = PoolingType::MAX_POOL; - auto extraParam0 = 1; - - dnnl_memory_desc_t empty; - dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty); - dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty); - dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; - dnnl::algorithm algorithm; - - mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, - extraParam0, true, - bS, iC, iD, iH, iW, oC, oD, oH, oW, input, gradI, gradO, - algorithm, - &pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md, - &user_diff_src_md, &user_dst_md, - pool_strides, pool_kernel, pool_padding, pool_padding_r); - - // input is sometimes null, so we can't rely on pool_src_md being valid - if (input->buffer() == nullptr) { - pool_src_md = pool_diff_src_md; - user_src_md = user_diff_src_md; - } - auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm, pool_src_md, pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - dnnl::stream stream(engine); - auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine); - - auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r); - - auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc); - auto userB_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer()); - auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer()); - - auto poolB_src_memory = userB_src_memory; - if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) { - poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine); - } - - auto poolB_dst_memory = userB_dst_memory; - if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) { - poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine); - reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory); - } - - - auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer()); - - auto pool_src_memory = user_src_memory; - if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) { - pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine); - reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory); - } - - auto pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine); - auto pool_workspace_memory = dnnl::memory(pool_prim_desc.workspace_desc(), engine); - - pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory}, - {DNNL_ARG_DST, pool_dst_memory}, - {DNNL_ARG_WORKSPACE, pool_workspace_memory}}); - pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory}, - {DNNL_ARG_WORKSPACE, pool_workspace_memory}, - {DNNL_ARG_DIFF_SRC, poolB_src_memory}}); - - - if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) { - reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory); - } - - stream.wait(); - - if (!isNCDHW) { - delete input; - delete gradI; - delete gradO; - } - - return Status::OK(); - } - - PLATFORM_CHECK(maxpool3dnew_bp, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output}); - } - } - } -} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp index 96bbffcf8..0b81de76d 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp @@ -23,383 +23,388 @@ using namespace dnnl; -namespace nd4j { - namespace mkldnnUtils { - void getMKLDNNMemoryDescPool2d( - int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW, - int bS, int iC, int iH, int iW, int oC, int oH, int oW, - const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm, - dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md, - dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, - dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r) { - dnnl::memory::dims pool_src_tz = { bS, iC, iH, iW }; - dnnl::memory::dims pool_dst_tz = { bS, oC, oH, oW }; +namespace nd4j { +namespace mkldnnUtils { - pool_strides = { sH, sW }; - pool_kernel = { kH, kW }; - pool_padding = { pH, pW }; - pool_padding_r = { (oH - 1) * sH - iH + kH - pH, - (oW - 1) * sW - iW + kW - pW }; +////////////////////////////////////////////////////////////////////////// +void getMKLDNNMemoryDescPool2d( + int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW, + int bS, int iC, int iH, int iW, int oC, int oH, int oW, + const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm, + dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md, + dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, + dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r) { + dnnl::memory::dims pool_src_tz = { bS, iC, iH, iW }; + dnnl::memory::dims pool_dst_tz = { bS, oC, oH, oW }; - algorithm = poolingMode == 0 ? algorithm::pooling_max - : extraParam0 == 0 ? algorithm::pooling_avg_exclude_padding - : algorithm::pooling_avg_include_padding; - auto type = dnnl::memory::data_type::f32; - auto format = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; - auto supposed_to_be_any_format = dnnl::memory::format_tag::nChw8c; // doesn't work with "any" + pool_strides = { sH, sW }; + pool_kernel = { kH, kW }; + pool_padding = { pH, pW }; + pool_padding_r = { (oH - 1) * sH - iH + kH - pH, + (oW - 1) * sW - iW + kW - pW }; - if (src != nullptr && src->getBuffer() != nullptr && pool_src_md != nullptr) { - *pool_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format); - *user_src_md = dnnl::memory::desc({ pool_src_tz }, type, format); - user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc" - user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCHW ? 0 : 0]; - user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCHW ? 1 : 3]; - user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCHW ? 2 : 1]; - user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCHW ? 3 : 2]; - } + algorithm = poolingMode == 0 ? algorithm::pooling_max + : extraParam0 == 0 ? algorithm::pooling_avg_exclude_padding + : algorithm::pooling_avg_include_padding; + auto type = dnnl::memory::data_type::f32; + auto format = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; + auto supposed_to_be_any_format = dnnl::memory::format_tag::nChw8c; // doesn't work with "any" - if (diff_src != nullptr && diff_src->getBuffer() != nullptr && pool_diff_src_md != nullptr) { - *pool_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format); - *user_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, format); - user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc" - user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCHW ? 0 : 0]; - user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCHW ? 1 : 3]; - user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCHW ? 2 : 1]; - user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCHW ? 3 : 2]; - } - - if (dst != nullptr && dst->getBuffer() != nullptr && pool_dst_md != nullptr) { - *pool_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format); - *user_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, format); - user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc" - user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCHW ? 0 : 0]; - user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCHW ? 1 : 3]; - user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCHW ? 2 : 1]; - user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCHW ? 3 : 2]; - } - }; - - - void getMKLDNNMemoryDescPool3d( - int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, int poolingMode, int extraParam0, bool isNCDHW, - int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, - const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm, - dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md, - dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, - dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r) { - dnnl::memory::dims pool_src_tz = { bS, iC, iD, iH, iW }; - dnnl::memory::dims pool_dst_tz = { bS, oC, oD, oH, oW }; - - pool_strides = { sD, sH, sW }; - pool_kernel = { kD, kH, kW }; - pool_padding = { pD, pH, pW }; - pool_padding_r = { (oD - 1) * sD - iD + kD - pD, - (oH - 1) * sH - iH + kH - pH, - (oW - 1) * sW - iW + kW - pW }; - - algorithm = poolingMode == 0 ? algorithm::pooling_max - : extraParam0 == 0 ? algorithm::pooling_avg_exclude_padding - : algorithm::pooling_avg_include_padding; - auto type = dnnl::memory::data_type::f32; - auto format = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; - auto supposed_to_be_any_format = dnnl::memory::format_tag::nCdhw8c; // doesn't work with "any" - - if (src != nullptr && src->getBuffer() != nullptr && pool_src_md != nullptr) { - *pool_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format); - *user_src_md = dnnl::memory::desc({ pool_src_tz }, type, format); - user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" - user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCDHW ? 0 : 0]; - user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCDHW ? 1 : 4]; - user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCDHW ? 2 : 1]; - user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCDHW ? 3 : 2]; - user_src_md->data.format_desc.blocking.strides[4] = src->stridesOf()[isNCDHW ? 4 : 3]; - } - - if (diff_src != nullptr && diff_src->getBuffer() != nullptr && pool_diff_src_md != nullptr) { - *pool_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format); - *user_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, format); - user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" - user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCDHW ? 0 : 0]; - user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCDHW ? 1 : 4]; - user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCDHW ? 2 : 1]; - user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCDHW ? 3 : 2]; - user_diff_src_md->data.format_desc.blocking.strides[4] = diff_src->stridesOf()[isNCDHW ? 4 : 3]; - } - - if (dst != nullptr && dst->getBuffer() != nullptr && pool_dst_md != nullptr) { - *pool_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format); - *user_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, format); - user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" - user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCDHW ? 0 : 0]; - user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCDHW ? 1 : 4]; - user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCDHW ? 2 : 1]; - user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCDHW ? 3 : 2]; - user_dst_md->data.format_desc.blocking.strides[4] = dst->stridesOf()[isNCDHW ? 4 : 3]; - } - }; - - - - void getMKLDNNMemoryDescConv2d( - int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, const int paddingMode, bool isNCHW, - int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src, - const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst, - dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md, - dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md, - dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md, - dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md, - dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation) { - dnnl::memory::dims conv_src_tz = { bS, iC, iH, iW }; - dnnl::memory::dims conv_weights_tz = { oC, iC, kH, kW }; - dnnl::memory::dims conv_bias_tz = { oC }; - dnnl::memory::dims conv_dst_tz = { bS, oC, oH, oW }; - - const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d - - conv_strides = { sH, sW }; - conv_padding = { pH, pW }; - conv_padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame }; - conv_dilation = { dH-1, dW-1}; - - auto type = dnnl::memory::data_type::f32; - auto format = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; - auto formatw = dnnl::memory::format_tag::hwio; - - if (src != nullptr && conv_src_md != nullptr) { - *conv_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any); - *user_src_md = dnnl::memory::desc({ conv_src_tz }, type, format); - user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc" - user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCHW ? 0 : 0]; - user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCHW ? 1 : 3]; - user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCHW ? 2 : 1]; - user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCHW ? 3 : 2]; - } - - if (diff_src != nullptr && conv_diff_src_md != nullptr) { - *conv_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any); - *user_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, format); - user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc" - user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCHW ? 0 : 0]; - user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCHW ? 1 : 3]; - user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCHW ? 2 : 1]; - user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCHW ? 3 : 2]; - } - - if (weights != nullptr && conv_weights_md != nullptr) { - *conv_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any); - *user_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw); - user_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = hwio" - user_weights_md->data.format_desc.blocking.strides[0] = weights->stridesOf()[3]; - user_weights_md->data.format_desc.blocking.strides[1] = weights->stridesOf()[2]; - user_weights_md->data.format_desc.blocking.strides[2] = weights->stridesOf()[0]; - user_weights_md->data.format_desc.blocking.strides[3] = weights->stridesOf()[1]; - } - - if (diff_weights != nullptr && conv_diff_weights_md != nullptr) { - *conv_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any); - *user_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw); - user_diff_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = hwio" - user_diff_weights_md->data.format_desc.blocking.strides[0] = diff_weights->stridesOf()[3]; - user_diff_weights_md->data.format_desc.blocking.strides[1] = diff_weights->stridesOf()[2]; - user_diff_weights_md->data.format_desc.blocking.strides[2] = diff_weights->stridesOf()[0]; - user_diff_weights_md->data.format_desc.blocking.strides[3] = diff_weights->stridesOf()[1]; - } - - if (bias != nullptr && conv_bias_md != nullptr) { - *conv_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::any); - *user_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::x); - } - - if (dst != nullptr && conv_dst_md != nullptr) { - *conv_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, dnnl::memory::format_tag::any); - *user_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, format); - user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc" - user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCHW ? 0 : 0]; - user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCHW ? 1 : 3]; - user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCHW ? 2 : 1]; - user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCHW ? 3 : 2]; - } - } - - void getMKLDNNMemoryDescConv3d( - int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, bool paddingMode, bool isNCDHW, - int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, const NDArray* src, const NDArray* diff_src, - const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst, - dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md, - dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md, - dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md, - dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md, - dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation) { - dnnl::memory::dims conv_src_tz = { bS, iC, iD, iH, iW }; - dnnl::memory::dims conv_weights_tz = { oC, iC, kD, kH, kW }; - dnnl::memory::dims conv_bias_tz = { oC }; - dnnl::memory::dims conv_dst_tz = { bS, oC, oD, oH, oW }; - - conv_strides = { sD, sH, sW }; - conv_padding = { pD, pH, pW }; - conv_padding_r = { (oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW }; - conv_dilation = { dD-1, dH-1, dW-1}; - - auto type = dnnl::memory::data_type::f32; - auto format = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; - auto formatw = dnnl::memory::format_tag::dhwio; - - if (src != nullptr && conv_src_md != nullptr) { - *conv_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any); - *user_src_md = dnnl::memory::desc({ conv_src_tz }, type, format); - user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" - user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCDHW ? 0 : 0]; - user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCDHW ? 1 : 4]; - user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCDHW ? 2 : 1]; - user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCDHW ? 3 : 2]; - user_src_md->data.format_desc.blocking.strides[4] = src->stridesOf()[isNCDHW ? 4 : 3]; - } - - if (diff_src != nullptr && conv_diff_src_md != nullptr) { - *conv_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any); - *user_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, format); - user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" - user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCDHW ? 0 : 0]; - user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCDHW ? 1 : 4]; - user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCDHW ? 2 : 1]; - user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCDHW ? 3 : 2]; - user_diff_src_md->data.format_desc.blocking.strides[4] = diff_src->stridesOf()[isNCDHW ? 4 : 3]; - } - - if (weights != nullptr && conv_weights_md != nullptr) { - *conv_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any); - *user_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw); - user_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = dhwio" - user_weights_md->data.format_desc.blocking.strides[0] = weights->stridesOf()[4]; - user_weights_md->data.format_desc.blocking.strides[1] = weights->stridesOf()[3]; - user_weights_md->data.format_desc.blocking.strides[2] = weights->stridesOf()[0]; - user_weights_md->data.format_desc.blocking.strides[3] = weights->stridesOf()[1]; - user_weights_md->data.format_desc.blocking.strides[4] = weights->stridesOf()[2]; - } - - if (diff_weights != nullptr && conv_diff_weights_md != nullptr) { - *conv_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any); - *user_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw); - user_diff_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = dhwio" - user_diff_weights_md->data.format_desc.blocking.strides[0] = diff_weights->stridesOf()[4]; - user_diff_weights_md->data.format_desc.blocking.strides[1] = diff_weights->stridesOf()[3]; - user_diff_weights_md->data.format_desc.blocking.strides[2] = diff_weights->stridesOf()[0]; - user_diff_weights_md->data.format_desc.blocking.strides[3] = diff_weights->stridesOf()[1]; - user_diff_weights_md->data.format_desc.blocking.strides[4] = diff_weights->stridesOf()[2]; - } - - if (bias != nullptr && conv_bias_md != nullptr) { - *conv_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::any); - *user_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::x); - } - - if (dst != nullptr && conv_dst_md != nullptr) { - *conv_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, dnnl::memory::format_tag::any); - *user_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, format); - user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" - user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCDHW ? 0 : 0]; - user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCDHW ? 1 : 4]; - user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCDHW ? 2 : 1]; - user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCDHW ? 3 : 2]; - user_dst_md->data.format_desc.blocking.strides[4] = dst->stridesOf()[isNCDHW ? 4 : 3]; - } - }; - - - // void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst, - // dnnl::memory::desc* batchnorm_src_md, dnnl::memory::desc* batchnorm_diff_src_md, dnnl::memory::desc* batchnorm_dst_md, - // dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis) { - // const Nd4jLong* shape = src->getShapeInfo(); - // Nd4jLong rank = shape[0]; - // Nd4jLong dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one - // Nd4jLong dim2 = axis >= 2 ? 1 : 2; - // Nd4jLong dim3 = axis >= 3 ? 2 : 3; - // dnnl::memory::dims batchnorm_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1}; - - // auto type = dnnl::memory::data_type::f32; - // auto format = dnnl::memory::format_tag::nchw; - // auto supposed_to_be_any_format = dnnl::memory::format_tag::nChw8c; // doesn't work with "any" - - // if (src != nullptr && src->getBuffer() != nullptr && batchnorm_src_md != nullptr) { - // *batchnorm_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format); - // *user_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format); - // user_src_md->data.format_kind = dnnl_blocked; // overrides format - // user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0]; - // user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1]; - // user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1; - // user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? src->stridesOf()[dim3] : 1; - // } - - // if (diff_src != nullptr && diff_src->getBuffer() != nullptr && batchnorm_diff_src_md != nullptr) { - // *batchnorm_diff_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format); - // *user_diff_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format); - // user_diff_src_md->data.format_kind = dnnl_blocked; // overrides format - // user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0]; - // user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1]; - // user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1; - // user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1; - // } - - // if (dst != nullptr && dst->getBuffer() != nullptr && batchnorm_dst_md != nullptr) { - // *batchnorm_dst_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format); - // *user_dst_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format); - // user_dst_md->data.format_kind = dnnl_blocked; // overrides format - // user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0]; - // user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1]; - // user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1; - // user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? dst->stridesOf()[dim3] : 1; - // } - // }; - - - void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst, - dnnl::memory::desc* lrn_src_md, dnnl::memory::desc* lrn_diff_src_md, dnnl::memory::desc* lrn_dst_md, - dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis) { - const Nd4jLong* shape = src->getShapeInfo(); - long rank = shape[0]; - long dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one - long dim2 = axis >= 2 ? 1 : 2; - long dim3 = axis >= 3 ? 2 : 3; - dnnl::memory::dims lrn_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1}; - - auto type = dnnl::memory::data_type::f32; - auto format = axis == 1 ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; - auto supposed_to_be_any_format = format; // doesn't work with "any" - - if (src != nullptr && src->getBuffer() != nullptr && lrn_src_md != nullptr) { - *lrn_src_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format); - *user_src_md = dnnl::memory::desc({ lrn_src_tz }, type, format); - user_src_md->data.format_kind = dnnl_blocked; - user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0]; - user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1]; - user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1; - user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? src->stridesOf()[dim3] : 1; - } - - if (diff_src != nullptr && diff_src->getBuffer() != nullptr && lrn_diff_src_md != nullptr) { - *lrn_diff_src_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format); - *user_diff_src_md = dnnl::memory::desc({ lrn_src_tz }, type, format); - user_diff_src_md->data.format_kind = dnnl_blocked; - user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0]; - user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1]; - user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1; - user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1; - } - - if (dst != nullptr && dst->getBuffer() != nullptr && lrn_dst_md != nullptr) { - *lrn_dst_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format); - *user_dst_md = dnnl::memory::desc({ lrn_src_tz }, type, format); - user_dst_md->data.format_kind = dnnl_blocked; - user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0]; - user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1]; - user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1; - user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? dst->stridesOf()[dim3] : 1; - } - } - - dnnl::engine& getEngine(void *ptr) { - auto eng = reinterpret_cast(ptr); - return *eng; - } + if (src != nullptr && src->getBuffer() != nullptr && pool_src_md != nullptr) { + *pool_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format); + *user_src_md = dnnl::memory::desc({ pool_src_tz }, type, format); + user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc" + user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCHW ? 0 : 0]; + user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCHW ? 1 : 3]; + user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCHW ? 2 : 1]; + user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCHW ? 3 : 2]; } + + if (diff_src != nullptr && diff_src->getBuffer() != nullptr && pool_diff_src_md != nullptr) { + *pool_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format); + *user_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, format); + user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc" + user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCHW ? 0 : 0]; + user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCHW ? 1 : 3]; + user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCHW ? 2 : 1]; + user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCHW ? 3 : 2]; + } + + if (dst != nullptr && dst->getBuffer() != nullptr && pool_dst_md != nullptr) { + *pool_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format); + *user_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, format); + user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc" + user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCHW ? 0 : 0]; + user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCHW ? 1 : 3]; + user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCHW ? 2 : 1]; + user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCHW ? 3 : 2]; + } +}; + +////////////////////////////////////////////////////////////////////////// +void getMKLDNNMemoryDescPool3d( + int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, int poolingMode, int extraParam0, bool isNCDHW, + int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, + const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::algorithm& algorithm, + dnnl::memory::desc* pool_src_md, dnnl::memory::desc* pool_diff_src_md, dnnl::memory::desc* pool_dst_md, + dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, + dnnl::memory::dims& pool_strides, dnnl::memory::dims& pool_kernel, dnnl::memory::dims& pool_padding, dnnl::memory::dims& pool_padding_r) { + dnnl::memory::dims pool_src_tz = { bS, iC, iD, iH, iW }; + dnnl::memory::dims pool_dst_tz = { bS, oC, oD, oH, oW }; + + pool_strides = { sD, sH, sW }; + pool_kernel = { kD, kH, kW }; + pool_padding = { pD, pH, pW }; + pool_padding_r = { (oD - 1) * sD - iD + kD - pD, + (oH - 1) * sH - iH + kH - pH, + (oW - 1) * sW - iW + kW - pW }; + + algorithm = poolingMode == 0 ? algorithm::pooling_max + : extraParam0 == 0 ? algorithm::pooling_avg_exclude_padding + : algorithm::pooling_avg_include_padding; + auto type = dnnl::memory::data_type::f32; + auto format = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; + auto supposed_to_be_any_format = dnnl::memory::format_tag::nCdhw8c; // doesn't work with "any" + + if (src != nullptr && src->getBuffer() != nullptr && pool_src_md != nullptr) { + *pool_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format); + *user_src_md = dnnl::memory::desc({ pool_src_tz }, type, format); + user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" + user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCDHW ? 0 : 0]; + user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCDHW ? 1 : 4]; + user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCDHW ? 2 : 1]; + user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCDHW ? 3 : 2]; + user_src_md->data.format_desc.blocking.strides[4] = src->stridesOf()[isNCDHW ? 4 : 3]; + } + + if (diff_src != nullptr && diff_src->getBuffer() != nullptr && pool_diff_src_md != nullptr) { + *pool_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format); + *user_diff_src_md = dnnl::memory::desc({ pool_src_tz }, type, format); + user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" + user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCDHW ? 0 : 0]; + user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCDHW ? 1 : 4]; + user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCDHW ? 2 : 1]; + user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCDHW ? 3 : 2]; + user_diff_src_md->data.format_desc.blocking.strides[4] = diff_src->stridesOf()[isNCDHW ? 4 : 3]; + } + + if (dst != nullptr && dst->getBuffer() != nullptr && pool_dst_md != nullptr) { + *pool_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format); + *user_dst_md = dnnl::memory::desc({ pool_dst_tz }, type, format); + user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" + user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCDHW ? 0 : 0]; + user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCDHW ? 1 : 4]; + user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCDHW ? 2 : 1]; + user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCDHW ? 3 : 2]; + user_dst_md->data.format_desc.blocking.strides[4] = dst->stridesOf()[isNCDHW ? 4 : 3]; + } +}; + +////////////////////////////////////////////////////////////////////////// +void getMKLDNNMemoryDescConv2d( + int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, const int paddingMode, bool isNCHW, + int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src, + const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst, + dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md, + dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md, + dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md, + dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md, + dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation) { + dnnl::memory::dims conv_src_tz = { bS, iC, iH, iW }; + dnnl::memory::dims conv_weights_tz = { oC, iC, kH, kW }; + dnnl::memory::dims conv_bias_tz = { oC }; + dnnl::memory::dims conv_dst_tz = { bS, oC, oH, oW }; + + const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d + + conv_strides = { sH, sW }; + conv_padding = { pH, pW }; + conv_padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame }; + conv_dilation = { dH-1, dW-1}; + + auto type = dnnl::memory::data_type::f32; + auto format = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; + auto formatw = dnnl::memory::format_tag::hwio; + + if (src != nullptr && conv_src_md != nullptr) { + *conv_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any); + *user_src_md = dnnl::memory::desc({ conv_src_tz }, type, format); + user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc" + user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCHW ? 0 : 0]; + user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCHW ? 1 : 3]; + user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCHW ? 2 : 1]; + user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCHW ? 3 : 2]; + } + + if (diff_src != nullptr && conv_diff_src_md != nullptr) { + *conv_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any); + *user_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, format); + user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc" + user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCHW ? 0 : 0]; + user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCHW ? 1 : 3]; + user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCHW ? 2 : 1]; + user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCHW ? 3 : 2]; + } + + if (weights != nullptr && conv_weights_md != nullptr) { + *conv_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any); + *user_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw); + user_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = hwio" + user_weights_md->data.format_desc.blocking.strides[0] = weights->stridesOf()[3]; + user_weights_md->data.format_desc.blocking.strides[1] = weights->stridesOf()[2]; + user_weights_md->data.format_desc.blocking.strides[2] = weights->stridesOf()[0]; + user_weights_md->data.format_desc.blocking.strides[3] = weights->stridesOf()[1]; + } + + if (diff_weights != nullptr && conv_diff_weights_md != nullptr) { + *conv_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any); + *user_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw); + user_diff_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = hwio" + user_diff_weights_md->data.format_desc.blocking.strides[0] = diff_weights->stridesOf()[3]; + user_diff_weights_md->data.format_desc.blocking.strides[1] = diff_weights->stridesOf()[2]; + user_diff_weights_md->data.format_desc.blocking.strides[2] = diff_weights->stridesOf()[0]; + user_diff_weights_md->data.format_desc.blocking.strides[3] = diff_weights->stridesOf()[1]; + } + + if (bias != nullptr && conv_bias_md != nullptr) { + *conv_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::any); + *user_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::x); + } + + if (dst != nullptr && conv_dst_md != nullptr) { + *conv_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, dnnl::memory::format_tag::any); + *user_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, format); + user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCHW ? nchw : nhwc" + user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCHW ? 0 : 0]; + user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCHW ? 1 : 3]; + user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCHW ? 2 : 1]; + user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCHW ? 3 : 2]; + } +} + +////////////////////////////////////////////////////////////////////////// +void getMKLDNNMemoryDescConv3d( + int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, bool paddingMode, bool isNCDHW, + int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, const NDArray* src, const NDArray* diff_src, + const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst, + dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md, + dnnl::memory::desc* conv_diff_weights_md, dnnl::memory::desc* conv_bias_md, dnnl::memory::desc* conv_dst_md, + dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_weights_md, + dnnl::memory::desc* user_diff_weights_md, dnnl::memory::desc* user_bias_md, dnnl::memory::desc* user_dst_md, + dnnl::memory::dims& conv_strides, dnnl::memory::dims& conv_padding, dnnl::memory::dims& conv_padding_r, dnnl::memory::dims& conv_dilation) { + dnnl::memory::dims conv_src_tz = { bS, iC, iD, iH, iW }; + dnnl::memory::dims conv_weights_tz = { oC, iC, kD, kH, kW }; + dnnl::memory::dims conv_bias_tz = { oC }; + dnnl::memory::dims conv_dst_tz = { bS, oC, oD, oH, oW }; + + conv_strides = { sD, sH, sW }; + conv_padding = { pD, pH, pW }; + conv_padding_r = { (oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW }; + conv_dilation = { dD-1, dH-1, dW-1}; + + auto type = dnnl::memory::data_type::f32; + auto format = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; + auto formatw = dnnl::memory::format_tag::dhwio; + + if (src != nullptr && conv_src_md != nullptr) { + *conv_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any); + *user_src_md = dnnl::memory::desc({ conv_src_tz }, type, format); + user_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" + user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[isNCDHW ? 0 : 0]; + user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[isNCDHW ? 1 : 4]; + user_src_md->data.format_desc.blocking.strides[2] = src->stridesOf()[isNCDHW ? 2 : 1]; + user_src_md->data.format_desc.blocking.strides[3] = src->stridesOf()[isNCDHW ? 3 : 2]; + user_src_md->data.format_desc.blocking.strides[4] = src->stridesOf()[isNCDHW ? 4 : 3]; + } + + if (diff_src != nullptr && conv_diff_src_md != nullptr) { + *conv_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, dnnl::memory::format_tag::any); + *user_diff_src_md = dnnl::memory::desc({ conv_src_tz }, type, format); + user_diff_src_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" + user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[isNCDHW ? 0 : 0]; + user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[isNCDHW ? 1 : 4]; + user_diff_src_md->data.format_desc.blocking.strides[2] = diff_src->stridesOf()[isNCDHW ? 2 : 1]; + user_diff_src_md->data.format_desc.blocking.strides[3] = diff_src->stridesOf()[isNCDHW ? 3 : 2]; + user_diff_src_md->data.format_desc.blocking.strides[4] = diff_src->stridesOf()[isNCDHW ? 4 : 3]; + } + + if (weights != nullptr && conv_weights_md != nullptr) { + *conv_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any); + *user_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw); + user_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = dhwio" + user_weights_md->data.format_desc.blocking.strides[0] = weights->stridesOf()[4]; + user_weights_md->data.format_desc.blocking.strides[1] = weights->stridesOf()[3]; + user_weights_md->data.format_desc.blocking.strides[2] = weights->stridesOf()[0]; + user_weights_md->data.format_desc.blocking.strides[3] = weights->stridesOf()[1]; + user_weights_md->data.format_desc.blocking.strides[4] = weights->stridesOf()[2]; + } + + if (diff_weights != nullptr && conv_diff_weights_md != nullptr) { + *conv_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, dnnl::memory::format_tag::any); + *user_diff_weights_md = dnnl::memory::desc({ conv_weights_tz }, type, formatw); + user_diff_weights_md->data.format_kind = dnnl_blocked; // overrides "formatw = dhwio" + user_diff_weights_md->data.format_desc.blocking.strides[0] = diff_weights->stridesOf()[4]; + user_diff_weights_md->data.format_desc.blocking.strides[1] = diff_weights->stridesOf()[3]; + user_diff_weights_md->data.format_desc.blocking.strides[2] = diff_weights->stridesOf()[0]; + user_diff_weights_md->data.format_desc.blocking.strides[3] = diff_weights->stridesOf()[1]; + user_diff_weights_md->data.format_desc.blocking.strides[4] = diff_weights->stridesOf()[2]; + } + + if (bias != nullptr && conv_bias_md != nullptr) { + *conv_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::any); + *user_bias_md = dnnl::memory::desc({ conv_bias_tz }, type, dnnl::memory::format_tag::x); + } + + if (dst != nullptr && conv_dst_md != nullptr) { + *conv_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, dnnl::memory::format_tag::any); + *user_dst_md = dnnl::memory::desc({ conv_dst_tz }, type, format); + user_dst_md->data.format_kind = dnnl_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" + user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[isNCDHW ? 0 : 0]; + user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[isNCDHW ? 1 : 4]; + user_dst_md->data.format_desc.blocking.strides[2] = dst->stridesOf()[isNCDHW ? 2 : 1]; + user_dst_md->data.format_desc.blocking.strides[3] = dst->stridesOf()[isNCDHW ? 3 : 2]; + user_dst_md->data.format_desc.blocking.strides[4] = dst->stridesOf()[isNCDHW ? 4 : 3]; + } +}; + + +// void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst, +// dnnl::memory::desc* batchnorm_src_md, dnnl::memory::desc* batchnorm_diff_src_md, dnnl::memory::desc* batchnorm_dst_md, +// dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis) { +// const Nd4jLong* shape = src->getShapeInfo(); +// Nd4jLong rank = shape[0]; +// Nd4jLong dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one +// Nd4jLong dim2 = axis >= 2 ? 1 : 2; +// Nd4jLong dim3 = axis >= 3 ? 2 : 3; +// dnnl::memory::dims batchnorm_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1}; + +// auto type = dnnl::memory::data_type::f32; +// auto format = dnnl::memory::format_tag::nchw; +// auto supposed_to_be_any_format = dnnl::memory::format_tag::nChw8c; // doesn't work with "any" + +// if (src != nullptr && src->getBuffer() != nullptr && batchnorm_src_md != nullptr) { +// *batchnorm_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format); +// *user_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format); +// user_src_md->data.format_kind = dnnl_blocked; // overrides format +// user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0]; +// user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1]; +// user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1; +// user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? src->stridesOf()[dim3] : 1; +// } + +// if (diff_src != nullptr && diff_src->getBuffer() != nullptr && batchnorm_diff_src_md != nullptr) { +// *batchnorm_diff_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format); +// *user_diff_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format); +// user_diff_src_md->data.format_kind = dnnl_blocked; // overrides format +// user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0]; +// user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1]; +// user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1; +// user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1; +// } + +// if (dst != nullptr && dst->getBuffer() != nullptr && batchnorm_dst_md != nullptr) { +// *batchnorm_dst_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format); +// *user_dst_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format); +// user_dst_md->data.format_kind = dnnl_blocked; // overrides format +// user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0]; +// user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1]; +// user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1; +// user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? dst->stridesOf()[dim3] : 1; +// } +// }; + +////////////////////////////////////////////////////////////////////////// +void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst, + dnnl::memory::desc* lrn_src_md, dnnl::memory::desc* lrn_diff_src_md, dnnl::memory::desc* lrn_dst_md, + dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis) { + const Nd4jLong* shape = src->getShapeInfo(); + long rank = shape[0]; + long dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one + long dim2 = axis >= 2 ? 1 : 2; + long dim3 = axis >= 3 ? 2 : 3; + dnnl::memory::dims lrn_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1}; + + auto type = dnnl::memory::data_type::f32; + auto format = axis == 1 ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; + auto supposed_to_be_any_format = format; // doesn't work with "any" + + if (src != nullptr && src->getBuffer() != nullptr && lrn_src_md != nullptr) { + *lrn_src_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format); + *user_src_md = dnnl::memory::desc({ lrn_src_tz }, type, format); + user_src_md->data.format_kind = dnnl_blocked; + user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0]; + user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1]; + user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1; + user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? src->stridesOf()[dim3] : 1; + } + + if (diff_src != nullptr && diff_src->getBuffer() != nullptr && lrn_diff_src_md != nullptr) { + *lrn_diff_src_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format); + *user_diff_src_md = dnnl::memory::desc({ lrn_src_tz }, type, format); + user_diff_src_md->data.format_kind = dnnl_blocked; + user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0]; + user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1]; + user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1; + user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1; + } + + if (dst != nullptr && dst->getBuffer() != nullptr && lrn_dst_md != nullptr) { + *lrn_dst_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format); + *user_dst_md = dnnl::memory::desc({ lrn_src_tz }, type, format); + user_dst_md->data.format_kind = dnnl_blocked; + user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0]; + user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1]; + user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1; + user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? dst->stridesOf()[dim3] : 1; + } +} + +////////////////////////////////////////////////////////////////////////// +dnnl::engine& getEngine(void *ptr) { + auto eng = reinterpret_cast(ptr); + return *eng; +} + + +} } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp b/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp index 9ed9f0ee6..9aafe869e 100644 --- a/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp @@ -318,36 +318,6 @@ TEST_F(ConvolutionTests1, conv2d_8) { delete results; } -TYPED_TEST(TypedConvolutionTests1, TestAvgFF_TF) { - - auto input = NDArrayFactory::create('c', {4, 10, 10, 3}, {9.37125111f, 2.20166993f, 2.91434479f, 5.43639755f, -2.10573769f, 4.08528662f, 5.86908436f, -4.46203756f, 2.21057916f, 5.35849190f, 0.01394637f, 4.40566349f, 7.07982206f, -0.09633455f, 2.42429352f, 3.97301817f, -1.89553940f, 1.99690318f, 6.33141708f, 0.55401880f, 1.70707977f, - 5.55204201f, -0.03513752f, 1.60011971f, 2.62700319f, -2.74582434f, 3.06697464f, 1.06277943f, -1.16075921f, -0.78095782f, 9.72352791f, -1.22686064f, 1.99644792f, 7.35571337f, 1.40607321f, 0.11390255f, 9.53334427f, 2.28303599f, -1.66728830f, 6.16678810f, -0.04532295f, -1.97708666f, 9.74906158f, 1.46223176f, -1.46734393f, 4.30761862f, - -1.23790228f, 1.24823606f, 6.13938427f, -3.83689475f, -1.19625473f, 7.91535568f, 6.05868721f, -3.22946382f, 8.81633949f, -0.19967777f, 0.66053957f, 2.30919123f, 0.74543846f, -0.39347672f, 11.11058044f, 0.53720862f, 1.52645731f, 5.70012379f, -1.15213466f, 1.16451406f, 7.00526333f, 1.57362783f, -2.44384766f, 5.54213285f, -1.98828590f, - -0.70483637f, 7.88281822f, -3.59875536f, 0.80745387f, 13.41578484f, -1.55507684f, -0.65855008f, 9.32583523f, -0.14544789f, 0.73436141f, 3.61176538f, -1.71268058f, -2.58490300f, 9.09280205f, -3.27405524f, -2.04569697f, 4.44761324f, -0.62955856f, -2.61917663f, 8.04890442f, 0.54579324f, 0.85929775f, 9.82259560f, -1.93825579f, 0.77703512f, - 4.67090321f, -4.79267597f, -2.38906908f, 9.31265545f, 0.96026313f, -1.14109385f, 11.54231834f, -0.01417295f, -0.39500344f, 8.49191666f, 0.55300158f, 2.79490185f, 6.92466164f, 1.72254205f, 2.82222271f, 8.83112717f, 2.95033407f, 2.18054962f, 6.73509789f, -2.22272944f, 0.51127720f, -1.04563558f, 2.15747333f, -2.30959272f, 9.55441570f, - 1.50396204f, 1.77370787f, 7.38146257f, -1.79076433f, 3.20961165f, 7.18864202f, 2.91217351f, 0.43018937f, 7.11078024f, -1.17386127f, -0.16817921f, 6.12327290f, -2.82205725f, 3.30696845f, 13.51291752f, -1.30856836f, -2.38332748f, 11.09487438f, -1.47190213f, -0.53050828f, 4.38285351f, -5.07309771f, 1.50714362f, 5.72274446f, -2.85825086f, - -0.89673209f, 3.73791552f, -0.67708802f, -4.13149452f, -0.00671843f, -0.26566532f, 0.32961160f, 7.14501762f, -1.41608179f, -4.96590328f, 12.26205540f, -0.65158135f, -0.88641000f, 6.95777559f, -0.79058206f, -0.10260171f, 7.87169170f, 1.35921454f, 1.11759663f, 5.46187401f, -2.57214499f, 2.48484039f, 4.04043484f, -2.07137156f, -1.42709637f, - 9.25487137f, -0.12605135f, -2.66949964f, 2.89412403f, 0.74451172f, -2.96250391f, 3.99258423f, 0.27084303f, 0.32213116f, 5.42332172f, -0.44414216f, 1.70881832f, 6.69346905f, 0.53058422f, -4.73146200f, 4.22051668f, 2.24834967f, 0.66996074f, 4.30173683f, 0.11849818f, -4.07520294f, 8.27318478f, -2.54398274f, -2.86705542f, 10.11775303f, - -0.99382895f, 0.65881538f, 7.93556786f, -1.27934420f, -1.69343162f, 9.68042564f, -1.02609646f, -1.18189347f, 5.75370646f, -1.67888868f, -4.48871994f, 4.79537392f, -0.79212248f, -0.19855022f, 6.15060997f, -0.01081491f, 3.64454579f, 10.82562447f, 1.58859253f, -2.65847278f, 8.60093212f, -1.59196103f, 0.07635692f, 11.76175690f, -1.17453325f, - 0.10122013f, 6.86458445f, -2.18891335f, -2.74004745f, 8.07066154f, 0.71818852f, -2.03035975f, 6.31053686f, 0.51509416f, 1.39789927f, 9.43515587f, 2.04256630f, 0.13985133f, 4.65010691f, 2.40911126f, -0.36255789f, -3.06867862f, -0.45225358f, -1.56778407f, 6.05917358f, -1.09891272f, 1.77184200f, 6.46248102f, 0.96042323f, -0.24346280f, - 4.63436460f, -4.69907761f, 1.25187206f, 11.46173859f, -2.21917558f, 1.28007793f, 6.92173195f, 2.11268163f, -3.47389889f, 5.08722782f, -3.03950930f, -4.17154264f, 11.30568314f, 0.80361372f, 2.53214502f, 7.18707085f, -4.49114513f, 2.85449266f, 10.14906883f, -0.31974933f, -0.84472644f, -0.52459574f, 0.12921631f, -1.81390119f, 2.76170087f, 1.03982210f, 2.91744232f, -0.29048753f, 5.87453508f, -1.53684759f, 1.85800636f, -0.91404629f, 1.28954852f, 5.11354685f, -2.47475505f, -1.33179152f, 2.58552408f, 1.37316465f, -3.32339454f, 1.54122913f, 3.24953628f, -0.29758382f, 2.82391763f, -1.51142192f, -1.22699404f, 6.75745535f, 0.65452754f, -3.29385471f, 2.06008053f, 2.53172946f, -4.23532820f, -1.53909743f, -0.07010663f, -1.42173731f, 7.29031610f, -0.18448229f, 4.59496164f, 6.73027277f, 0.73441899f, 0.14426160f, 4.14915276f, -2.97010231f, 6.05851364f, 4.95218086f, -2.39145470f, 2.40494704f, 2.10288811f, 0.53503096f, 1.44511235f, 6.66344261f, -3.05803776f, 7.21418667f, 3.30303526f, -0.24163735f, 3.47409391f, 3.64520788f, 2.15189481f, -3.11243272f, 3.62310791f, 0.37379482f, 0.40865007f, -0.83132005f, -4.78246069f, 2.07030797f, 6.51765442f, 3.16178989f, 5.06180477f, 3.78434467f, -0.96689719f, 0.35965276f, 5.89967585f, 1.40294051f, 1.11952639f, 10.59778214f, 0.26739889f, -1.61297631f, 6.24801159f, -0.93914318f, -0.57812452f, 9.92604542f, -0.73025000f, -3.38530874f, 2.45646000f, -2.47949195f, 0.51638460f, 10.65636063f, 1.97816694f, -3.00407791f, 2.66914415f, -0.81951088f, -0.23316640f, 2.40737987f, -2.70007610f, 1.51531935f, 4.08860207f, -0.27552786f, -1.31721711f, 7.11568260f, -3.33498216f, -4.02545023f, 7.22675610f, -0.81690705f, -2.52689576f, 1.04016697f, -0.79291463f, -0.34875512f, 10.00498390f, -4.24167728f, 1.46162593f, 11.82569408f, -1.70359993f, -0.30161047f, 16.44085884f, -0.82253462f, -0.09435523f, 6.13080597f, -0.20259480f, 0.68308711f, 6.15663004f, -6.61776876f, 0.33295766f, 2.55449438f, -0.17819691f, -1.14892209f, 5.56776142f, 1.99279118f, 1.33035934f, 4.45823956f, 3.34916544f, -2.59905386f, 6.16164446f, -2.03881931f, -2.45273542f, 12.46793365f, -2.22743297f, 2.83738565f, 8.48628139f, -1.39347959f, -1.30867767f, 11.08041477f, -4.00363779f, 2.09183025f, 11.30395889f, -2.20504737f, 1.37426853f, 8.98735619f, 1.04676604f, -0.72757077f, 8.28050232f, -6.70741081f, -0.65798020f, 5.68592072f, -0.60760021f, 0.35854483f, 6.26852131f, 1.94100165f, 1.32112014f, 0.80987954f, -1.74617672f, -0.25434083f, 7.16045523f, 1.58884013f, -2.64847064f, 13.14820385f, 1.21393633f, -2.47258949f, 9.41650105f, -0.79384226f, 2.48954105f, 10.95629311f, 0.47723705f, 4.02126694f, 8.02593136f, -2.20726371f, -1.18794477f, 1.50836647f, 0.93118095f, -1.73513174f, 8.85493565f, -2.99670315f, -0.79055870f, 2.39473820f, 2.05046916f, -2.38055134f, 11.82299423f, 0.15609655f, 0.68744308f, 5.66401434f, -0.69281673f, 2.09855556f, 7.74626589f, -0.34283102f, 1.00542057f, 9.95838642f, 0.80161905f, 2.33455157f, 9.80057335f, -0.93561798f, 2.56991577f, 8.29711342f, 0.94213426f, 0.44209945f, 11.70259857f, 0.92710167f, 2.60957146f, 0.24971688f, -0.86529571f, 3.78628922f, 6.80884457f, -0.68178189f, 2.21103406f, 3.18895817f, 0.60283208f, -2.92716241f, 6.72060776f, -1.06625068f, 2.56543374f, 9.97404480f, 3.58080721f, -0.94936347f, 10.16736984f, -1.38464379f, 1.18191063f, 6.66179037f, -3.56115270f, 0.32329530f, 10.90870762f, 2.20638227f, 0.19653285f, 7.34650040f, -3.63859272f, -1.03027737f, 5.98829985f, -3.66606474f, -3.89746714f, 8.63469028f, 1.22569811f, 1.63240814f, 3.74385309f, 0.58243257f, -0.56981975f, 3.69260955f, 1.00979900f, -1.44030499f, 8.57058144f, -1.10648811f, 1.20474911f, 5.43133020f, -2.14822555f, -0.07928789f, 11.25825310f, 0.19645604f, -5.49546146f, 10.41917038f, -0.68178523f, -2.99639869f, 6.50054455f, 0.46488351f, -5.42328453f, 9.09500027f, -2.82107449f, 0.05601966f, 15.34610748f, -0.06820253f, 3.86699796f, 10.73316956f, -3.04795432f, -0.14702171f, 5.64813185f, 1.44028485f, -2.47596145f, 0.07280898f, -3.03187990f, -1.35183525f, 9.35835648f, 2.72966957f, 1.88199532f, 10.36187744f, -0.22834805f, -3.26738238f, 6.92025137f, -2.34061313f, 4.77379704f, 5.28559113f, -2.96323752f, -1.76186585f, 5.94436455f, 0.38647744f, -5.73869514f, 6.76849556f, 1.40892124f, -1.19068217f, 5.37919092f, -6.65328646f, 3.62782669f, 12.34744644f, 2.44762444f, -4.19242620f, 6.14906216f, 0.08121119f, 0.61355996f, 2.69666457f, -1.88962626f, -0.55314136f, 1.84937525f, 1.56048691f, 1.17460012f, 3.75674725f, 1.06198275f, -5.74625874f, 5.41645575f, -1.28946674f, -1.51689398f, 4.32400894f, -0.05222082f, -4.83948946f, 1.80747867f, 1.63144708f, -2.73887825f, 1.63975775f, -2.02163982f, -0.16210437f, 2.93518686f, 1.14427686f, -2.83246303f, 4.79283667f, 2.69697428f, -3.12678456f, -1.19225168f, -2.37022972f, -3.09429741f, 1.94225383f, -1.13747168f, -2.55048585f, 5.40242243f, 1.12777328f, 3.43713188f, 3.62658787f, -2.16878843f, 0.30164462f, 2.97407579f, -0.07275413f, -1.31149673f, 4.70066261f, -2.01323795f, 4.85255766f, 4.59128904f, 1.68084168f, 1.60336494f, 6.58138466f, -1.04759812f, 2.69906545f, 3.55769277f, -0.74327278f, 2.65819693f, 5.39528131f, 2.11248922f, -1.06446671f, 5.24546766f, -2.43146014f, 4.58907509f, 0.06521678f, -2.24503994f, 2.45722699f, 6.94863081f, 0.35258654f, 2.83396196f, 9.92525196f, -1.12225175f, -0.34365177f, 7.19116688f, -4.39813757f, 0.46517885f, 13.22028065f, -2.57483673f, -6.37226963f, 7.58046293f, -2.74600363f, 0.42231262f, 8.04881668f, 0.17289802f, -0.53447008f, 16.55157471f, -5.63614368f, 0.39288223f, 3.37079263f, 1.26484549f, -0.12820500f, 8.46440125f, -4.39304399f, 2.97676420f, 0.65650189f, 0.83158541f, -1.11556435f, 6.32885838f, -0.36087769f, 2.80724382f, 9.90292645f, 1.15936041f, 0.20947981f, 6.91249275f, -2.67404819f, 2.93782163f, 6.65656614f, -2.30828357f, 2.98214006f, 6.80611229f, -4.93821478f, -7.66555262f, 7.59763002f, -0.54159302f, 3.87403512f, 12.42607784f, 2.59284401f, -0.23375344f, 8.95293331f, -0.71807784f, 0.61873478f, 8.66713524f, 1.24289191f, -2.37835455f, 2.08071637f, -0.88315344f, -3.41891551f, 6.85245323f, 1.73007369f, 1.02169311f, 7.69170332f, -2.85411978f, 2.69790673f, 8.12906551f, -1.19351399f, -2.26442742f, 12.26104450f, -0.75579089f, -1.73274946f, 10.68729019f, 2.20655656f, -0.90522075f, 12.42165184f, -1.67929137f, 2.44851565f, 9.31565762f, -0.06645700f, 1.52762020f, 6.18427515f, -1.68882596f, 3.70261097f, 3.02252960f, -3.44125366f, -1.31575799f, 2.84617424f, -0.96849400f, -4.52356243f, 9.95027161f, 0.19966406f, -0.78874779f, 8.18595028f, -4.08300209f, 1.75126517f, 0.96418417f, -4.04913044f, -0.95200396f, 12.03637886f, -0.03041124f, 0.41642749f, 8.88267422f, -3.24985337f, -2.24919462f, 7.32566118f, 0.16964148f, -2.74123430f, 7.05264473f, -3.30191112f, 0.17163286f, 4.81851053f, -1.64463484f, -0.85933101f, 7.29276276f, 2.34066939f, -2.14860010f, 3.46148157f, -0.01782012f, 1.51504040f, 4.79304934f, 1.85281146f, -1.70663762f, 6.93470192f, -4.15440845f, -1.25983095f, 10.52491760f, 0.42930329f, -1.85146868f, 11.70042324f, -0.41704914f, 3.83796859f, 9.21148491f, -2.79719448f, 0.79470479f, 6.26926661f, -5.85230207f, 3.95105338f, 7.84790897f, -1.38680744f, -1.78099084f, 11.95235348f, -2.99841452f, -1.34507811f, 6.15714645f, -1.07552516f, -2.81228638f, 1.66234732f, -4.55166149f, -1.92601109f, 8.64634514f, -0.48158705f, 3.31595659f, 7.67371941f, 2.56964207f, 0.12107098f, 4.56467867f, -0.93541539f, 1.39432955f, 11.99714088f, 1.05353570f, -2.13099813f, 3.67617917f, 3.45895386f, 1.37365830f, 8.74344158f, -4.17585802f, 1.43908918f, 6.28764772f, 3.97346330f, -0.69144285f, 9.07983303f, -0.41635889f, -0.14965028f, 8.85469818f, 1.11306190f, 2.59440994f, 5.38982344f, -1.07948279f, 1.37252975f, 10.26984596f, -0.09318046f, 2.73104119f, 12.45902252f, -1.55446684f, -2.76124811f, 12.19395065f, -0.51846564f, 1.02764034f, 11.42673588f, -0.95940983f, -0.04781032f, 8.78379822f, -4.88957930f, 0.32534006f, 11.97696400f, -3.35108662f, 1.95104563f, 4.46915388f, -2.32061648f, 3.45230985f, 8.29983711f, 2.81034684f, -2.35529327f, 6.07801294f, -0.98105043f, -0.05359888f, 2.52291036f, -0.01986909f, -2.35321999f, 10.51954269f, 2.11145401f, 3.53506470f, 7.29093266f, 0.03721160f, -1.13496494f, 7.43886709f, -5.84201956f, 2.50796294f, 12.14647675f, 2.77490377f, -2.18896222f, 6.05641937f, 5.32617044f, 1.04221284f, 10.79106712f, -2.95749092f, -2.75414610f, 11.30037117f, -3.40654182f, -2.24673963f, 7.49126101f, 0.70811015f, -6.18003702f, 13.83951187f, -1.01204085f, 1.36298490f, -1.04451632f, 2.42435336f, -0.02346706f, -0.85528886f, 1.04731262f, 0.22192979f, 4.15708160f, 0.34933877f, 0.04814529f, 2.24107265f, 0.49676740f, -1.47752666f, 0.45040059f, -0.70471478f, -1.19759345f, 0.21711677f, 0.88461423f, -2.76830935f, 5.52066898f, 1.97664857f, -1.75381601f, 3.45877838f, 1.52617192f, -1.61350942f, 0.85337949f, 1.97610760f, -3.40310287f, 3.40319014f, -3.38691044f, -0.71319139f, 1.65463758f, -0.60680127f, -1.80700517f, 8.02592373f, 2.59627104f, 2.65895891f, 5.93043184f, -4.48425817f, 3.92670918f, 4.19496679f, -2.28286791f, 6.41634607f, 5.72330523f, 1.16269672f, -0.28753027f, 2.46342492f, 0.36693189f, 0.26712441f, 6.37652683f, -2.50139046f, 2.43923736f, 5.56310415f, 0.98065847f, 1.04267502f, 4.16403675f, -0.04966142f, 4.40897894f, 3.72905660f, -3.46129870f, 3.59962773f, 1.34830284f, -1.76661730f, 0.47943926f, 5.29946661f, -1.12711561f, 1.26970029f, 15.17655945f, -1.50971997f, 5.81345224f, 8.48562050f, -4.36049604f, 2.48144460f, 8.23780441f, -3.46030426f, -0.84656560f, 5.94946814f, 1.12747943f, -2.65683913f, 8.69085693f, 1.31309867f, -2.79958344f, 8.76840591f, -1.56444156f, 1.62710834f, 2.41177034f, -0.72804940f, 5.70619011f, 4.67169666f, -0.86167198f, -1.83803177f, 2.96346045f, 2.82692933f, -2.81557131f, 7.11113358f, -1.90071094f, 2.54244423f, 11.19284058f, -0.06298946f, -1.71517313f, 12.98388577f, 0.84510714f, 3.00816894f, 2.57200313f, 0.03899818f, -1.49330592f, 9.60099125f, -3.59513044f, -1.30045319f, 7.09241819f, -0.65233821f, -2.33627677f, 8.81366920f, 0.84154201f, 1.03312039f, 9.85289097f, 0.19351870f, 1.78496623f, 7.34631205f, -2.16530800f, -0.65016162f, 2.46842360f, 0.24016285f, -1.24308395f, 4.78175163f, -0.97682536f, 2.20942235f, 6.68382788f, 3.76786447f, -1.44454038f, 6.26453733f, -3.23575711f, -2.30137897f, 9.53092670f, -5.55222607f, 3.25999236f, 9.37559509f, 1.86339056f, -0.23551451f, 10.23400211f, 3.93031883f, -0.52629089f, 7.85724449f, -2.91549587f, 4.46612740f, 5.66530371f, -2.70820427f, 4.81359577f, 10.31247330f, 1.92230141f, 2.53931546f, 0.74986327f, 1.70303428f, 0.48063779f, 5.31099129f, -0.78976244f, 3.75864220f, 4.23051405f, 2.34042454f, -7.98193836f, 9.83987141f, -1.46722627f, 3.54497814f, 10.36455154f, -4.51249075f, 0.77715248f, 7.78694630f, -4.59989023f, -2.49585629f, 9.90296268f, 1.38535416f, 1.17441154f, 10.10452843f, -0.98628229f, 0.60194463f, 9.12639141f, -3.90754628f, 2.88526392f, 7.24123430f, -0.15283313f, -0.75728363f, -1.15116858f, -2.53791571f, 0.77229571f, 6.44114161f, 0.02646767f, 4.95463037f, 7.21066380f, 1.79384065f, 0.73250306f, 8.04447937f, 0.32576546f, -0.79447043f, 10.12717724f, 2.33392906f, 1.30716443f, 12.36073112f, -0.36694977f, -1.20438910f, 7.03105593f, 0.59557682f, 0.69267452f, 10.18113136f, 2.49944925f, -0.42229167f, 8.83143330f, -1.18805945f, -2.87509322f, 4.53596449f, 4.09732771f, -3.39088297f, -1.02536607f, 0.82119560f, -3.47302604f, 9.29991817f, 0.21001509f, 4.97036457f, 9.50018406f, 1.04420102f, 1.96560478f, 10.74769592f, -6.22709799f, 3.11690164f, 5.06759691f, -1.23724771f, -3.05831861f, 8.12925529f, -1.93435478f, -1.10151744f, 9.32263088f, -0.04249470f, -5.98547363f, 10.49398136f, 0.26400441f, -0.78915191f, 13.28219604f, 2.99276900f, 0.74853164f, 2.49364305f, -3.43529654f, 4.05278301f, 2.13498688f, -2.35444307f, -0.79900265f, 4.66968822f, -0.31095147f, 3.60674143f, 12.37222099f, -0.07855003f, -3.30292702f, 12.15215874f, 0.60886210f, 2.87075138f, 7.75271845f, 0.38044083f, 3.34402204f, 6.40583277f, -0.87888050f, 0.67438459f, 6.91080809f, 1.98332930f, -0.08303714f, 8.08630371f, -0.16772588f, -2.74058914f, 7.17253590f, -2.69122696f, 1.48173678f, 8.99470139f, -1.43302310f, -0.88651133f, 2.66944790f, -0.29186964f, 2.00838661f, 5.09587479f, -0.76676071f, -2.88322186f, 8.31110573f, -0.14550979f, -1.37726915f, 10.28355122f, -1.60575438f, -0.04118848f, 9.97510815f, 0.14440438f, -3.24632120f, 9.00034523f, 4.14319563f, -1.31023729f, 7.16950464f, -0.70428526f, 2.01559544f, 7.26155043f, 2.40816474f, 2.09847403f, 7.31264496f, -0.75401551f, 2.13392544f, 7.03648758f, 1.04036045f, -1.15636516f, 1.09634531f, -0.06340861f, -0.58107805f, -0.65623116f, 1.18972754f, -0.80717683f, 1.40118241f, -0.61932516f, -3.60596156f, 1.59904599f, -2.23774099f, -1.13721037f, 3.89620137f, -0.09115922f, -7.51356888f, 2.36975193f, -1.42520905f, -2.34173775f, 3.33830214f, -2.74016523f, -3.04115510f, 6.00119495f, -1.36084354f, -2.45065260f, 4.56992292f, -3.02825928f, -3.74182844f, 5.11069250f, -0.91531068f, -2.31385994f, 1.83399653f, 3.39370203f, -3.60886002f}); - auto exp = NDArrayFactory::create('c', {4, 4, 4, 3}, {7.97172260f, 0.06878620f, 2.27749538f, 7.29276514f, -0.14074677f, 0.65480286f, 5.70313978f, -0.06546132f, 0.35443667f, 3.70382833f, -0.84020567f, 0.63826996f, 8.60301399f, -0.38236514f, 1.55177069f, 7.37542057f, -0.99374938f, -0.29971302f, 8.84352493f, -0.67121059f, 0.43132120f, 4.78175592f, -1.25070143f, -1.91523600f, 6.03855371f, -0.00292124f, -1.11214364f, 7.90158176f, -0.57949901f, -0.96735370f, 7.81192017f, -0.53255427f, -0.48009714f, 3.16953635f, 0.08353355f, -1.54299748f, 3.74821687f, 1.69396687f, 0.72724354f, 5.42915201f, -1.13686812f, -0.71793109f, 5.78376389f, -0.72239977f, -0.60055625f, 2.53636408f, 0.56777251f, -2.07892323f, 6.08064651f, 0.68620735f, 2.54017019f, 5.65828180f, -0.68255502f, 1.47283304f, 6.10842514f, -0.39655915f, 0.28380761f, 1.96707797f, -1.98206317f, 0.94027776f, 4.71811438f, 0.32104525f, -0.92409706f, 8.34588146f, -1.05581069f, -0.55217457f, 9.58440876f, -0.96549922f, 0.45820439f, 5.65453672f, -2.50953507f, -0.71441835f, 8.03059578f, -0.21281289f, 0.92125505f, 9.26900673f, -0.35963219f, -0.70039093f, 8.59924412f, -1.22358346f, 0.81318003f, 3.85920119f, -0.01305223f, -1.09234154f, 6.33158875f, 1.28094780f, -1.48926139f, 4.94969177f, -0.77126902f, -1.97033751f, 5.64381838f, -0.16285487f, -1.31277227f, 2.39893222f, -1.32902908f, -1.39609122f, 6.47572327f, -0.45267010f, 1.55727172f, 6.70965624f, -1.68735468f, -0.05672536f, 7.25092363f, -0.64613032f, 0.67050058f, 3.60789680f, -2.05948973f, 2.22687531f, 8.15202713f, -0.70148355f, 1.28314006f, 8.14842319f, -1.88807654f, -1.04808438f, 8.45500565f, -0.76425624f, 0.94542569f, 4.56179953f, -0.28786001f, -2.04502511f, 8.46278095f, -0.31019822f, 0.07339200f, 9.34214592f, -0.61948007f, 0.52481830f, 8.32515621f, -1.52418160f, 0.49678251f, 5.11082315f, -1.09908783f, -0.52969611f, 5.27806664f, 0.88632923f, 0.66754371f, 4.75839233f, 0.48928693f, -0.68036932f, 6.56925392f, -0.02949905f, -2.99189186f, 4.46320581f, -0.64534980f, -0.29516968f, 8.60809517f, -1.13120568f, 3.41720533f, 5.84243155f, -1.24109328f, 0.89566326f, 5.99578333f, -0.42496428f, 2.07076764f, 3.17812920f, -0.81566459f, -0.14363396f, 6.55184317f, 0.39633346f, -0.43852386f, 8.70214558f, -2.24613595f, 0.30708700f, 8.73882294f, -0.53545928f, 1.54409575f, 4.49452257f, -0.16509305f, 0.19028664f, 8.24897003f, 0.44750381f, 2.15448594f, 8.97640514f, -0.77728152f, 0.57272542f, 9.03467560f, 0.47173575f, -1.10807717f, 3.30056310f, -0.43268481f, -0.41470885f, 3.53798294f, -0.08546703f, -2.16840744f, 6.18733406f, -0.17871059f, -2.59837723f, 5.94218683f, -1.02990067f, -0.49760687f, 3.76938033f, 0.86383581f, -1.91504073f}); - - nd4j::ops::avgpool2d op; - auto result = op.execute({&input}, {}, {3,3, 3,3, 0,0, 1,1,1, 0,1}, {}); - ASSERT_EQ(Status::OK(), result->status()); - - auto z = result->at(0); - - // z->printIndexedBuffer("z"); - // exp.printIndexedBuffer("e"); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - delete result; -} - ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, sconv2d_1) { float _expB[] = {10025.0f, 10350.0f, 10675.0f, 11000.0f, 11325.0f, 11650.0f, 13275.0f, 13600.0f, 13925.0f, 14250.0f, 14575.0f, 14900.0f, 16525.0f, 16850.0f, 17175.0f, 17500.0f, 17825.0f, 18150.0f, 19775.0f, 20100.0f, 20425.0f, 20750.0f, 21075.0f, 21400.0f, 23025.0f, 23350.0f, 23675.0f, 24000.0f, 24325.0f, 24650.0f, 26275.0f, 26600.0f, 26925.0f, 27250.0f, 27575.0f, 27900.0f, 38775.0f, 40350.0f, 41925.0f, 43500.0f, 45075.0f, 46650.0f, 54525.0f, 56100.0f, 57675.0f, 59250.0f, 60825.0f, 62400.0f, 70275.0f, 71850.0f, 73425.0f, 75000.0f, 76575.0f, 78150.0f, 86025.0f, 87600.0f, 89175.0f, 90750.0f, 92325.0f, 93900.0f, 101775.0f, 103350.0f, 104925.0f, 106500.0f, 108075.0f, 109650.0f, 117525.0f, 119100.0f, 120675.0f, 122250.0f, 123825.0f, 125400.0f, 67525.0f, 70350.0f, 73175.0f, 76000.0f, 78825.0f, 81650.0f, 95775.0f, 98600.0f, 101425.0f, 104250.0f, 107075.0f, 109900.0f, 124025.0f, 126850.0f, 129675.0f, 132500.0f, 135325.0f, 138150.0f, 152275.0f, 155100.0f, 157925.0f, 160750.0f, 163575.0f, 166400.0f, 180525.0f, 183350.0f, 186175.0f, 189000.0f, 191825.0f, 194650.0f, 208775.0f, 211600.0f, 214425.0f, 217250.0f, 220075.0f, 222900.0f, 119400.0f, 120350.0f, 121300.0f, 122250.0f, 123200.0f, 124150.0f, 128900.0f, 129850.0f, 130800.0f, 131750.0f, 132700.0f, 133650.0f, 138400.0f, 139350.0f, 140300.0f, 141250.0f, 142200.0f, 143150.0f, 147900.0f, 148850.0f, 149800.0f, 150750.0f, 151700.0f, 152650.0f, 157400.0f, 158350.0f, 159300.0f, 160250.0f, 161200.0f, 162150.0f, 166900.0f, 167850.0f, 168800.0f, 169750.0f, 170700.0f, 171650.0f, 273150.0f, 275350.0f, 277550.0f, 279750.0f, 281950.0f, 284150.0f, 295150.0f, 297350.0f, 299550.0f, 301750.0f, 303950.0f, 306150.0f, 317150.0f, 319350.0f, 321550.0f, 323750.0f, 325950.0f, 328150.0f, 339150.0f, 341350.0f, 343550.0f, 345750.0f, 347950.0f, 350150.0f, 361150.0f, 363350.0f, 365550.0f, 367750.0f, 369950.0f, 372150.0f, 383150.0f, 385350.0f, 387550.0f, 389750.0f, 391950.0f, 394150.0f, 426900.0f, 430350.0f, 433800.0f, 437250.0f, 440700.0f, 444150.0f, 461400.0f, 464850.0f, 468300.0f, 471750.0f, 475200.0f, 478650.0f, 495900.0f, 499350.0f, 502800.0f, 506250.0f, 509700.0f, 513150.0f, 530400.0f, 533850.0f, 537300.0f, 540750.0f, 544200.0f, 547650.0f, 564900.0f, 568350.0f, 571800.0f, 575250.0f, 578700.0f, 582150.0f, 599400.0f, 602850.0f, 606300.0f, 609750.0f, 613200.0f, 616650.0f, 75025.0f, 75350.0f, 75675.0f, 76000.0f, 76325.0f, 76650.0f, 78275.0f, 78600.0f, 78925.0f, 79250.0f, 79575.0f, 79900.0f, 81525.0f, 81850.0f, 82175.0f, 82500.0f, 82825.0f, 83150.0f, 84775.0f, 85100.0f, 85425.0f, 85750.0f, 86075.0f, 86400.0f, 88025.0f, 88350.0f, 88675.0f, 89000.0f, 89325.0f, 89650.0f, 91275.0f, 91600.0f, 91925.0f, 92250.0f, 92575.0f, 92900.0f, 353775.0f, 355350.0f, 356925.0f, 358500.0f, 360075.0f, 361650.0f, 369525.0f, 371100.0f, 372675.0f, 374250.0f, 375825.0f, 377400.0f, 385275.0f, 386850.0f, 388425.0f, 390000.0f, 391575.0f, 393150.0f, 401025.0f, 402600.0f, 404175.0f, 405750.0f, 407325.0f, 408900.0f, 416775.0f, 418350.0f, 419925.0f, 421500.0f, 423075.0f, 424650.0f, 432525.0f, 434100.0f, 435675.0f, 437250.0f, 438825.0f, 440400.0f, 632525.0f, 635350.0f, 638175.0f, 641000.0f, 643825.0f, 646650.0f, 660775.0f, 663600.0f, 666425.0f, 669250.0f, 672075.0f, 674900.0f, 689025.0f, 691850.0f, 694675.0f, 697500.0f, 700325.0f, 703150.0f, 717275.0f, 720100.0f, 722925.0f, 725750.0f, 728575.0f, 731400.0f, 745525.0f, 748350.0f, 751175.0f, 754000.0f, 756825.0f, 759650.0f, 773775.0f, 776600.0f, 779425.0f, 782250.0f, 785075.0f, 787900.0f, 309400.0f, 310350.0f, 311300.0f, 312250.0f, 313200.0f, 314150.0f, 318900.0f, 319850.0f, 320800.0f, 321750.0f, 322700.0f, 323650.0f, 328400.0f, 329350.0f, 330300.0f, 331250.0f, 332200.0f, 333150.0f, 337900.0f, 338850.0f, 339800.0f, 340750.0f, 341700.0f, 342650.0f, 347400.0f, 348350.0f, 349300.0f, 350250.0f, 351200.0f, 352150.0f, 356900.0f, 357850.0f, 358800.0f, 359750.0f, 360700.0f, 361650.0f, 713150.0f, 715350.0f, 717550.0f, 719750.0f, 721950.0f, 724150.0f, 735150.0f, 737350.0f, 739550.0f, 741750.0f, 743950.0f, 746150.0f, 757150.0f, 759350.0f, 761550.0f, 763750.0f, 765950.0f, 768150.0f, 779150.0f, 781350.0f, 783550.0f, 785750.0f, 787950.0f, 790150.0f, 801150.0f, 803350.0f, 805550.0f, 807750.0f, 809950.0f, 812150.0f, 823150.0f, 825350.0f, 827550.0f, 829750.0f, 831950.0f, 834150.0f, 1116900.0f, 1120350.0f, 1123800.0f, 1127250.0f, 1130700.0f, 1134150.0f, 1151400.0f, 1154850.0f, 1158300.0f, 1161750.0f, 1165200.0f, 1168650.0f, 1185900.0f, 1189350.0f, 1192800.0f, 1196250.0f, 1199700.0f, 1203150.0f, 1220400.0f, 1223850.0f, 1227300.0f, 1230750.0f, 1234200.0f, 1237650.0f, 1254900.0f, 1258350.0f, 1261800.0f, 1265250.0f, 1268700.0f, 1272150.0f, 1289400.0f, 1292850.0f, 1296300.0f, 1299750.0f, 1303200.0f, 1306650.0f,}; diff --git a/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp b/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp index a16d9cfbd..989d316de 100644 --- a/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp @@ -970,7 +970,6 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_6) { x.linspace(1); - nd4j::ops::maxpool2d op; auto result = op.execute({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 1, 1, 1}); @@ -991,7 +990,6 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_7) { x.linspace(1); - nd4j::ops::maxpool2d op; auto result = op.execute({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 0, 1, 1}); @@ -1012,7 +1010,6 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_8) { x.linspace(1); - nd4j::ops::maxpool2d op; auto result = op.execute({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 0, 1, 0}); @@ -1467,11 +1464,12 @@ TYPED_TEST(TypedConvolutionTests2, maxpool3d_bp_test1) { auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); auto gradO = NDArrayFactory::create('c', {bS, iC, oD, oH, oW}); auto expected = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.f, 0.3f, 0.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.5f, 0.6f, 0.f, 0.7f, 0.8f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.9f, 1.f, 0.f, 1.1f, 1.2f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.3f, 1.4f, 0.f, 1.5f, 1.6f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.7f, 1.8f, 0.f, 1.9f, 2.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.1f, 2.2f, 0.f, 2.3f, 2.4f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.5f, 2.6f, 0.f, 2.7f, 2.8f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.9f, 3.f, 0.f, 3.1f, 3.2f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.3f, 3.4f, 0.f, 3.5f, 3.6f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.7f, 3.8f, 0.f, 3.9f, 4.f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 4.1f, 4.2f, 0.f, 4.3f, 4.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 4.5f, 4.6f, 0.f, 4.7f, 4.8f}); + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.9f, 1.f, 0.f, 1.1f, 1.2f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.3f, 1.4f, 0.f, 1.5f, 1.6f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.7f, 1.8f, 0.f, 1.9f, 2.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.1f, 2.2f, 0.f, 2.3f, 2.4f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.5f, 2.6f, 0.f, 2.7f, 2.8f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.9f, 3.f, 0.f, 3.1f, 3.2f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.3f, 3.4f, 0.f, 3.5f, 3.6f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.7f, 3.8f, 0.f, 3.9f, 4.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 4.1f, 4.2f, 0.f, 4.3f, 4.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 4.5f, 4.6f, 0.f, 4.7f, 4.8f}); + input.linspace(1.); gradO.linspace(0.1, 0.1); diff --git a/libnd4j/tests_cpu/layers_tests/CuDnnTests.cu b/libnd4j/tests_cpu/layers_tests/CuDnnTests.cu index 8809ad894..02e1040aa 100644 --- a/libnd4j/tests_cpu/layers_tests/CuDnnTests.cu +++ b/libnd4j/tests_cpu/layers_tests/CuDnnTests.cu @@ -57,6 +57,17 @@ TEST_F(CuDnnTests, helpers_includer) { nd4j::ops::platforms::PLATFORM_depthwise_conv2d_ENGINE_CUDA depthwise_conv2d; nd4j::ops::platforms::PLATFORM_depthwise_conv2d_bp_ENGINE_CUDA depthwise_conv2d_bp; nd4j::ops::platforms::PLATFORM_batchnorm_ENGINE_CUDA batchnorm; + nd4j::ops::platforms::PLATFORM_batchnorm_bp_ENGINE_CUDA batchnorm_bp; + nd4j::ops::platforms::PLATFORM_avgpool2d_ENGINE_CUDA avgpool2d; + nd4j::ops::platforms::PLATFORM_avgpool2d_bp_ENGINE_CUDA avgpool2d_bp; + nd4j::ops::platforms::PLATFORM_maxpool2d_ENGINE_CUDA maxpool2d; + nd4j::ops::platforms::PLATFORM_maxpool2d_bp_ENGINE_CUDA maxpool2d_bp; + nd4j::ops::platforms::PLATFORM_avgpool3dnew_ENGINE_CUDA avgpool3dnew; + nd4j::ops::platforms::PLATFORM_avgpool3dnew_bp_ENGINE_CUDA avgpool3dnew_bp; + nd4j::ops::platforms::PLATFORM_maxpool3dnew_ENGINE_CUDA maxpool3dnew; + nd4j::ops::platforms::PLATFORM_maxpool3dnew_bp_ENGINE_CUDA maxpool3dnew_bp; + + printer({&conv2d}); printer({&conv2d_bp}); @@ -65,6 +76,15 @@ TEST_F(CuDnnTests, helpers_includer) { printer({&depthwise_conv2d}); printer({&depthwise_conv2d_bp}); printer({&batchnorm}); + printer({&batchnorm_bp}); + printer({&avgpool2d}); + printer({&avgpool2d_bp}); + printer({&maxpool2d}); + printer({&maxpool2d_bp}); + printer({&avgpool3dnew}); + printer({&avgpool3dnew_bp}); + printer({&maxpool3dnew}); + printer({&maxpool3dnew_bp}); #endif } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp index ee569a07c..18f58c2a1 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp @@ -25,6 +25,7 @@ #include #include #include +#include using namespace nd4j; @@ -2247,3 +2248,525 @@ TEST_F(DeclarableOpsTests13, batchnorm_test9) { delete results; } + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, batchnorm_bp_test1) { + + NDArray input ('c', {2,3,4}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {4}, {1.1, 1.2, 1.3, 1.4}, nd4j::DataType::FLOAT32); + NDArray variance('c', {4}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {4}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {4}, nd4j::DataType::FLOAT32); + NDArray gradO ('c', {2,3,4}, nd4j::DataType::FLOAT32); + + NDArray expdLdI('c', {2,3,4}, {-0.000056, -0.000056, -0.000056, -0.000056, -0.000034, -0.000034, -0.000034, -0.000034, -0.000011, -0.000011, -0.000011, -0.000011, 0.000011, 0.000011, 0.000011, 0.000011, 0.000034, 0.000034, 0.000034, 0.000034, 0.000056, 0.000056, 0.000056, 0.000056}, nd4j::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {6.148104, 6.148104, 6.148105, 6.148105}, nd4j::DataType::FLOAT32); + NDArray expdLdB('c', {4}, {3.6, 4.5, 5.4, 6.3}, nd4j::DataType::FLOAT32); + + input.linspace(0.1, 0.1); + variance.assign(0.46666667); + gamma.assign(1.2); + beta.assign(1.); // has no effect on gradient calculations + gradO.linspace(-0.9, 0.15); + + nd4j::ops::batchnorm_bp op; + + auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto dLdI = results->at(0); + auto dLdG = results->at(3); + auto dLdB = results->at(4); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + + delete results; +} + + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, batchnorm_bp_test2) { + + NDArray input ('c', {2,3,4}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {3}, {1.05, 1.1, 1.15}, nd4j::DataType::FLOAT32); + NDArray variance('c', {3}, {0.5, 0.6, 0.7}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {3}, {1.2, 1.3, 1.4}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {3}, nd4j::DataType::FLOAT32); + NDArray gradO ('c', {2,3,4}, nd4j::DataType::FLOAT32); + + NDArray expdLdI('c', {2,3,4}, {-0.601415, -0.521226, -0.441037, -0.360849, -0.456306, -0.395465, -0.334624, -0.273784, 0.396631, 0.343747, + 0.290863, 0.237978, 0.360849, 0.441037, 0.521226, 0.601415, 0.273784, 0.334625, 0.395465, 0.456306, -0.237978, + -0.290863, -0.343746, -0.396631}, nd4j::DataType::FLOAT32); + NDArray expdLdG('c', {3}, {5.81236 , 7.048771, 12.155388}, nd4j::DataType::FLOAT32); + NDArray expdLdB('c', {3}, {1.8, 6.6, 11.4}, nd4j::DataType::FLOAT32); + + input.linspace(0.1, 0.1); + // beta.assign(1.); // has no effect on gradient calculations + gradO.linspace(-0.9, 0.15); + + nd4j::ops::batchnorm_bp op; + + auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto dLdI = results->at(0); + auto dLdG = results->at(3); + auto dLdB = results->at(4); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + + delete results; +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, batchnorm_bp_test3) { + + NDArray input ('c', {2,3,4}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {2,1,4}, {1.05, 1.1, 1.15, 1.2, 1.25, 1.3, 1.35, 1.4}, nd4j::DataType::FLOAT32); + NDArray variance('c', {2,1,4}, {0.5, 0.6, 0.7, 0.8, 0.9, 1., 1.1, 1.2}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {2,1,4}, {1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {2,1,4}, nd4j::DataType::FLOAT32); + NDArray gradO ('c', {2,3,4}, nd4j::DataType::FLOAT32); + + NDArray expdLdI('c', {2,3,4}, {-0.577002, -0.744041, -0.850999, -0.922373, -0.000000, -0.000000, -0.000000, -0.000000, 0.577002, + 0.744041, 0.850999, 0.922373, -0.386037, -0.350205, -0.312047, -0.271737, -0.000000, -0.000000, + -0.000000, -0.000000, 0.386037, 0.350205, 0.312047, 0.271736}, nd4j::DataType::FLOAT32); + NDArray expdLdG('c', {2,1,4}, {1.378844, 0.910144, 0.573706, 0.335408, 2.640487, 2.954985, 3.289431, 3.64234 }, nd4j::DataType::FLOAT32); + NDArray expdLdB('c', {2,1,4}, {-0.9 , -0.45, 0. , 0.45, 4.5 , 4.95, 5.4 , 5.85}, nd4j::DataType::FLOAT32); + + input.linspace(0.1, 0.1); + // beta.assign(1.); // has no effect on gradient calculations + gradO.linspace(-0.9, 0.15); + + nd4j::ops::batchnorm_bp op; + + auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,0,2}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto dLdI = results->at(0); + auto dLdG = results->at(3); + auto dLdB = results->at(4); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + + delete results; +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, batchnorm_bp_test4) { + + NDArray input ('c', {2,4}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {4}, nd4j::DataType::FLOAT32); + NDArray gradO ('c', {2,4}, nd4j::DataType::FLOAT32); + + NDArray expdLdI('c', {2,4}, {0.162923, -0.289673, 0.354174, -0.386151, -0.162923, 0.289673, -0.354174, 0.386151}, nd4j::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {1.442483, 0.950200, 0.569207, 0.314641}, nd4j::DataType::FLOAT32); + NDArray expdLdB('c', {4}, {-1.2, -0.9, -0.6, -0.3}, nd4j::DataType::FLOAT32); + + input.linspace(0.1, 0.1); + gradO.linspace(-0.9, 0.15); + + nd4j::ops::batchnorm_bp op; + + auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto dLdI = results->at(0); + auto dLdG = results->at(3); + auto dLdB = results->at(4); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + + delete results; +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, batchnorm_bp_test5) { + +#if defined(HAVE_CUDNN) +return; +#endif + NDArray input ('c', {2,4,2,2}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {4}, nd4j::DataType::FLOAT32); + NDArray gradO ('c', {2,4,2,2}, nd4j::DataType::FLOAT32); + + NDArray expdLdI('c', {2,4,2,2}, {-0.737512, -0.659880, -0.582247, -0.504614, 0.561404, 0.502309, 0.443214, 0.384118, -1.168243, + -1.045270, -0.922297, -0.799324, 1.899026, 1.699128, 1.499231, 1.299333, 0.504614, 0.582247, 0.659880, 0.737512, -0.384118, + -0.443214, -0.502308, -0.561404, 0.799324, 0.922297, 1.045270, 1.168243, -1.299334, -1.499231, -1.699129, -1.899026}, nd4j::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {11.073181, 12.585667, 17.708657, 24.313186}, nd4j::DataType::FLOAT32); + NDArray expdLdB('c', {4}, {4.2, 9. , 13.8, 18.6}, nd4j::DataType::FLOAT32); + + input.linspace(0.1, 0.1); + gradO.linspace(-0.9, 0.15); + + nd4j::ops::batchnorm_bp op; + + auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto dLdI = results->at(0); + auto dLdG = results->at(3); + auto dLdB = results->at(4); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + + delete results; +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, batchnorm_bp_test6) { + +#if defined(HAVE_CUDNN) +return; +#endif + + NDArray input ('c', {2,2,2,4}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {4}, nd4j::DataType::FLOAT32); + NDArray gradO ('c', {2,2,2,4}, nd4j::DataType::FLOAT32); + + NDArray expdLdI('c', {2,2,2,4}, {-4.989124, 2.540357, -1.515022, 0.791769, -3.563660, 1.814540, -1.082159, 0.565549, -2.138196, 1.088724, -0.649295, + 0.339329, -0.712732, 0.362908, -0.216432, 0.113110, 0.712732, -0.362908, 0.216432, -0.113110, 2.138195, -1.088724, 0.649295, + -0.339330, 3.563660,-1.814540, 1.082159, -0.565549, 4.989125, -2.540356, 1.515022, -0.791770}, nd4j::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {20.364472, 17.856588, 16.949714, 15.903684}, nd4j::DataType::FLOAT32); + NDArray expdLdB('c', {4}, {9.6, 10.8, 12. , 13.2}, nd4j::DataType::FLOAT32); + + input.linspace(0.1, 0.1); + gradO.linspace(-0.9, 0.15); + + nd4j::ops::batchnorm_bp op; + + auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,3}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto dLdI = results->at(0); + auto dLdG = results->at(3); + auto dLdB = results->at(4); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + + delete results; +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, batchnorm_bp_test7) { + +#if defined(HAVE_CUDNN) +return; +#endif + + NDArray input ('c', {2,2,2,2,4}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {4}, nd4j::DataType::FLOAT32); + NDArray gradO ('c', {2,2,2,2,4}, nd4j::DataType::FLOAT32); + + NDArray expdLdI('c', {2,2,2,2,4}, {-119.435059, 78.159744, -58.732986, 46.630123, -103.510391, 67.738441, -50.901920, 40.412773, -87.585716, 57.317142, + -43.070854, 34.195419, -71.661041, 46.895844, -35.239792, 27.978071, -55.736359, 36.474548, -27.408726, 21.760721, -39.811687, 26.053242, -19.577662, + 15.543370, -23.887009, 15.631950, -11.746595, 9.326023, -7.962326, 5.210644, -3.915531, 3.108671, 7.962341, -5.210655, 3.915535, -3.108677, 23.887032, + -15.631958, 11.746601, -9.326031, 39.811691, -26.053246, 19.577671, -15.543377, 55.736382, -36.474548, 27.408726, -21.760731, 71.661064, -46.895851, 35.239788, + -27.978077, 87.585732, -57.317154, 43.070866, -34.195431, 103.510384, -67.738464, 50.901920, -40.412777, 119.435097, -78.159744, 58.732998, -46.630131}, nd4j::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {282.38734 , 244.542027, 224.140995, 207.548793}, nd4j::DataType::FLOAT32); + NDArray expdLdB('c', {4}, {57.6, 60. , 62.4, 64.8}, nd4j::DataType::FLOAT32); + + input.linspace(0.1, 0.1); + gradO.linspace(-0.9, 0.15); + + + nd4j::ops::batchnorm_bp op; + + auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,4}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto dLdI = results->at(0); + auto dLdG = results->at(3); + auto dLdB = results->at(4); + + // dLdI->printBuffer(); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + + delete results; +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, batchnorm_bp_test8) { + +#if defined(HAVE_CUDNN) +return; +#endif + + NDArray input ('c', {2,4,2,2,2}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {4}, nd4j::DataType::FLOAT32); + NDArray gradO ('c', {2,4,2,2,2}, nd4j::DataType::FLOAT32); + + NDArray expdLdI('c', {2,4,2,2,2}, {-34.373802, -32.611046, -30.848286, -29.085529, -27.322769, -25.560009, -23.797251, -22.034491, 36.146996, 34.293301, + 32.439610, 30.585917, 28.732227, 26.878534, 25.024841, 23.171150, -42.876553, -40.677757, -38.478958, -36.280159, -34.081367, -31.882565, -29.683767, + -27.484968, 50.674446, 48.075760, 45.477066, 42.878380, 40.279686, 37.681000, 35.082310, 32.483616, 22.034489, 23.797249, 25.560009, 27.322765, 29.085526, + 30.848286, 32.611046, 34.373802, -23.171146, -25.024837, -26.878536, -28.732231, -30.585918, -32.439613, -34.293297, -36.146996, 27.484982, 29.683773, + 31.882572, 34.081364, 36.280178, 38.478970, 40.677776, 42.876560, -32.483627, -35.082329, -37.681023, -40.279701, -42.878403, -45.477081, -48.075775, -50.674484}, nd4j::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {134.490365, 179.785003, 248.933114, 330.087248}, nd4j::DataType::FLOAT32); + NDArray expdLdB('c', {4}, {32.4, 51.6, 70.8, 90.}, nd4j::DataType::FLOAT32); + + input.linspace(0.1, 0.1); + gradO.linspace(-0.9, 0.15); + + nd4j::ops::batchnorm_bp op; + + auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto dLdI = results->at(0); + auto dLdG = results->at(3); + auto dLdB = results->at(4); + + // dLdI->printBuffer(); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + + delete results; +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, batchnorm_bp_test9) { + + NDArray input ('c', {2,4,2,2}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {4}, nd4j::DataType::FLOAT32); + NDArray variance('c', {4}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {4}, nd4j::DataType::FLOAT32); + NDArray gradO ('c', {2,4,2,2}, nd4j::DataType::FLOAT32); + + NDArray expdLdI('c', {2,4,2,2}, {0.032378, 0.028967, 0.025558, 0.022147, -0.035056, -0.031364, -0.027669, -0.024006, 0.037742, 0.033766, 0.029791, 0.025818, + -0.040429, -0.036172, -0.031913, -0.027656, -0.022155, -0.025564, -0.028974, -0.032359, 0.023982, 0.027677, 0.031373, 0.035063, + -0.025822, -0.029794, -0.033770, -0.037747, 0.027653, 0.031913, 0.036168, 0.040426}, nd4j::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {9.685875, 9.685880, 9.685887, 9.685891}, nd4j::DataType::FLOAT32); + NDArray expdLdB('c', {4}, {4.2, 9. , 13.8, 18.6}, nd4j::DataType::FLOAT32); + + input.linspace(1,0.01); + gradO.linspace(-0.9, 0.15); + + // calculate mean and variance of input + PointersManager manager(input.getContext(), "DeclarableOpsTests13.batchnorm_bp_test9"); + std::vector dimensions = {0,2,3}; + int* dims = reinterpret_cast(manager.replicatePointer(dimensions.data(), dimensions.size() * sizeof(int))); + input.reduceAlongDimension(nd4j::reduce::Mean, mean, dimensions); + NDArray::prepareSpecialUse({&variance}, {&input}); + auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input.getShapeInfo(), dimensions); + NativeOpExecutioner::execSummaryStats(input.getContext(), 0,input.getBuffer(), input.getShapeInfo(),input.getSpecialBuffer(), input.getSpecialShapeInfo(),nullptr,variance.getBuffer(), variance.getShapeInfo(),variance.getSpecialBuffer(), variance.getSpecialShapeInfo(), dims, dimensions.size(),packX.platformShapeInfo(), packX.platformOffsets(),false); + manager.synchronize(); + NDArray::registerSpecialUse({&variance}, {&input}); + + nd4j::ops::batchnorm_bp op; + + auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto dLdI = results->at(0); + auto dLdG = results->at(3); + auto dLdB = results->at(4); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI, 1e-4)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + + delete results; +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, batchnorm_bp_test10) { + + NDArray input ('c', {2,2,2,4}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {4}, nd4j::DataType::FLOAT32); + NDArray variance('c', {4}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {4}, nd4j::DataType::FLOAT32); + NDArray gradO ('c', {2,2,2,4}, nd4j::DataType::FLOAT32); + + NDArray expdLdI('c', {2,2,2,4}, {0.032634, -0.035423, 0.038110, -0.040864, 0.023302, -0.025294, 0.027213, -0.029205, 0.013996, -0.015192, 0.016343, + -0.017519, 0.004664, -0.005062, 0.005445, -0.005833, -0.004668, 0.005067, -0.005452, 0.005824, -0.013974, 0.015171, + -0.016325, 0.017508, -0.023309, 0.025301, -0.027221, 0.029197, -0.032639, 0.035428, -0.038118, 0.040878}, nd4j::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {10.991656, 10.991631, 10.991643, 10.991632}, nd4j::DataType::FLOAT32); + NDArray expdLdB('c', {4}, {9.6, 10.8, 12., 13.2}, nd4j::DataType::FLOAT32); + + input.linspace(1,0.01); + gradO.linspace(-0.9, 0.15); + + // calculate mean and variance of input + PointersManager manager(input.getContext(), "DeclarableOpsTests13.batchnorm_bp_test9"); + std::vector dimensions = {0,1,2}; + int* dims = reinterpret_cast(manager.replicatePointer(dimensions.data(), dimensions.size() * sizeof(int))); + input.reduceAlongDimension(nd4j::reduce::Mean, mean, dimensions); + NDArray::prepareSpecialUse({&variance}, {&input}); + auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input.getShapeInfo(), dimensions); + NativeOpExecutioner::execSummaryStats(input.getContext(), 0,input.getBuffer(), input.getShapeInfo(),input.getSpecialBuffer(), input.getSpecialShapeInfo(),nullptr,variance.getBuffer(), variance.getShapeInfo(),variance.getSpecialBuffer(), variance.getSpecialShapeInfo(), dims, dimensions.size(),packX.platformShapeInfo(), packX.platformOffsets(),false); + manager.synchronize(); + NDArray::registerSpecialUse({&variance}, {&input}); + + nd4j::ops::batchnorm_bp op; + + auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,3}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto dLdI = results->at(0); + auto dLdG = results->at(3); + auto dLdB = results->at(4); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI, 1e-4)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + + delete results; +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, batchnorm_bp_test11) { + + NDArray input ('c', {2,3,4,5}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {1,3,4,5}, nd4j::DataType::FLOAT32); + NDArray variance('c', {1,3,4,5}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {1,3,4,5}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {1,3,4,5}, nd4j::DataType::FLOAT32); + NDArray gradO ('c', {2,3,4,5}, nd4j::DataType::FLOAT32); + + NDArray expdLdI('c', {2,3,4,5}, {0.004981, 0.004818, 0.004652, 0.004483, 0.004319, 0.004153, 0.003985, 0.003832, 0.003661, 0.003505, 0.003340, 0.003171, 0.003001, 0.002837, + 0.002670, 0.002505, 0.002337, 0.002167, 0.002003, 0.001835, 0.001666, 0.001499, 0.001327, 0.001162, 0.000996, 0.000830, 0.000664, 0.000498, + 0.000332, 0.000166, -0.0, -0.000166, -0.000333, -0.000500, -0.000668, -0.000835, -0.001003, -0.001168, -0.001337, -0.001502, -0.001670, + -0.001838, -0.002003, -0.002172, -0.002330, -0.002499, -0.002669, -0.002832, -0.003002, -0.003162, -0.003332, -0.003495, -0.003665, -0.003821, + -0.004001, -0.004163, -0.004324, -0.004516, -0.004678, -0.004851, -0.004981, -0.004818, -0.004652, -0.004483, -0.004319, -0.004151, -0.003985, + -0.003836, -0.003661, -0.003505, -0.003338, -0.003171, -0.003004, -0.002837, -0.002670, -0.002503, -0.002337, -0.002170, -0.002003, -0.001835, + -0.001664, -0.001499, -0.001328, -0.001162, -0.000996, -0.000829, -0.000664, -0.000498, -0.000332, -0.000166, 0.0, 0.000166, 0.000334, + 0.000500, 0.000668, 0.000834, 0.001003, 0.001170, 0.001337, 0.001502, 0.001669, 0.001838, 0.002005, 0.002172, 0.002330, 0.002496, 0.002669, + 0.002836, 0.003002, 0.003162, 0.003328, 0.003495, 0.003670, 0.003828, 0.003992, 0.004158, 0.004324, 0.004522, 0.004689, 0.004843}, nd4j::DataType::FLOAT32); + NDArray expdLdG('c', {1,3,4,5}, {8.999503, 8.999502, 8.999502, 8.999503, 8.999502, 8.999503, 8.999503, 8.999499, 8.999501, 8.999498, 8.999498, 8.999498, 8.999498, 8.999498, 8.999498, + 8.999498, 8.999498, 8.999498, 8.999498, 8.999499, 8.999501, 8.999500, 8.999503, 8.999503, 8.999503, 8.999504, 8.999503, 8.999503, 8.999504, 8.999503, + 8.999504, 8.999504, 8.999499, 8.999500, 8.999497, 8.999498, 8.999496, 8.999496, 8.999496, 8.999498, 8.999498, 8.999496, 8.999496, 8.999496, 8.999501, + 8.999501, 8.999499, 8.999499, 8.999499, 8.999501, 8.999501, 8.999501, 8.999499, 8.999500, 8.999501, 8.999501, 8.999501, 8.999495, 8.999495, 8.999497}, nd4j::DataType::FLOAT32); + NDArray expdLdB('c', {1,3,4,5}, {7.2, 7.5, 7.8, 8.1, 8.4, 8.7, 9.0, 9.3, 9.6, 9.9, 10.2, 10.5, 10.8, 11.1, 11.4, 11.7, 12.0, 12.3, 12.6, 12.9, 13.2, 13.5, 13.8, 14.1, 14.4, 14.7, 15.0, + 15.3, 15.6, 15.9, 16.2, 16.5, 16.8, 17.1, 17.4, 17.7, 18.0, 18.3, 18.6, 18.9, 19.2, 19.5, 19.8, 20.1, 20.4, 20.7, 21.0, 21.3, 21.6, 21.9, 22.2, 22.5, + 22.8, 23.1, 23.4, 23.7, 24.0, 24.3, 24.6, 24.9}, nd4j::DataType::FLOAT32); + + input.linspace(1,0.01); + gradO.linspace(-0.9, 0.15); + gamma.linspace(-3, 0.1); + + // calculate mean and variance of input + PointersManager manager(input.getContext(), "DeclarableOpsTests13.batchnorm_bp_test9"); + std::vector dimensions = {0}; + int* dims = reinterpret_cast(manager.replicatePointer(dimensions.data(), dimensions.size() * sizeof(int))); + input.reduceAlongDimension(nd4j::reduce::Mean, mean, dimensions, true); + NDArray::prepareSpecialUse({&variance}, {&input}); + auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input.getShapeInfo(), dimensions); + NativeOpExecutioner::execSummaryStats(input.getContext(), 0,input.getBuffer(), input.getShapeInfo(),input.getSpecialBuffer(), input.getSpecialShapeInfo(),nullptr,variance.getBuffer(), variance.getShapeInfo(),variance.getSpecialBuffer(), variance.getSpecialShapeInfo(), dims, dimensions.size(),packX.platformShapeInfo(), packX.platformOffsets(),false); + manager.synchronize(); + NDArray::registerSpecialUse({&variance}, {&input}); + + nd4j::ops::batchnorm_bp op; + + auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1, 1,2,3}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto dLdI = results->at(0); + auto dLdG = results->at(3); + auto dLdB = results->at(4); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI, 1e-4)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + + delete results; +} \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp index 75db5989c..84dd5d732 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp @@ -72,71 +72,6 @@ TEST_F(DeclarableOpsTests15, Test_Half_assign_1) { ASSERT_EQ(10, x.sumNumber().e(0)); } -TEST_F(DeclarableOpsTests15, test_avgpooling_edge_1) { - int inOutH = 5;// 35; - int inOutW = 5;// 35; - int inOutC = 10;// 192; - - auto x = NDArrayFactory::create('c', {1, inOutH, inOutW, inOutC}); - x.linspace(1.0); - - nd4j::ops::avgpool2d op; - auto result = op.execute({&x}, {}, {3,3, 1,1, 0,0, 1,1, 1, 0, 1}); - ASSERT_EQ(Status::OK(), result->status()); - - auto z = result->at(0); - - int totalPadHeight = (inOutH - 1) * 1 + 3 - inOutH; - int padTop = totalPadHeight / 2; - int padBottom = totalPadHeight - totalPadHeight / 2; - - int k = 3; - - auto m = NDArrayFactory::create('c', {1, inOutH, inOutW, inOutC}); - auto c = NDArrayFactory::create('c', {1, inOutH, inOutW, inOutC}); - - for (int h = 0; h < inOutH; h++) { - for (int w = 0; w < inOutW; w++) { - int hFrom = h - padTop; - int wFrom = w - padBottom; - - int hTo = hFrom + k; - int wTo = wFrom + k; - - hFrom = nd4j::math::nd4j_max(0, hFrom); - wFrom = nd4j::math::nd4j_max(0, wFrom); - - hTo = nd4j::math::nd4j_min(inOutH, hTo); - wTo = nd4j::math::nd4j_min(inOutW, wTo); - - int idxOut[4]; - int idxIn[4]; - for (int ch = 0; ch < inOutC; ch++) { - idxOut[1] = h; - idxOut[2] = w; - idxOut[3] = ch; - idxIn[3] = ch; - - for (int kh = hFrom; kh < hTo; kh++) { - for (int kw = wFrom; kw < wTo; kw++) { - idxIn[1] = kh; - idxIn[2] = kw; - - auto inVal = x.e(0, kh, kw, ch); - m.p(0, h, w, ch, inVal + m.e(0, h, w, ch)); - c.p(0, h, w, ch, 1 + c.e(0, h, w, ch)); - } - } - } - } - } - m /= c; - - ASSERT_EQ(m, *z); - - delete result; -} - TEST_F(DeclarableOpsTests15, Test_standarize_1) { auto x = NDArrayFactory::create('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f}); auto e = NDArrayFactory::create('c', {5}, {0.f, 0.f, 0.f, 0.f, 0.f}); @@ -1097,7 +1032,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_2) { ASSERT_EQ(Status::OK(), result->status()); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + delete result; } @@ -1106,7 +1041,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_3) { // rank 2 NDArray rgbs('c', { 3, 4 }, { -9.4, 9.9, 9.7, 9.0, 1.14, 1.01, 1.11, 9.6, 1.05, 10.0, 1.03, 10.22 }, nd4j::DataType::FLOAT32); NDArray expected('c', { 3, 4 }, { -2.021720, 4.692970, 3.669290, 9.491281, 1.511627, 2.611648, -1.298824, 0.358612, -6.472839, 4.568039, 5.290639, -0.430992 }, nd4j::DataType::FLOAT32); - + nd4j::ops::rgb_to_yuv op; auto result = op.execute({ &rgbs }, {}, { 0 }); auto output = result->at(0); @@ -1170,7 +1105,7 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_7) { // rank 3 NDArray rgbs('f', { 2, 2, 3 }, { 1.7750e+01f,-7.1062e+01f, -1.0019e+02f, -2.3406e+01f,5.2094e+01f,9.5438e+01f, -6.7461e+00f,3.8562e+01f, 6.5078e+00f, 3.3562e+01f,-5.8844e+01f,2.2750e+01f }, nd4j::DataType::FLOAT32); NDArray expected('f', { 2,2,3 }, { 36.628319,38.600643, -40.624989,18.231001, -14.822637,-2.479566, -8.965780, 2.223851, -16.561626,- 96.205162,-52.255379, -36.527435 }, nd4j::DataType::FLOAT32); - + nd4j::ops::rgb_to_yuv op; auto result = op.execute({ &rgbs }, {}, {}); auto output = result->at(0); @@ -1210,7 +1145,7 @@ TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_2) { ASSERT_EQ(Status::OK(), result->status()); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + delete result; } @@ -1484,7 +1419,7 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test7) { auto Y = NDArrayFactory::create(2.f); NDArray x('c', { 2, 2, 2 }, nd4j::DataType::FLOAT32); NDArray dLdzC('c', { 2, 2, 2 }, nd4j::DataType::FLOAT32); - + dLdzC.linspace(0.1, 0.1); x = 4.f; diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp index dacfac127..2ef86710a 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp @@ -883,22 +883,6 @@ TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_6) { delete result; } -TEST_F(DeclarableOpsTests3, Test_AvgPool_1) { - auto x= NDArrayFactory::create('c', {2, 10, 10, 3}); - x.linspace(1); - - nd4j::ops::avgpool2d op; - // kY kX sY sX pY pX dY dX M P - auto result = op.execute({&x}, {}, {3, 3, 3, 3, 0, 0, 1, 1, 1, 0, 1}); - // 0 1 2 3 4 5 6 7 8 9 10 - auto z = result->at(0); - - // z->printShapeInfo("z shape"); - // z->printIndexedBuffer("z buffr"); - - delete result; -} - TEST_F(DeclarableOpsTests3, Test_ReverseDivide_1) { auto x= NDArrayFactory::create('c', {1, 3}, {2, 2, 2}); auto y= NDArrayFactory::create('c', {1, 3}, {4, 6, 8}); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp index 1155c72de..9460a053f 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp @@ -56,7 +56,8 @@ public: typedef ::testing::Types TestingTypes; TYPED_TEST_CASE(TypedDeclarableOpsTests4, TestingTypes); -TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_1) { +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_1) { auto x = NDArrayFactory::create('c', {2, 4, 4, 2}); auto exp = NDArrayFactory::create('c', {2, 2, 2, 2}, {6.f, 7.f, 10.f, 11.f, 22.f, 23.f, 26.f, 27.f, 38.f, 39.f, 42.f, 43.f, 54.f, 55.f, 58.f, 59.f}); @@ -75,8 +76,8 @@ TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_1) { delete result; } - -TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_2) { +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_2) { auto x = NDArrayFactory::create('c', {2, 4, 4, 2}); auto exp = NDArrayFactory::create('c', {2, 2, 2, 2}, {6.f, 7.f, 10.f, 11.f, 22.f, 23.f, 26.f, 27.f, 38.f, 39.f, 42.f, 43.f, 54.f, 55.f, 58.f, 59.f}); @@ -96,7 +97,8 @@ TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_2) { delete result; } -TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_5) { +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_3) { auto x = NDArrayFactory::create('c', {2, 5, 5, 2}); auto exp = NDArrayFactory::create('c', {2, 3, 3, 2}, {7.f, 8.f, 11.f, 12.f, 14.f, 15.f, 27.f, 28.f, 31.f, 32.f, 34.f, 35.f, 42.f, 43.f, 46.f, 47.f, 49.f, 50.f, 57.f, 58.f, 61.f, 62.f, 64.f, 65.f, 77.f, 78.f, 81.f, 82.f, 84.f, 85.f, 92.f, 93.f, 96.f, 97.f, 99.f, 100.f,}); @@ -116,7 +118,8 @@ TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_5) { delete result; } -TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_6) { +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_4) { auto x = NDArrayFactory::create('c', {2, 5, 5, 2}); auto exp = NDArrayFactory::create('c', {2, 2, 2, 2}, {7.f, 8.f, 11.f, 12.f, 27.f, 28.f, 31.f, 32.f, 57.f, 58.f, 61.f, 62.f, 77.f, 78.f, 81.f, 82.f}); @@ -135,8 +138,8 @@ TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_6) { delete result; } - -TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_8) { +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_5) { auto x = NDArrayFactory::create('c', {2, 2, 5, 5}); auto exp = NDArrayFactory::create('c', {2, 2, 3, 3}, {1.f, 2.5f, 4.5f, 8.5f, 10.f, 12.f, 18.5f, 20.f, 22.f, 26.f, 27.5f, 29.5f, 33.5f, 35.f, 37.f, 43.5f, 45.f, 47.f, 51.f, 52.5f, 54.5f, 58.5f, 60.f, 62.f, 68.5f, 70.f, 72.f, 76.f, 77.5f, 79.5f, 83.5f, 85.f, 87.f, 93.5f, 95.f, 97.f}); @@ -156,8 +159,8 @@ TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_8) { delete result; } - -TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_9) { +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_6) { auto x = NDArrayFactory::create('c', {2, 2, 5, 5}); auto exp = NDArrayFactory::create('c', {2, 2, 3, 3}, {0.25f, 1.25f, 2.25f, 4.25f, 10.f, 12.f, 9.25f, 20.f, 22.f, 6.5f, 13.75f, 14.75, 16.75f, 35.f, 37.f, 21.75f, 45.f, 47.f, 12.75f, 26.25f, 27.25f, 29.25f, 60.f, 62.f, 34.25f, 70.f, 72.f, 19.f, 38.75f, 39.75f, 41.75f, 85.f, 87.f, 46.75f, 95.f, 97.f}); @@ -177,8 +180,8 @@ TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_9) { delete result; } - -TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_10) { +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_7) { auto x = NDArrayFactory::create('c', {2, 2, 5, 5}); auto exp = NDArrayFactory::create('c', {2, 2, 3, 3}, {4.f, 6.f, 7.5f, 14.f, 16.f, 17.5f, 21.5f, 23.5f, 25.f, 29.f, 31.f, 32.5f, 39.f, 41.f, 42.5f, 46.5f, 48.5f, 50.f, 54.f, 56.f, 57.5f, 64.f, 66.f, 67.5f, 71.5f, 73.5f, 75.f, 79.f, 81.f, 82.5f, 89.f, 91.f, 92.5f, 96.5f, 98.5f, 100.f}); @@ -198,8 +201,8 @@ TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_10) { delete result; } - -TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_11) { +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_8) { auto x = NDArrayFactory::create('c', {1, 1, 3, 3}); auto exp = NDArrayFactory::create('c', {1, 1, 2, 2}, {3.f, 4.f, 6.f, 7.f}); @@ -219,7 +222,8 @@ TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_11) { delete result; } -TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_12) { +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_9) { auto x = NDArrayFactory::create('c', {1, 1, 3, 3}); auto exp = NDArrayFactory::create('c', {1, 1, 3, 3}, {3.f, 4.f, 4.5f, 6.f, 7.f, 7.5f, 7.5f, 8.5f, 9.f}); @@ -242,7 +246,139 @@ TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_12) { delete result; } +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_10) { + auto input = NDArrayFactory::create('c', {4, 10, 10, 3}, {9.37125111f, 2.20166993f, 2.91434479f, 5.43639755f, -2.10573769f, 4.08528662f, 5.86908436f, -4.46203756f, 2.21057916f, 5.35849190f, 0.01394637f, 4.40566349f, 7.07982206f, -0.09633455f, 2.42429352f, 3.97301817f, -1.89553940f, 1.99690318f, 6.33141708f, 0.55401880f, 1.70707977f, + 5.55204201f, -0.03513752f, 1.60011971f, 2.62700319f, -2.74582434f, 3.06697464f, 1.06277943f, -1.16075921f, -0.78095782f, 9.72352791f, -1.22686064f, 1.99644792f, 7.35571337f, 1.40607321f, 0.11390255f, 9.53334427f, 2.28303599f, -1.66728830f, 6.16678810f, -0.04532295f, -1.97708666f, 9.74906158f, 1.46223176f, -1.46734393f, 4.30761862f, + -1.23790228f, 1.24823606f, 6.13938427f, -3.83689475f, -1.19625473f, 7.91535568f, 6.05868721f, -3.22946382f, 8.81633949f, -0.19967777f, 0.66053957f, 2.30919123f, 0.74543846f, -0.39347672f, 11.11058044f, 0.53720862f, 1.52645731f, 5.70012379f, -1.15213466f, 1.16451406f, 7.00526333f, 1.57362783f, -2.44384766f, 5.54213285f, -1.98828590f, + -0.70483637f, 7.88281822f, -3.59875536f, 0.80745387f, 13.41578484f, -1.55507684f, -0.65855008f, 9.32583523f, -0.14544789f, 0.73436141f, 3.61176538f, -1.71268058f, -2.58490300f, 9.09280205f, -3.27405524f, -2.04569697f, 4.44761324f, -0.62955856f, -2.61917663f, 8.04890442f, 0.54579324f, 0.85929775f, 9.82259560f, -1.93825579f, 0.77703512f, + 4.67090321f, -4.79267597f, -2.38906908f, 9.31265545f, 0.96026313f, -1.14109385f, 11.54231834f, -0.01417295f, -0.39500344f, 8.49191666f, 0.55300158f, 2.79490185f, 6.92466164f, 1.72254205f, 2.82222271f, 8.83112717f, 2.95033407f, 2.18054962f, 6.73509789f, -2.22272944f, 0.51127720f, -1.04563558f, 2.15747333f, -2.30959272f, 9.55441570f, + 1.50396204f, 1.77370787f, 7.38146257f, -1.79076433f, 3.20961165f, 7.18864202f, 2.91217351f, 0.43018937f, 7.11078024f, -1.17386127f, -0.16817921f, 6.12327290f, -2.82205725f, 3.30696845f, 13.51291752f, -1.30856836f, -2.38332748f, 11.09487438f, -1.47190213f, -0.53050828f, 4.38285351f, -5.07309771f, 1.50714362f, 5.72274446f, -2.85825086f, + -0.89673209f, 3.73791552f, -0.67708802f, -4.13149452f, -0.00671843f, -0.26566532f, 0.32961160f, 7.14501762f, -1.41608179f, -4.96590328f, 12.26205540f, -0.65158135f, -0.88641000f, 6.95777559f, -0.79058206f, -0.10260171f, 7.87169170f, 1.35921454f, 1.11759663f, 5.46187401f, -2.57214499f, 2.48484039f, 4.04043484f, -2.07137156f, -1.42709637f, + 9.25487137f, -0.12605135f, -2.66949964f, 2.89412403f, 0.74451172f, -2.96250391f, 3.99258423f, 0.27084303f, 0.32213116f, 5.42332172f, -0.44414216f, 1.70881832f, 6.69346905f, 0.53058422f, -4.73146200f, 4.22051668f, 2.24834967f, 0.66996074f, 4.30173683f, 0.11849818f, -4.07520294f, 8.27318478f, -2.54398274f, -2.86705542f, 10.11775303f, + -0.99382895f, 0.65881538f, 7.93556786f, -1.27934420f, -1.69343162f, 9.68042564f, -1.02609646f, -1.18189347f, 5.75370646f, -1.67888868f, -4.48871994f, 4.79537392f, -0.79212248f, -0.19855022f, 6.15060997f, -0.01081491f, 3.64454579f, 10.82562447f, 1.58859253f, -2.65847278f, 8.60093212f, -1.59196103f, 0.07635692f, 11.76175690f, -1.17453325f, + 0.10122013f, 6.86458445f, -2.18891335f, -2.74004745f, 8.07066154f, 0.71818852f, -2.03035975f, 6.31053686f, 0.51509416f, 1.39789927f, 9.43515587f, 2.04256630f, 0.13985133f, 4.65010691f, 2.40911126f, -0.36255789f, -3.06867862f, -0.45225358f, -1.56778407f, 6.05917358f, -1.09891272f, 1.77184200f, 6.46248102f, 0.96042323f, -0.24346280f, + 4.63436460f, -4.69907761f, 1.25187206f, 11.46173859f, -2.21917558f, 1.28007793f, 6.92173195f, 2.11268163f, -3.47389889f, 5.08722782f, -3.03950930f, -4.17154264f, 11.30568314f, 0.80361372f, 2.53214502f, 7.18707085f, -4.49114513f, 2.85449266f, 10.14906883f, -0.31974933f, -0.84472644f, -0.52459574f, 0.12921631f, -1.81390119f, 2.76170087f, 1.03982210f, 2.91744232f, -0.29048753f, 5.87453508f, -1.53684759f, 1.85800636f, -0.91404629f, 1.28954852f, 5.11354685f, -2.47475505f, -1.33179152f, 2.58552408f, 1.37316465f, -3.32339454f, 1.54122913f, 3.24953628f, -0.29758382f, 2.82391763f, -1.51142192f, -1.22699404f, 6.75745535f, 0.65452754f, -3.29385471f, 2.06008053f, 2.53172946f, -4.23532820f, -1.53909743f, -0.07010663f, -1.42173731f, 7.29031610f, -0.18448229f, 4.59496164f, 6.73027277f, 0.73441899f, 0.14426160f, 4.14915276f, -2.97010231f, 6.05851364f, 4.95218086f, -2.39145470f, 2.40494704f, 2.10288811f, 0.53503096f, 1.44511235f, 6.66344261f, -3.05803776f, 7.21418667f, 3.30303526f, -0.24163735f, 3.47409391f, 3.64520788f, 2.15189481f, -3.11243272f, 3.62310791f, 0.37379482f, 0.40865007f, -0.83132005f, -4.78246069f, 2.07030797f, 6.51765442f, 3.16178989f, 5.06180477f, 3.78434467f, -0.96689719f, 0.35965276f, 5.89967585f, 1.40294051f, 1.11952639f, 10.59778214f, 0.26739889f, -1.61297631f, 6.24801159f, -0.93914318f, -0.57812452f, 9.92604542f, -0.73025000f, -3.38530874f, 2.45646000f, -2.47949195f, 0.51638460f, 10.65636063f, 1.97816694f, -3.00407791f, 2.66914415f, -0.81951088f, -0.23316640f, 2.40737987f, -2.70007610f, 1.51531935f, 4.08860207f, -0.27552786f, -1.31721711f, 7.11568260f, -3.33498216f, -4.02545023f, 7.22675610f, -0.81690705f, -2.52689576f, 1.04016697f, -0.79291463f, -0.34875512f, 10.00498390f, -4.24167728f, 1.46162593f, 11.82569408f, -1.70359993f, -0.30161047f, 16.44085884f, -0.82253462f, -0.09435523f, 6.13080597f, -0.20259480f, 0.68308711f, 6.15663004f, -6.61776876f, 0.33295766f, 2.55449438f, -0.17819691f, -1.14892209f, 5.56776142f, 1.99279118f, 1.33035934f, 4.45823956f, 3.34916544f, -2.59905386f, 6.16164446f, -2.03881931f, -2.45273542f, 12.46793365f, -2.22743297f, 2.83738565f, 8.48628139f, -1.39347959f, -1.30867767f, 11.08041477f, -4.00363779f, 2.09183025f, 11.30395889f, -2.20504737f, 1.37426853f, 8.98735619f, 1.04676604f, -0.72757077f, 8.28050232f, -6.70741081f, -0.65798020f, 5.68592072f, -0.60760021f, 0.35854483f, 6.26852131f, 1.94100165f, 1.32112014f, 0.80987954f, -1.74617672f, -0.25434083f, 7.16045523f, 1.58884013f, -2.64847064f, 13.14820385f, 1.21393633f, -2.47258949f, 9.41650105f, -0.79384226f, 2.48954105f, 10.95629311f, 0.47723705f, 4.02126694f, 8.02593136f, -2.20726371f, -1.18794477f, 1.50836647f, 0.93118095f, -1.73513174f, 8.85493565f, -2.99670315f, -0.79055870f, 2.39473820f, 2.05046916f, -2.38055134f, 11.82299423f, 0.15609655f, 0.68744308f, 5.66401434f, -0.69281673f, 2.09855556f, 7.74626589f, -0.34283102f, 1.00542057f, 9.95838642f, 0.80161905f, 2.33455157f, 9.80057335f, -0.93561798f, 2.56991577f, 8.29711342f, 0.94213426f, 0.44209945f, 11.70259857f, 0.92710167f, 2.60957146f, 0.24971688f, -0.86529571f, 3.78628922f, 6.80884457f, -0.68178189f, 2.21103406f, 3.18895817f, 0.60283208f, -2.92716241f, 6.72060776f, -1.06625068f, 2.56543374f, 9.97404480f, 3.58080721f, -0.94936347f, 10.16736984f, -1.38464379f, 1.18191063f, 6.66179037f, -3.56115270f, 0.32329530f, 10.90870762f, 2.20638227f, 0.19653285f, 7.34650040f, -3.63859272f, -1.03027737f, 5.98829985f, -3.66606474f, -3.89746714f, 8.63469028f, 1.22569811f, 1.63240814f, 3.74385309f, 0.58243257f, -0.56981975f, 3.69260955f, 1.00979900f, -1.44030499f, 8.57058144f, -1.10648811f, 1.20474911f, 5.43133020f, -2.14822555f, -0.07928789f, 11.25825310f, 0.19645604f, -5.49546146f, 10.41917038f, -0.68178523f, -2.99639869f, 6.50054455f, 0.46488351f, -5.42328453f, 9.09500027f, -2.82107449f, 0.05601966f, 15.34610748f, -0.06820253f, 3.86699796f, 10.73316956f, -3.04795432f, -0.14702171f, 5.64813185f, 1.44028485f, -2.47596145f, 0.07280898f, -3.03187990f, -1.35183525f, 9.35835648f, 2.72966957f, 1.88199532f, 10.36187744f, -0.22834805f, -3.26738238f, 6.92025137f, -2.34061313f, 4.77379704f, 5.28559113f, -2.96323752f, -1.76186585f, 5.94436455f, 0.38647744f, -5.73869514f, 6.76849556f, 1.40892124f, -1.19068217f, 5.37919092f, -6.65328646f, 3.62782669f, 12.34744644f, 2.44762444f, -4.19242620f, 6.14906216f, 0.08121119f, 0.61355996f, 2.69666457f, -1.88962626f, -0.55314136f, 1.84937525f, 1.56048691f, 1.17460012f, 3.75674725f, 1.06198275f, -5.74625874f, 5.41645575f, -1.28946674f, -1.51689398f, 4.32400894f, -0.05222082f, -4.83948946f, 1.80747867f, 1.63144708f, -2.73887825f, 1.63975775f, -2.02163982f, -0.16210437f, 2.93518686f, 1.14427686f, -2.83246303f, 4.79283667f, 2.69697428f, -3.12678456f, -1.19225168f, -2.37022972f, -3.09429741f, 1.94225383f, -1.13747168f, -2.55048585f, 5.40242243f, 1.12777328f, 3.43713188f, 3.62658787f, -2.16878843f, 0.30164462f, 2.97407579f, -0.07275413f, -1.31149673f, 4.70066261f, -2.01323795f, 4.85255766f, 4.59128904f, 1.68084168f, 1.60336494f, 6.58138466f, -1.04759812f, 2.69906545f, 3.55769277f, -0.74327278f, 2.65819693f, 5.39528131f, 2.11248922f, -1.06446671f, 5.24546766f, -2.43146014f, 4.58907509f, 0.06521678f, -2.24503994f, 2.45722699f, 6.94863081f, 0.35258654f, 2.83396196f, 9.92525196f, -1.12225175f, -0.34365177f, 7.19116688f, -4.39813757f, 0.46517885f, 13.22028065f, -2.57483673f, -6.37226963f, 7.58046293f, -2.74600363f, 0.42231262f, 8.04881668f, 0.17289802f, -0.53447008f, 16.55157471f, -5.63614368f, 0.39288223f, 3.37079263f, 1.26484549f, -0.12820500f, 8.46440125f, -4.39304399f, 2.97676420f, 0.65650189f, 0.83158541f, -1.11556435f, 6.32885838f, -0.36087769f, 2.80724382f, 9.90292645f, 1.15936041f, 0.20947981f, 6.91249275f, -2.67404819f, 2.93782163f, 6.65656614f, -2.30828357f, 2.98214006f, 6.80611229f, -4.93821478f, -7.66555262f, 7.59763002f, -0.54159302f, 3.87403512f, 12.42607784f, 2.59284401f, -0.23375344f, 8.95293331f, -0.71807784f, 0.61873478f, 8.66713524f, 1.24289191f, -2.37835455f, 2.08071637f, -0.88315344f, -3.41891551f, 6.85245323f, 1.73007369f, 1.02169311f, 7.69170332f, -2.85411978f, 2.69790673f, 8.12906551f, -1.19351399f, -2.26442742f, 12.26104450f, -0.75579089f, -1.73274946f, 10.68729019f, 2.20655656f, -0.90522075f, 12.42165184f, -1.67929137f, 2.44851565f, 9.31565762f, -0.06645700f, 1.52762020f, 6.18427515f, -1.68882596f, 3.70261097f, 3.02252960f, -3.44125366f, -1.31575799f, 2.84617424f, -0.96849400f, -4.52356243f, 9.95027161f, 0.19966406f, -0.78874779f, 8.18595028f, -4.08300209f, 1.75126517f, 0.96418417f, -4.04913044f, -0.95200396f, 12.03637886f, -0.03041124f, 0.41642749f, 8.88267422f, -3.24985337f, -2.24919462f, 7.32566118f, 0.16964148f, -2.74123430f, 7.05264473f, -3.30191112f, 0.17163286f, 4.81851053f, -1.64463484f, -0.85933101f, 7.29276276f, 2.34066939f, -2.14860010f, 3.46148157f, -0.01782012f, 1.51504040f, 4.79304934f, 1.85281146f, -1.70663762f, 6.93470192f, -4.15440845f, -1.25983095f, 10.52491760f, 0.42930329f, -1.85146868f, 11.70042324f, -0.41704914f, 3.83796859f, 9.21148491f, -2.79719448f, 0.79470479f, 6.26926661f, -5.85230207f, 3.95105338f, 7.84790897f, -1.38680744f, -1.78099084f, 11.95235348f, -2.99841452f, -1.34507811f, 6.15714645f, -1.07552516f, -2.81228638f, 1.66234732f, -4.55166149f, -1.92601109f, 8.64634514f, -0.48158705f, 3.31595659f, 7.67371941f, 2.56964207f, 0.12107098f, 4.56467867f, -0.93541539f, 1.39432955f, 11.99714088f, 1.05353570f, -2.13099813f, 3.67617917f, 3.45895386f, 1.37365830f, 8.74344158f, -4.17585802f, 1.43908918f, 6.28764772f, 3.97346330f, -0.69144285f, 9.07983303f, -0.41635889f, -0.14965028f, 8.85469818f, 1.11306190f, 2.59440994f, 5.38982344f, -1.07948279f, 1.37252975f, 10.26984596f, -0.09318046f, 2.73104119f, 12.45902252f, -1.55446684f, -2.76124811f, 12.19395065f, -0.51846564f, 1.02764034f, 11.42673588f, -0.95940983f, -0.04781032f, 8.78379822f, -4.88957930f, 0.32534006f, 11.97696400f, -3.35108662f, 1.95104563f, 4.46915388f, -2.32061648f, 3.45230985f, 8.29983711f, 2.81034684f, -2.35529327f, 6.07801294f, -0.98105043f, -0.05359888f, 2.52291036f, -0.01986909f, -2.35321999f, 10.51954269f, 2.11145401f, 3.53506470f, 7.29093266f, 0.03721160f, -1.13496494f, 7.43886709f, -5.84201956f, 2.50796294f, 12.14647675f, 2.77490377f, -2.18896222f, 6.05641937f, 5.32617044f, 1.04221284f, 10.79106712f, -2.95749092f, -2.75414610f, 11.30037117f, -3.40654182f, -2.24673963f, 7.49126101f, 0.70811015f, -6.18003702f, 13.83951187f, -1.01204085f, 1.36298490f, -1.04451632f, 2.42435336f, -0.02346706f, -0.85528886f, 1.04731262f, 0.22192979f, 4.15708160f, 0.34933877f, 0.04814529f, 2.24107265f, 0.49676740f, -1.47752666f, 0.45040059f, -0.70471478f, -1.19759345f, 0.21711677f, 0.88461423f, -2.76830935f, 5.52066898f, 1.97664857f, -1.75381601f, 3.45877838f, 1.52617192f, -1.61350942f, 0.85337949f, 1.97610760f, -3.40310287f, 3.40319014f, -3.38691044f, -0.71319139f, 1.65463758f, -0.60680127f, -1.80700517f, 8.02592373f, 2.59627104f, 2.65895891f, 5.93043184f, -4.48425817f, 3.92670918f, 4.19496679f, -2.28286791f, 6.41634607f, 5.72330523f, 1.16269672f, -0.28753027f, 2.46342492f, 0.36693189f, 0.26712441f, 6.37652683f, -2.50139046f, 2.43923736f, 5.56310415f, 0.98065847f, 1.04267502f, 4.16403675f, -0.04966142f, 4.40897894f, 3.72905660f, -3.46129870f, 3.59962773f, 1.34830284f, -1.76661730f, 0.47943926f, 5.29946661f, -1.12711561f, 1.26970029f, 15.17655945f, -1.50971997f, 5.81345224f, 8.48562050f, -4.36049604f, 2.48144460f, 8.23780441f, -3.46030426f, -0.84656560f, 5.94946814f, 1.12747943f, -2.65683913f, 8.69085693f, 1.31309867f, -2.79958344f, 8.76840591f, -1.56444156f, 1.62710834f, 2.41177034f, -0.72804940f, 5.70619011f, 4.67169666f, -0.86167198f, -1.83803177f, 2.96346045f, 2.82692933f, -2.81557131f, 7.11113358f, -1.90071094f, 2.54244423f, 11.19284058f, -0.06298946f, -1.71517313f, 12.98388577f, 0.84510714f, 3.00816894f, 2.57200313f, 0.03899818f, -1.49330592f, 9.60099125f, -3.59513044f, -1.30045319f, 7.09241819f, -0.65233821f, -2.33627677f, 8.81366920f, 0.84154201f, 1.03312039f, 9.85289097f, 0.19351870f, 1.78496623f, 7.34631205f, -2.16530800f, -0.65016162f, 2.46842360f, 0.24016285f, -1.24308395f, 4.78175163f, -0.97682536f, 2.20942235f, 6.68382788f, 3.76786447f, -1.44454038f, 6.26453733f, -3.23575711f, -2.30137897f, 9.53092670f, -5.55222607f, 3.25999236f, 9.37559509f, 1.86339056f, -0.23551451f, 10.23400211f, 3.93031883f, -0.52629089f, 7.85724449f, -2.91549587f, 4.46612740f, 5.66530371f, -2.70820427f, 4.81359577f, 10.31247330f, 1.92230141f, 2.53931546f, 0.74986327f, 1.70303428f, 0.48063779f, 5.31099129f, -0.78976244f, 3.75864220f, 4.23051405f, 2.34042454f, -7.98193836f, 9.83987141f, -1.46722627f, 3.54497814f, 10.36455154f, -4.51249075f, 0.77715248f, 7.78694630f, -4.59989023f, -2.49585629f, 9.90296268f, 1.38535416f, 1.17441154f, 10.10452843f, -0.98628229f, 0.60194463f, 9.12639141f, -3.90754628f, 2.88526392f, 7.24123430f, -0.15283313f, -0.75728363f, -1.15116858f, -2.53791571f, 0.77229571f, 6.44114161f, 0.02646767f, 4.95463037f, 7.21066380f, 1.79384065f, 0.73250306f, 8.04447937f, 0.32576546f, -0.79447043f, 10.12717724f, 2.33392906f, 1.30716443f, 12.36073112f, -0.36694977f, -1.20438910f, 7.03105593f, 0.59557682f, 0.69267452f, 10.18113136f, 2.49944925f, -0.42229167f, 8.83143330f, -1.18805945f, -2.87509322f, 4.53596449f, 4.09732771f, -3.39088297f, -1.02536607f, 0.82119560f, -3.47302604f, 9.29991817f, 0.21001509f, 4.97036457f, 9.50018406f, 1.04420102f, 1.96560478f, 10.74769592f, -6.22709799f, 3.11690164f, 5.06759691f, -1.23724771f, -3.05831861f, 8.12925529f, -1.93435478f, -1.10151744f, 9.32263088f, -0.04249470f, -5.98547363f, 10.49398136f, 0.26400441f, -0.78915191f, 13.28219604f, 2.99276900f, 0.74853164f, 2.49364305f, -3.43529654f, 4.05278301f, 2.13498688f, -2.35444307f, -0.79900265f, 4.66968822f, -0.31095147f, 3.60674143f, 12.37222099f, -0.07855003f, -3.30292702f, 12.15215874f, 0.60886210f, 2.87075138f, 7.75271845f, 0.38044083f, 3.34402204f, 6.40583277f, -0.87888050f, 0.67438459f, 6.91080809f, 1.98332930f, -0.08303714f, 8.08630371f, -0.16772588f, -2.74058914f, 7.17253590f, -2.69122696f, 1.48173678f, 8.99470139f, -1.43302310f, -0.88651133f, 2.66944790f, -0.29186964f, 2.00838661f, 5.09587479f, -0.76676071f, -2.88322186f, 8.31110573f, -0.14550979f, -1.37726915f, 10.28355122f, -1.60575438f, -0.04118848f, 9.97510815f, 0.14440438f, -3.24632120f, 9.00034523f, 4.14319563f, -1.31023729f, 7.16950464f, -0.70428526f, 2.01559544f, 7.26155043f, 2.40816474f, 2.09847403f, 7.31264496f, -0.75401551f, 2.13392544f, 7.03648758f, 1.04036045f, -1.15636516f, 1.09634531f, -0.06340861f, -0.58107805f, -0.65623116f, 1.18972754f, -0.80717683f, 1.40118241f, -0.61932516f, -3.60596156f, 1.59904599f, -2.23774099f, -1.13721037f, 3.89620137f, -0.09115922f, -7.51356888f, 2.36975193f, -1.42520905f, -2.34173775f, 3.33830214f, -2.74016523f, -3.04115510f, 6.00119495f, -1.36084354f, -2.45065260f, 4.56992292f, -3.02825928f, -3.74182844f, 5.11069250f, -0.91531068f, -2.31385994f, 1.83399653f, 3.39370203f, -3.60886002f}); + auto exp = NDArrayFactory::create('c', {4, 4, 4, 3}, {7.97172260f, 0.06878620f, 2.27749538f, 7.29276514f, -0.14074677f, 0.65480286f, 5.70313978f, -0.06546132f, 0.35443667f, 3.70382833f, -0.84020567f, 0.63826996f, 8.60301399f, -0.38236514f, 1.55177069f, 7.37542057f, -0.99374938f, -0.29971302f, 8.84352493f, -0.67121059f, 0.43132120f, 4.78175592f, -1.25070143f, -1.91523600f, 6.03855371f, -0.00292124f, -1.11214364f, 7.90158176f, -0.57949901f, -0.96735370f, 7.81192017f, -0.53255427f, -0.48009714f, 3.16953635f, 0.08353355f, -1.54299748f, 3.74821687f, 1.69396687f, 0.72724354f, 5.42915201f, -1.13686812f, -0.71793109f, 5.78376389f, -0.72239977f, -0.60055625f, 2.53636408f, 0.56777251f, -2.07892323f, 6.08064651f, 0.68620735f, 2.54017019f, 5.65828180f, -0.68255502f, 1.47283304f, 6.10842514f, -0.39655915f, 0.28380761f, 1.96707797f, -1.98206317f, 0.94027776f, 4.71811438f, 0.32104525f, -0.92409706f, 8.34588146f, -1.05581069f, -0.55217457f, 9.58440876f, -0.96549922f, 0.45820439f, 5.65453672f, -2.50953507f, -0.71441835f, 8.03059578f, -0.21281289f, 0.92125505f, 9.26900673f, -0.35963219f, -0.70039093f, 8.59924412f, -1.22358346f, 0.81318003f, 3.85920119f, -0.01305223f, -1.09234154f, 6.33158875f, 1.28094780f, -1.48926139f, 4.94969177f, -0.77126902f, -1.97033751f, 5.64381838f, -0.16285487f, -1.31277227f, 2.39893222f, -1.32902908f, -1.39609122f, 6.47572327f, -0.45267010f, 1.55727172f, 6.70965624f, -1.68735468f, -0.05672536f, 7.25092363f, -0.64613032f, 0.67050058f, 3.60789680f, -2.05948973f, 2.22687531f, 8.15202713f, -0.70148355f, 1.28314006f, 8.14842319f, -1.88807654f, -1.04808438f, 8.45500565f, -0.76425624f, 0.94542569f, 4.56179953f, -0.28786001f, -2.04502511f, 8.46278095f, -0.31019822f, 0.07339200f, 9.34214592f, -0.61948007f, 0.52481830f, 8.32515621f, -1.52418160f, 0.49678251f, 5.11082315f, -1.09908783f, -0.52969611f, 5.27806664f, 0.88632923f, 0.66754371f, 4.75839233f, 0.48928693f, -0.68036932f, 6.56925392f, -0.02949905f, -2.99189186f, 4.46320581f, -0.64534980f, -0.29516968f, 8.60809517f, -1.13120568f, 3.41720533f, 5.84243155f, -1.24109328f, 0.89566326f, 5.99578333f, -0.42496428f, 2.07076764f, 3.17812920f, -0.81566459f, -0.14363396f, 6.55184317f, 0.39633346f, -0.43852386f, 8.70214558f, -2.24613595f, 0.30708700f, 8.73882294f, -0.53545928f, 1.54409575f, 4.49452257f, -0.16509305f, 0.19028664f, 8.24897003f, 0.44750381f, 2.15448594f, 8.97640514f, -0.77728152f, 0.57272542f, 9.03467560f, 0.47173575f, -1.10807717f, 3.30056310f, -0.43268481f, -0.41470885f, 3.53798294f, -0.08546703f, -2.16840744f, 6.18733406f, -0.17871059f, -2.59837723f, 5.94218683f, -1.02990067f, -0.49760687f, 3.76938033f, 0.86383581f, -1.91504073f}); + + nd4j::ops::avgpool2d op; + auto result = op.execute({&input}, {}, {3,3, 3,3, 0,0, 1,1,1, 0,1}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + + // z->printIndexedBuffer("z"); + // exp.printIndexedBuffer("e"); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete result; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, avgpool2d_11) { + int inOutH = 5;// 35; + int inOutW = 5;// 35; + int inOutC = 10;// 192; + + auto x = NDArrayFactory::create('c', {1, inOutH, inOutW, inOutC}); + x.linspace(1.0); + + nd4j::ops::avgpool2d op; + auto result = op.execute({&x}, {}, {3,3, 1,1, 0,0, 1,1, 1, 0, 1}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + + int totalPadHeight = (inOutH - 1) * 1 + 3 - inOutH; + int padTop = totalPadHeight / 2; + int padBottom = totalPadHeight - totalPadHeight / 2; + + int k = 3; + + auto m = NDArrayFactory::create('c', {1, inOutH, inOutW, inOutC}); + auto c = NDArrayFactory::create('c', {1, inOutH, inOutW, inOutC}); + + for (int h = 0; h < inOutH; h++) { + for (int w = 0; w < inOutW; w++) { + int hFrom = h - padTop; + int wFrom = w - padBottom; + + int hTo = hFrom + k; + int wTo = wFrom + k; + + hFrom = nd4j::math::nd4j_max(0, hFrom); + wFrom = nd4j::math::nd4j_max(0, wFrom); + + hTo = nd4j::math::nd4j_min(inOutH, hTo); + wTo = nd4j::math::nd4j_min(inOutW, wTo); + + int idxOut[4]; + int idxIn[4]; + for (int ch = 0; ch < inOutC; ch++) { + idxOut[1] = h; + idxOut[2] = w; + idxOut[3] = ch; + idxIn[3] = ch; + + for (int kh = hFrom; kh < hTo; kh++) { + for (int kw = wFrom; kw < wTo; kw++) { + idxIn[1] = kh; + idxIn[2] = kw; + + auto inVal = x.e(0, kh, kw, ch); + m.p(0, h, w, ch, inVal + m.e(0, h, w, ch)); + c.p(0, h, w, ch, 1 + c.e(0, h, w, ch)); + } + } + } + } + } + m /= c; + + ASSERT_EQ(m, *z); + + delete result; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, avgpool2d_12) { + + int bS=4, iH=10,iW=10, iC=3, kH=3,kW=3, sH=3,sW=3, pH=0,pW=0, dH=1,dW=1; + int oH=4, oW=4; + int paddingMode = 1; // 1-SAME, 0-VALID + int dataFormat = 1; // 1-NHWC, 0-NDHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto expected = NDArrayFactory::create('c', {bS, oH, oW, iC}, { 17.5, 18.5, 19.5, 25. , 26. , 27. , 34. , 35. , 36. , 41.5, 42.5, 43.5, 92.5, 93.5, 94.5, 100. , 101. , 102. , 109. , 110. , 111. , 116.5, 117.5, 118.5, + 182.5, 183.5, 184.5, 190. , 191. , 192. , 199. , 200. , 201. , 206.5, 207.5, 208.5, 257.5, 258.5, 259.5, 265. , 266. , 267. , 274. , 275. , 276. , 281.5, 282.5, 283.5, + 317.5, 318.5, 319.5, 325. , 326. , 327. , 334. , 335. , 336. , 341.5, 342.5, 343.5, 392.5, 393.5, 394.5, 400. , 401. , 402. , 409. , 410. , 411. , 416.5, 417.5, 418.5, + 482.5, 483.5, 484.5, 490. , 491. , 492. , 499. , 500. , 501. , 506.5, 507.5, 508.5, 557.5, 558.5, 559.5, 565. , 566. , 567. , 574. , 575. , 576. , 581.5, 582.5, 583.5, + 617.5, 618.5, 619.5, 625. , 626. , 627. , 634. , 635. , 636. , 641.5, 642.5, 643.5, 692.5, 693.5, 694.5, 700. , 701. , 702. , 709. , 710. , 711. , 716.5, 717.5, 718.5, + 782.5, 783.5, 784.5, 790. , 791. , 792. , 799. , 800. , 801. , 806.5, 807.5, 808.5, 857.5, 858.5, 859.5, 865. , 866. , 867. , 874. , 875. , 876. , 881.5, 882.5, 883.5, + 917.5, 918.5, 919.5, 925. , 926. , 927. , 934. , 935. , 936. , 941.5, 942.5, 943.5, 992.5, 993.5, 994.5,1000. , 1001. , 1002. ,1009. , 1010. , 1011. ,1016.5, 1017.5, 1018.5, + 1082.5, 1083.5, 1084.5,1090. , 1091. , 1092. ,1099. , 1100. , 1101. ,1106.5, 1107.5, 1108.5,1157.5, 1158.5, 1159.5,1165. , 1166. , 1167. ,1174. , 1175. , 1176. ,1181.5, 1182.5, 1183.5}); + input.linspace(1.); + input.syncToDevice(); + + nd4j::ops::avgpool2d op; + auto results = op.execute({&input}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 0, dataFormat}); + auto output = results->at(0); + + ASSERT_EQ(Status::OK(), results->status()); + + //output->printIndexedBuffer("output"); + //expected.printIndexedBuffer("expected"); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + delete results; +} + +////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, biasadd_1) { auto x = NDArrayFactory::create('c', {2, 3, 3, 2}); auto bias = NDArrayFactory::create('c', {2}, {1, 2}); @@ -1652,13 +1788,13 @@ TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_3) { ); auto exp = NDArrayFactory::create('c', {2, 2, 2, 4}, { - 0.9824562f, 0.f, 0.03822664f, 0.9824562f, - 0.67488194f, 0.f, 0.18924236f, 0.96960944f, - 0.99330735f, 0.f, 0.f, 0.37139067f, - 0.86567914f, 0.18702209f, 0.05610663f, 0.9520745f, - 0.6154575f, 0.34942827f, 0.45425674f, 0.6154575f, - 0.905509f, 0.f, 0.2824086f, 0.8361251f, - 0.57063663f, 0.41959068f, 0.629386f, 0.3504383f, + 0.9824562f, 0.f, 0.03822664f, 0.9824562f, + 0.67488194f, 0.f, 0.18924236f, 0.96960944f, + 0.99330735f, 0.f, 0.f, 0.37139067f, + 0.86567914f, 0.18702209f, 0.05610663f, 0.9520745f, + 0.6154575f, 0.34942827f, 0.45425674f, 0.6154575f, + 0.905509f, 0.f, 0.2824086f, 0.8361251f, + 0.57063663f, 0.41959068f, 0.629386f, 0.3504383f, 0.9520745f, 0.21039814f, 0.06311944f, 0.3268602f } ); @@ -1680,24 +1816,24 @@ TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_4) { auto x = NDArrayFactory::create('c', {2, 2, 2, 4}, { - 5.5f, 0.f, 0.3f, 5.5f, - 1.5f, 0.f, 1.3f, 6.5f, - 8.6f, 0.f, 0.f, 0.4f, - 2.5f, 1.f, 0.3f, 4.5f, - 1.5f, 1.f, 1.3f, 1.5f, - 3.5f, 0.f, 1.3f, 2.5f, - 2.6f, 2.f, 3.f, 1.4f, + 5.5f, 0.f, 0.3f, 5.5f, + 1.5f, 0.f, 1.3f, 6.5f, + 8.6f, 0.f, 0.f, 0.4f, + 2.5f, 1.f, 0.3f, 4.5f, + 1.5f, 1.f, 1.3f, 1.5f, + 3.5f, 0.f, 1.3f, 2.5f, + 2.6f, 2.f, 3.f, 1.4f, 4.5f, 1.f, 0.3f, 0.5f} ); auto exp = NDArrayFactory::create('c', {2, 2, 2, 4}, { - 0.70082176f, 0.f, 0.03822664f, 0.70082176f, - 0.21835658f, 0.f, 0.18924236f, 0.9462118f, - 0.9922489f, 0.f, 0.f, 0.04615111f, - 0.46755522f, 0.18702209f, 0.05610663f, 0.8415994f, - 0.5241424f, 0.34942827f, 0.45425674f, 0.5241424f, - 0.76033086f, 0.f, 0.2824086f, 0.54309344f, - 0.54546785f, 0.41959068f, 0.629386f, 0.29371348f, + 0.70082176f, 0.f, 0.03822664f, 0.70082176f, + 0.21835658f, 0.f, 0.18924236f, 0.9462118f, + 0.9922489f, 0.f, 0.f, 0.04615111f, + 0.46755522f, 0.18702209f, 0.05610663f, 0.8415994f, + 0.5241424f, 0.34942827f, 0.45425674f, 0.5241424f, + 0.76033086f, 0.f, 0.2824086f, 0.54309344f, + 0.54546785f, 0.41959068f, 0.629386f, 0.29371348f, 0.94679165f, 0.21039814f, 0.06311944f, 0.10519907f} ); @@ -1719,28 +1855,28 @@ TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_5) { auto x = NDArrayFactory::create('c', {2, 2, 2, 4}, { - 5.5f, 0.f, 0.3f, 5.5f, - 1.5f, 0.f, 1.3f, 6.5f, - 8.6f, 0.f, 0.f, 0.4f, - 2.5f, 1.f, 0.3f, 4.5f, - 1.5f, 1.f, 1.3f, 1.5f, - 3.5f, 0.f, 1.3f, 2.5f, - 2.6f, 2.f, 3.f, 1.4f, + 5.5f, 0.f, 0.3f, 5.5f, + 1.5f, 0.f, 1.3f, 6.5f, + 8.6f, 0.f, 0.f, 0.4f, + 2.5f, 1.f, 0.3f, 4.5f, + 1.5f, 1.f, 1.3f, 1.5f, + 3.5f, 0.f, 1.3f, 2.5f, + 2.6f, 2.f, 3.f, 1.4f, 4.5f, 1.f, 0.3f, 0.5f} ); auto eps = NDArrayFactory::create('c', {2, 2, 2, 4}, { - 0.70082176f, 0.f, 0.03822664f, 0.70082176f, - 0.21835658f, 0.f, 0.18924236f, 0.9462118f, + 0.70082176f, 0.f, 0.03822664f, 0.70082176f, + 0.21835658f, 0.f, 0.18924236f, 0.9462118f, - 0.9922489f, 0.f, 0.f, 0.04615111f, - 0.46755522f, 0.18702209f, 0.05610663f, 0.8415994f, + 0.9922489f, 0.f, 0.f, 0.04615111f, + 0.46755522f, 0.18702209f, 0.05610663f, 0.8415994f, - 0.5241424f, 0.34942827f, 0.45425674f, 0.5241424f, - 0.76033086f, 0.f, 0.2824086f, 0.54309344f, + 0.5241424f, 0.34942827f, 0.45425674f, 0.5241424f, + 0.76033086f, 0.f, 0.2824086f, 0.54309344f, - 0.54546785f, 0.41959068f, 0.629386f, 0.29371348f, + 0.54546785f, 0.41959068f, 0.629386f, 0.29371348f, 0.94679165f, 0.21039814f, 0.06311944f, 0.10519907f} ); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp index ef495142d..3fb90b480 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp @@ -2459,42 +2459,6 @@ TEST_F(DeclarableOpsTests8, reduceStDevBP_test4) { } -////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests8, avgpool2d_test13) { - - int bS=4, iH=10,iW=10, iC=3, kH=3,kW=3, sH=3,sW=3, pH=0,pW=0, dH=1,dW=1; - int oH=4, oW=4; - int paddingMode = 1; // 1-SAME, 0-VALID - int dataFormat = 1; // 1-NHWC, 0-NDHW - - auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); - auto expected = NDArrayFactory::create('c', {bS, oH, oW, iC}, { 17.5, 18.5, 19.5, 25. , 26. , 27. , 34. , 35. , 36. , 41.5, 42.5, 43.5, 92.5, 93.5, 94.5, 100. , 101. , 102. , 109. , 110. , 111. , 116.5, 117.5, 118.5, - 182.5, 183.5, 184.5, 190. , 191. , 192. , 199. , 200. , 201. , 206.5, 207.5, 208.5, 257.5, 258.5, 259.5, 265. , 266. , 267. , 274. , 275. , 276. , 281.5, 282.5, 283.5, - 317.5, 318.5, 319.5, 325. , 326. , 327. , 334. , 335. , 336. , 341.5, 342.5, 343.5, 392.5, 393.5, 394.5, 400. , 401. , 402. , 409. , 410. , 411. , 416.5, 417.5, 418.5, - 482.5, 483.5, 484.5, 490. , 491. , 492. , 499. , 500. , 501. , 506.5, 507.5, 508.5, 557.5, 558.5, 559.5, 565. , 566. , 567. , 574. , 575. , 576. , 581.5, 582.5, 583.5, - 617.5, 618.5, 619.5, 625. , 626. , 627. , 634. , 635. , 636. , 641.5, 642.5, 643.5, 692.5, 693.5, 694.5, 700. , 701. , 702. , 709. , 710. , 711. , 716.5, 717.5, 718.5, - 782.5, 783.5, 784.5, 790. , 791. , 792. , 799. , 800. , 801. , 806.5, 807.5, 808.5, 857.5, 858.5, 859.5, 865. , 866. , 867. , 874. , 875. , 876. , 881.5, 882.5, 883.5, - 917.5, 918.5, 919.5, 925. , 926. , 927. , 934. , 935. , 936. , 941.5, 942.5, 943.5, 992.5, 993.5, 994.5,1000. , 1001. , 1002. ,1009. , 1010. , 1011. ,1016.5, 1017.5, 1018.5, - 1082.5, 1083.5, 1084.5,1090. , 1091. , 1092. ,1099. , 1100. , 1101. ,1106.5, 1107.5, 1108.5,1157.5, 1158.5, 1159.5,1165. , 1166. , 1167. ,1174. , 1175. , 1176. ,1181.5, 1182.5, 1183.5}); - input.linspace(1.); - input.syncToDevice(); - - nd4j::ops::avgpool2d op; - auto results = op.execute({&input}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 0, dataFormat}); - auto output = results->at(0); - - ASSERT_EQ(Status::OK(), results->status()); - - //output->printIndexedBuffer("output"); - //expected.printIndexedBuffer("expected"); - - ASSERT_TRUE(expected.isSameShape(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - delete results; -} - - /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test1) { diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp index 6df52fb54..caceaa1cd 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp @@ -2894,344 +2894,7 @@ TEST_F(DeclarableOpsTests9, Floormod_BP_Test_4) { delete result; } -//////////////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests9, batchnorm_bp_test1) { - NDArray input ('c', {2,3,4}, nd4j::DataType::FLOAT32); - NDArray mean ('c', {4}, {1.1, 1.2, 1.3, 1.4}, nd4j::DataType::FLOAT32); - NDArray variance('c', {4}, nd4j::DataType::FLOAT32); - NDArray gamma ('c', {4}, nd4j::DataType::FLOAT32); - NDArray beta ('c', {4}, nd4j::DataType::FLOAT32); - NDArray gradO ('c', {2,3,4}, nd4j::DataType::FLOAT32); - - NDArray expdLdI('c', {2,3,4}, {-0.000056, -0.000056, -0.000056, -0.000056, -0.000034, -0.000034, -0.000034, -0.000034, -0.000011, -0.000011, -0.000011, -0.000011, 0.000011, 0.000011, 0.000011, 0.000011, 0.000034, 0.000034, 0.000034, 0.000034, 0.000056, 0.000056, 0.000056, 0.000056}, nd4j::DataType::FLOAT32); - NDArray expdLdG('c', {4}, {6.148104, 6.148104, 6.148105, 6.148105}, nd4j::DataType::FLOAT32); - NDArray expdLdB('c', {4}, {3.6, 4.5, 5.4, 6.3}, nd4j::DataType::FLOAT32); - - input.linspace(0.1, 0.1); - variance.assign(0.46666667); - gamma.assign(1.2); - beta.assign(1.); // has no effect on gradient calculations - gradO.linspace(-0.9, 0.15); - - nd4j::ops::batchnorm_bp op; - - auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1}); - - ASSERT_EQ(ND4J_STATUS_OK, results->status()); - - auto dLdI = results->at(0); - auto dLdG = results->at(3); - auto dLdB = results->at(4); - - ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); - ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - - ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); - ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - - ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); - ASSERT_TRUE(expdLdB.equalsTo(dLdB)); - - delete results; -} - - -//////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests9, batchnorm_bp_test2) { - - NDArray input ('c', {2,3,4}, nd4j::DataType::FLOAT32); - NDArray mean ('c', {3}, {1.05, 1.1, 1.15}, nd4j::DataType::FLOAT32); - NDArray variance('c', {3}, {0.5, 0.6, 0.7}, nd4j::DataType::FLOAT32); - NDArray gamma ('c', {3}, {1.2, 1.3, 1.4}, nd4j::DataType::FLOAT32); - NDArray beta ('c', {3}, nd4j::DataType::FLOAT32); - NDArray gradO ('c', {2,3,4}, nd4j::DataType::FLOAT32); - - NDArray expdLdI('c', {2,3,4}, {-0.601415, -0.521226, -0.441037, -0.360849, -0.456306, -0.395465, -0.334624, -0.273784, 0.396631, 0.343747, - 0.290863, 0.237978, 0.360849, 0.441037, 0.521226, 0.601415, 0.273784, 0.334625, 0.395465, 0.456306, -0.237978, - -0.290863, -0.343746, -0.396631}, nd4j::DataType::FLOAT32); - NDArray expdLdG('c', {3}, {5.81236 , 7.048771, 12.155388}, nd4j::DataType::FLOAT32); - NDArray expdLdB('c', {3}, {1.8, 6.6, 11.4}, nd4j::DataType::FLOAT32); - - input.linspace(0.1, 0.1); - // beta.assign(1.); // has no effect on gradient calculations - gradO.linspace(-0.9, 0.15); - - nd4j::ops::batchnorm_bp op; - - auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1}); - - ASSERT_EQ(ND4J_STATUS_OK, results->status()); - - auto dLdI = results->at(0); - auto dLdG = results->at(3); - auto dLdB = results->at(4); - - ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); - ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - - ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); - ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - - ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); - ASSERT_TRUE(expdLdB.equalsTo(dLdB)); - - delete results; -} - -//////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests9, batchnorm_bp_test3) { - - NDArray input ('c', {2,3,4}, nd4j::DataType::FLOAT32); - NDArray mean ('c', {2,1,4}, {1.05, 1.1, 1.15, 1.2, 1.25, 1.3, 1.35, 1.4}, nd4j::DataType::FLOAT32); - NDArray variance('c', {2,1,4}, {0.5, 0.6, 0.7, 0.8, 0.9, 1., 1.1, 1.2}, nd4j::DataType::FLOAT32); - NDArray gamma ('c', {2,1,4}, {1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9}, nd4j::DataType::FLOAT32); - NDArray beta ('c', {2,1,4}, nd4j::DataType::FLOAT32); - NDArray gradO ('c', {2,3,4}, nd4j::DataType::FLOAT32); - - NDArray expdLdI('c', {2,3,4}, {-0.577002, -0.744041, -0.850999, -0.922373, -0.000000, -0.000000, -0.000000, -0.000000, 0.577002, - 0.744041, 0.850999, 0.922373, -0.386037, -0.350205, -0.312047, -0.271737, -0.000000, -0.000000, - -0.000000, -0.000000, 0.386037, 0.350205, 0.312047, 0.271736}, nd4j::DataType::FLOAT32); - NDArray expdLdG('c', {2,1,4}, {1.378844, 0.910144, 0.573706, 0.335408, 2.640487, 2.954985, 3.289431, 3.64234 }, nd4j::DataType::FLOAT32); - NDArray expdLdB('c', {2,1,4}, {-0.9 , -0.45, 0. , 0.45, 4.5 , 4.95, 5.4 , 5.85}, nd4j::DataType::FLOAT32); - - input.linspace(0.1, 0.1); - // beta.assign(1.); // has no effect on gradient calculations - gradO.linspace(-0.9, 0.15); - - nd4j::ops::batchnorm_bp op; - - auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,0,2}); - - ASSERT_EQ(ND4J_STATUS_OK, results->status()); - - auto dLdI = results->at(0); - auto dLdG = results->at(3); - auto dLdB = results->at(4); - - ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); - ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - - ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); - ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - - ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); - ASSERT_TRUE(expdLdB.equalsTo(dLdB)); - - delete results; -} - -//////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests9, batchnorm_bp_test4) { - - NDArray input ('c', {2,4}, nd4j::DataType::FLOAT32); - NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32); - NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32); - NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); - NDArray beta ('c', {4}, nd4j::DataType::FLOAT32); - NDArray gradO ('c', {2,4}, nd4j::DataType::FLOAT32); - - NDArray expdLdI('c', {2,4}, {0.162923, -0.289673, 0.354174, -0.386151, -0.162923, 0.289673, -0.354174, 0.386151}, nd4j::DataType::FLOAT32); - NDArray expdLdG('c', {4}, {1.442483, 0.950200, 0.569207, 0.314641}, nd4j::DataType::FLOAT32); - NDArray expdLdB('c', {4}, {-1.2, -0.9, -0.6, -0.3}, nd4j::DataType::FLOAT32); - - input.linspace(0.1, 0.1); - gradO.linspace(-0.9, 0.15); - - nd4j::ops::batchnorm_bp op; - - auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1}); - - ASSERT_EQ(ND4J_STATUS_OK, results->status()); - - auto dLdI = results->at(0); - auto dLdG = results->at(3); - auto dLdB = results->at(4); - - ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); - ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - - ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); - ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - - ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); - ASSERT_TRUE(expdLdB.equalsTo(dLdB)); - - delete results; -} - -//////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests9, batchnorm_bp_test5) { - - NDArray input ('c', {2,4,2,2}, nd4j::DataType::FLOAT32); - NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32); - NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32); - NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); - NDArray beta ('c', {4}, nd4j::DataType::FLOAT32); - NDArray gradO ('c', {2,4,2,2}, nd4j::DataType::FLOAT32); - - NDArray expdLdI('c', {2,4,2,2}, {-0.737512, -0.659880, -0.582247, -0.504614, 0.561404, 0.502309, 0.443214, 0.384118, -1.168243, - -1.045270, -0.922297, -0.799324, 1.899026, 1.699128, 1.499231, 1.299333, 0.504614, 0.582247, 0.659880, 0.737512, -0.384118, - -0.443214, -0.502308, -0.561404, 0.799324, 0.922297, 1.045270, 1.168243, -1.299334, -1.499231, -1.699129, -1.899026}, nd4j::DataType::FLOAT32); - NDArray expdLdG('c', {4}, {11.073181, 12.585667, 17.708657, 24.313186}, nd4j::DataType::FLOAT32); - NDArray expdLdB('c', {4}, {4.2, 9. , 13.8, 18.6}, nd4j::DataType::FLOAT32); - - input.linspace(0.1, 0.1); - gradO.linspace(-0.9, 0.15); - - nd4j::ops::batchnorm_bp op; - - auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1}); - - ASSERT_EQ(ND4J_STATUS_OK, results->status()); - - auto dLdI = results->at(0); - auto dLdG = results->at(3); - auto dLdB = results->at(4); - - ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); - ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - - ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); - ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - - ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); - ASSERT_TRUE(expdLdB.equalsTo(dLdB)); - - delete results; -} - -//////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests9, batchnorm_bp_test6) { - - NDArray input ('c', {2,2,2,4}, nd4j::DataType::FLOAT32); - NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32); - NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32); - NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); - NDArray beta ('c', {4}, nd4j::DataType::FLOAT32); - NDArray gradO ('c', {2,2,2,4}, nd4j::DataType::FLOAT32); - - NDArray expdLdI('c', {2,2,2,4}, {-4.989124, 2.540357, -1.515022, 0.791769, -3.563660, 1.814540, -1.082159, 0.565549, -2.138196, 1.088724, -0.649295, - 0.339329, -0.712732, 0.362908, -0.216432, 0.113110, 0.712732, -0.362908, 0.216432, -0.113110, 2.138195, -1.088724, 0.649295, - -0.339330, 3.563660,-1.814540, 1.082159, -0.565549, 4.989125, -2.540356, 1.515022, -0.791770}, nd4j::DataType::FLOAT32); - NDArray expdLdG('c', {4}, {20.364472, 17.856588, 16.949714, 15.903684}, nd4j::DataType::FLOAT32); - NDArray expdLdB('c', {4}, {9.6, 10.8, 12. , 13.2}, nd4j::DataType::FLOAT32); - - input.linspace(0.1, 0.1); - gradO.linspace(-0.9, 0.15); - - nd4j::ops::batchnorm_bp op; - - auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,3}); - - ASSERT_EQ(ND4J_STATUS_OK, results->status()); - - auto dLdI = results->at(0); - auto dLdG = results->at(3); - auto dLdB = results->at(4); - - ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); - ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - - ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); - ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - - ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); - ASSERT_TRUE(expdLdB.equalsTo(dLdB)); - - delete results; -} - -//////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests9, batchnorm_bp_test7) { - - NDArray input ('c', {2,2,2,2,4}, nd4j::DataType::FLOAT32); - NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32); - NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32); - NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); - NDArray beta ('c', {4}, nd4j::DataType::FLOAT32); - NDArray gradO ('c', {2,2,2,2,4}, nd4j::DataType::FLOAT32); - - NDArray expdLdI('c', {2,2,2,2,4}, {-119.435059, 78.159744, -58.732986, 46.630123, -103.510391, 67.738441, -50.901920, 40.412773, -87.585716, 57.317142, - -43.070854, 34.195419, -71.661041, 46.895844, -35.239792, 27.978071, -55.736359, 36.474548, -27.408726, 21.760721, -39.811687, 26.053242, -19.577662, - 15.543370, -23.887009, 15.631950, -11.746595, 9.326023, -7.962326, 5.210644, -3.915531, 3.108671, 7.962341, -5.210655, 3.915535, -3.108677, 23.887032, - -15.631958, 11.746601, -9.326031, 39.811691, -26.053246, 19.577671, -15.543377, 55.736382, -36.474548, 27.408726, -21.760731, 71.661064, -46.895851, 35.239788, - -27.978077, 87.585732, -57.317154, 43.070866, -34.195431, 103.510384, -67.738464, 50.901920, -40.412777, 119.435097, -78.159744, 58.732998, -46.630131}, nd4j::DataType::FLOAT32); - NDArray expdLdG('c', {4}, {282.38734 , 244.542027, 224.140995, 207.548793}, nd4j::DataType::FLOAT32); - NDArray expdLdB('c', {4}, {57.6, 60. , 62.4, 64.8}, nd4j::DataType::FLOAT32); - - input.linspace(0.1, 0.1); - gradO.linspace(-0.9, 0.15); - - - nd4j::ops::batchnorm_bp op; - - auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,4}); - - ASSERT_EQ(ND4J_STATUS_OK, results->status()); - - auto dLdI = results->at(0); - auto dLdG = results->at(3); - auto dLdB = results->at(4); - - // dLdI->printBuffer(); - - ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); - ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - - ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); - ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - - ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); - ASSERT_TRUE(expdLdB.equalsTo(dLdB)); - - delete results; -} - -//////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests9, batchnorm_bp_test8) { - - NDArray input ('c', {2,4,2,2,2}, nd4j::DataType::FLOAT32); - NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32); - NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32); - NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); - NDArray beta ('c', {4}, nd4j::DataType::FLOAT32); - NDArray gradO ('c', {2,4,2,2,2}, nd4j::DataType::FLOAT32); - - NDArray expdLdI('c', {2,4,2,2,2}, {-34.373802, -32.611046, -30.848286, -29.085529, -27.322769, -25.560009, -23.797251, -22.034491, 36.146996, 34.293301, - 32.439610, 30.585917, 28.732227, 26.878534, 25.024841, 23.171150, -42.876553, -40.677757, -38.478958, -36.280159, -34.081367, -31.882565, -29.683767, - -27.484968, 50.674446, 48.075760, 45.477066, 42.878380, 40.279686, 37.681000, 35.082310, 32.483616, 22.034489, 23.797249, 25.560009, 27.322765, 29.085526, - 30.848286, 32.611046, 34.373802, -23.171146, -25.024837, -26.878536, -28.732231, -30.585918, -32.439613, -34.293297, -36.146996, 27.484982, 29.683773, - 31.882572, 34.081364, 36.280178, 38.478970, 40.677776, 42.876560, -32.483627, -35.082329, -37.681023, -40.279701, -42.878403, -45.477081, -48.075775, -50.674484}, nd4j::DataType::FLOAT32); - NDArray expdLdG('c', {4}, {134.490365, 179.785003, 248.933114, 330.087248}, nd4j::DataType::FLOAT32); - NDArray expdLdB('c', {4}, {32.4, 51.6, 70.8, 90.}, nd4j::DataType::FLOAT32); - - input.linspace(0.1, 0.1); - gradO.linspace(-0.9, 0.15); - - nd4j::ops::batchnorm_bp op; - - auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1}); - - ASSERT_EQ(ND4J_STATUS_OK, results->status()); - - auto dLdI = results->at(0); - auto dLdG = results->at(3); - auto dLdB = results->at(4); - - // dLdI->printBuffer(); - - ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); - ASSERT_TRUE(expdLdI.equalsTo(dLdI)); - - ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); - ASSERT_TRUE(expdLdG.equalsTo(dLdG)); - - ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); - ASSERT_TRUE(expdLdB.equalsTo(dLdB)); - - delete results; -} /* //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, gru_cell_bp_test1) {