// Copyright (c) 2023 Huawei Technologies Co., Ltd
// Copyright (c) 2019, Facebook CORPORATION.
// All rights reserved.
//
// Licensed under the BSD 3-Clause License  (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "torch_npu/csrc/framework/utils/RandomOpAdapter.h"
#include "op_plugin/AclOpsInterface.h"
#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"
#include "torch_npu/csrc/core/npu/NPUGraphsUtils.h"

namespace op_api {

static const int64_t BIT_NUMBER = 128;
static const int64_t UINT8_BIT_NUMBER = 8;
static const double NUMBER_ZERO = 0.0;
static const double NUMBER_ONE = 1.0;
static const int64_t BLOCK_SIZE = 256;
static const int64_t MAX_THREADS_PER_MULTI_PROCESSOR = 2048;
static const int64_t MAX_PROCESSOR_COUNT = 78;
static const int UNROLL = 4;

std::tuple<at::Tensor, at::Tensor> _npu_dropout(const at::Tensor& self, double p)
{
    if (!c10_npu::IsAclnnOnly()) {
        DO_COMPATIBILITY(aclnnDropoutGenMaskV2, acl_op::_npu_dropout(self, p));
        DO_COMPATIBILITY(aclnnDropoutDoMask, acl_op::_npu_dropout(self, p));
    }

    int64_t length = (self.numel() + BIT_NUMBER - 1) / BIT_NUMBER * BIT_NUMBER / UINT8_BIT_NUMBER;
    at::Tensor result = at_npu::native::OpPreparation::apply_tensor_without_format(self);
    at::Tensor mask;

    if (c10_npu::IsAclnnOnly()) {
        mask = at_npu::native::OpPreparation::apply_tensor_without_format({length}, self.options().dtype(at::kByte));
        const auto gen = at_npu::detail::getDefaultNPUGenerator();
        const int64_t nelem = self.numel();
        if (nelem == 0) {
            return std::tie(result, mask);
        }
        unsigned int blocksPerSm = MAX_THREADS_PER_MULTI_PROCESSOR / BLOCK_SIZE;
        unsigned int gridX = (nelem + BLOCK_SIZE - 1) / BLOCK_SIZE;
        gridX = std::min((unsigned int)MAX_PROCESSOR_COUNT * blocksPerSm, gridX);
        int64_t counterOffset = ((nelem - 1) / (BLOCK_SIZE * gridX * UNROLL) + 1) * UNROLL;
        auto pair = at::check_generator<at_npu::NPUGeneratorImpl>(gen)->philox_engine_inputs(counterOffset);
        const uint64_t seed = pair.first;
        const uint64_t offset = pair.second;
        c10::optional<at::Tensor> noiseOpt = c10::nullopt;
        at::Tensor optionalNoiseShape = c10::value_or_else(noiseOpt, [] { return at::Tensor(); });
        EXEC_NPU_CMD(aclnnDropoutV3, self, optionalNoiseShape, p, seed, offset, result, mask);
        return std::tie(result, mask);
    }

    auto original_stream = c10_npu::getCurrentNPUStream();
    auto secondary_stream = c10_npu::getCurrentSecondaryStream();
    auto is_capture = c10_npu::currentStreamCaptureStatusMayInitCtx();
    // DropOutGenMask use seed and seed1 to generator a seed, like this:
    //  seed1   seed
    // 127~64   63~0
    // so, we set seed1 = 0 to ensure the seed which user set is equal to the seed
    // used by the operator DropOutGenMask
    const auto gen = at_npu::detail::getDefaultNPUGenerator();
    {
        // During the life cycle of this raii instance, the calcu stream is set as the
        // secondary stream, and tasks are distributed to the secondary stream. At the
        // same time, according to the one-stream-one-pool principle, memory is also
        // alloced from the pool of the secondary stream.
        if (is_capture == c10_npu::CaptureStatus::None) {
            c10_npu::SecondaryStreamGuard guard(secondary_stream);
            mask = at_npu::native::OpPreparation::apply_tensor_without_format({length}, self.options().dtype(at::kByte));
            at::IntArrayRef shapeArray(self.sizes());
            auto pair = at::check_generator<at_npu::NPUGeneratorImpl>(gen)->philox_engine_inputs(10);
            // At present, the default value of random number may be very large,
            // which will cause overflow in graph mode, so we set seed = 0 to avoid it.
            const uint64_t seed = pair.first;
            const uint64_t offset = pair.second;
            aclDataType dataType = at_npu::native::OpPreparation::convert_to_acl_data_type(self.scalar_type());
            EXEC_NPU_CMD(aclnnDropoutGenMaskV2, shapeArray, p, seed, offset, dataType, mask);
        } else {
#if VERSION_BETWEEN(V2R5, VERSION_NEWEST)
            auto gen_state_ = at::check_generator<at_npu::NPUGeneratorImpl>(gen)->philox_npu_state(10);
            const at::Tensor* seed_ptr = gen_state_.seed_.ptr;
            const at::Tensor* offset_ptr = gen_state_.offset_.ptr;
            const uint64_t offset_intragraph = gen_state_.offset_intragraph_;
            if (!gen_state_.secondary_stream_capture_state_) {
                c10_npu::NPUEvent capture_event_begin = c10_npu::NPUEvent();
                capture_event_begin.record(original_stream);
                capture_event_begin.block(secondary_stream);
                ASCEND_LOGI("Event: record and block in dropout op capture begin is successfully executed, event=%p", capture_event_begin.event());
            }
            NPUStatus ret = c10_npu::emptyAllNPUStream();
            c10_npu::SecondaryStreamGuard guard(secondary_stream);
            mask = at_npu::native::OpPreparation::apply_tensor_without_format({length}, self.options().dtype(at::kByte));
            at::IntArrayRef shapeArray(self.sizes());
            aclDataType dataType = at_npu::native::OpPreparation::convert_to_acl_data_type(self.scalar_type());
            EXEC_NPU_CMD(aclnnDropoutGenMaskV2Tensor, shapeArray, p, *seed_ptr, *offset_ptr, offset_intragraph, dataType, mask);
            if (!gen_state_.secondary_stream_capture_state_) {
                ASCEND_LOGI("Event: record and block in dropout op capture end is successfully executed");
                at::check_generator<at_npu::NPUGeneratorImpl>(gen)->set_secondary_stream_capture_state(true);
            }
#endif
        }
    }
    // When tasks on multiple streams read and write the same block of memory,
    // recordStream needs to be called to ensure the correctness of memory reuse.
    c10_npu::NPUCachingAllocator::recordStream(mask.storage().data_ptr(), original_stream);

    EXEC_NPU_CMD(aclnnDropoutDoMask, self, mask, p, result);
    return std::tie(result, mask);
}

at::Tensor npu_dropout_backward(const at::Tensor& grad_output, const at::Tensor& mask, double scale)
{
    DO_COMPATIBILITY(aclnnDropoutDoMask, acl_op::npu_dropout_backward(grad_output, mask, scale));
    at::Tensor result;
    if (mask.numel() == 0) {
        return at_npu::native::OpPreparation::apply_tensor_without_format(mask.sizes(), grad_output.options());
    }
    result = at_npu::native::OpPreparation::apply_tensor_without_format(grad_output);
    EXEC_NPU_CMD(aclnnDropoutDoMask, grad_output, mask, scale, result);
    return result;
}

std::tuple<at::Tensor, at::Tensor> native_dropout(const at::Tensor& input, double p, c10::optional<bool> train)
{
    DO_COMPATIBILITY(aclnnDropoutGenMaskV2, acl_op::native_dropout(input, p, train));
    DO_COMPATIBILITY(aclnnDropoutDoMask, acl_op::native_dropout(input, p, train));

    bool dropout_train = !train.has_value() ? true : train.value();
    at::TensorOptions options = input.options();
    if (p == 0 || !dropout_train) {
        at::Tensor mask = at::ones(input.sizes(), input.options().dtype(at::kBool));
        return std::make_tuple(input.clone(), mask);
    }
    if (p == 1) {
        at::Tensor output = at::zeros(input.sizes(), options);
        at::Tensor mask = at::zeros(input.sizes(), input.options().dtype(at::kBool));
        return std::make_tuple(output, mask);
    }
    return op_api::_npu_dropout(input, p);
}

at::Tensor native_dropout_backward(const at::Tensor& grad_output, const at::Tensor& mask, double scale)
{
    DO_COMPATIBILITY(aclnnDropoutDoMask, acl_op::native_dropout_backward(grad_output, mask, scale));
    TORCH_CHECK(scale == NUMBER_ZERO || scale >= NUMBER_ONE, "native_dropout_backward scale has to be 0 or greater than or equal to 1, but got ", scale,
                OPS_ERROR(ErrCode::VALUE));

    double p = (scale == 0.0) ? 1 : (1 - 1 / scale);
    if (p == 0) {
        return grad_output.clone();
    }
    if (p == 1) {
        at::TensorOptions options = grad_output.options();
        return at::zeros(grad_output.sizes(), options);
    }
    return op_api::npu_dropout_backward(grad_output, mask, p);
}

}