/** * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /*! * \file kernel_reg.h * \brief */ #ifndef ASCENDC_KERNEL_REG_IMPL_H #define ASCENDC_KERNEL_REG_IMPL_H #include "kernel_utils.h" namespace AscendC { constexpr uint64_t MASK_PLACEHOLDER = 0; constexpr uint64_t MASK_PLACEHOLDER_LIST[2] = {0, 0}; enum class MaskMode : uint8_t { NORMAL = 0, COUNTER }; template __aicore__ static inline void SetVectorMaskImpl(const uint64_t maskHigh, const uint64_t maskLow) { #if defined(__CCE_KT_TEST__) && __CCE_KT_TEST__ == 1 if constexpr (sizeof(T) >= sizeof(int32_t)) { ASCENDC_ASSERT((maskHigh == 0ULL), { KERNEL_LOG(KERNEL_ERROR, "maskHigh must be 0 for type b32 and b64"); }); } ASCENDC_ASSERT(((maskLow != 0ULL) || (maskHigh != 0ULL)), { KERNEL_LOG(KERNEL_ERROR, "maskLow and maskHigh can not be zero at the same time"); }); #endif if ASCEND_IS_NOT_AIC { set_vector_mask(maskHigh, maskLow); } } template __aicore__ static inline void SetVectorMaskImpl(int32_t len) { if constexpr (mode == MaskMode::COUNTER) { SetVectorMaskImpl(0, len); return; } constexpr int32_t typeLen = DEFAULT_BLOCK_SIZE / sizeof(T); constexpr int32_t halfTypeLen = 64; constexpr int32_t lenCoeff = 2; if (len == halfTypeLen) { SetVectorMaskImpl(0, FULL_MASK); return; } else if (len == typeLen) { SetVectorMaskImpl(FULL_MASK, FULL_MASK); return; } else if (len >= halfTypeLen * lenCoeff) { SetVectorMaskImpl(FULL_MASK, FULL_MASK); return; } SetVectorMaskImpl(static_cast( (len > halfTypeLen) ? (((static_cast(1)) << static_cast(len - halfTypeLen)) - 1) : 0), static_cast((len > halfTypeLen) ? FULL_MASK : (((static_cast(1)) << static_cast(len)) - 1))); return; } __aicore__ inline void ResetMaskImpl() { if ASCEND_IS_NOT_AIC { set_vector_mask(FULL_MASK, FULL_MASK); } } template __aicore__ inline void PipeBarrierImpl() { #if __CCE_AICORE__ == 300 || __CCE_AICORE__ == 310 return; #endif pipe_barrier(pipe); } enum class CacheLine : uint64_t { SINGLE_CACHE_LINE = 0, ENTIRE_DATA_CACHE }; enum class DcciDst : uint64_t { CACHELINE_ALL = 0, CACHELINE_UB, CACHELINE_OUT, CACHELINE_ATOMIC }; #if __CCE_AICORE__ == 220 template __aicore__ inline void DcciGMImpl(__gm__ T* dst) { dcci(static_cast<__gm__ void *>(dst), static_cast(entireType), static_cast(dcciDst)); } template __aicore__ inline void DcciUBImpl(__ubuf__ T* dst) { dcci(static_cast<__ubuf__ void *>(dst), static_cast(entireType), static_cast(dcciDst)); } #endif #if (__CCE_AICORE__ == 220) || (__CCE_AICORE__ == 200) template __aicore__ inline void DcciGMImpl(__gm__ T* dst) { dcci(static_cast<__gm__ void *>(dst), static_cast(entireType)); } #endif __aicore__ inline void SetMaskCountImpl() { set_mask_count(); } __aicore__ inline void SetMaskNormImpl() { set_mask_norm(); } __aicore__ inline void SetLreluModeImpl(bool lreluMode) { if (lreluMode) { set_ctrl(sbitset1(get_ctrl(), LEAKY_RELU_MODE_BIT)); } else { set_ctrl(sbitset0(get_ctrl(), LEAKY_RELU_MODE_BIT)); } } __aicore__ inline void SetHF32ModeImpl(bool hf32Mode) { if (hf32Mode) { set_ctrl(sbitset1(get_ctrl(), HF32_MODE_BIT)); } else { set_ctrl(sbitset0(get_ctrl(), HF32_MODE_BIT)); } } __aicore__ inline void SetHF32TransModeImpl(bool hf32TransMode) { if (hf32TransMode) { set_ctrl(sbitset1(get_ctrl(), HF32_TRANS_MODE_BIT)); } else { set_ctrl(sbitset0(get_ctrl(), HF32_TRANS_MODE_BIT)); } } __aicore__ inline void SetMMLayoutTransformImpl(bool mmLayoutMode) { if (mmLayoutMode) { set_ctrl(sbitset1(get_ctrl(), MM_LAYOUT_MODE_BIT)); } else { set_ctrl(sbitset0(get_ctrl(), MM_LAYOUT_MODE_BIT)); } } __aicore__ inline int64_t GetAccValImpl() { return get_acc_val(); } } // namespace AscendC #endif // ASCENDC_KERNEL_REG_IMPL_H