#################################################################################################
#
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Utility functions for converting between frontend datatypes and CUTLASS datatypes
"""
import cutlass_bindings
import cutlass
from cutlass.backend.library import (
    DataTypeSize,
    MathInstruction,
    MathOperation,
    ShortLayoutTypeNames,
    TileDescription,
)
try:
    import numpy as np
    numpy_available = True
    _library_to_numpy_dict = {
        cutlass.DataType.f16: np.float16,
        cutlass.DataType.f32: np.float32,
        cutlass.DataType.f64: np.float64,
        cutlass.DataType.s8: np.int8,
        cutlass.DataType.s32: np.int32,
    }
except ImportError:
    numpy_available = False
    _library_to_numpy_dict = {}
[docs]def numpy_library_type(inp) -> cutlass.DataType:
    if numpy_available:
        if inp == np.float16:
            return cutlass.DataType.f16
        elif inp == np.float32:
            return cutlass.DataType.f32
        elif inp == np.float64:
            return cutlass.DataType.f64
        elif inp == np.int8:
            return cutlass.DataType.s8
        elif inp == np.int32:
            return cutlass.DataType.s32
    return None 
[docs]def numpy_type(inp):
    return _library_to_numpy_dict.get(inp, None) 
try:
    import cupy as cp
    cupy_available = True
    _library_to_cupy_dict = {
        cutlass.DataType.f16: cp.float16,
        cutlass.DataType.f32: cp.float32,
        cutlass.DataType.f64: cp.float64,
        cutlass.DataType.s8: cp.int8,
        cutlass.DataType.s32: cp.int32,
    }
except ImportError:
    cupy_available = False
    _library_to_cupy_dict = {}
[docs]def cupy_library_type(inp) -> cutlass.DataType:
    if cupy_available:
        if inp == cp.float16:
            return cutlass.DataType.f16
        elif inp == cp.float32:
            return cutlass.DataType.f32
        elif inp == cp.float64:
            return cutlass.DataType.f64
    return None 
[docs]def cupy_type(inp):
    return _library_to_cupy_dict.get(inp, None) 
try:
    import torch
    torch_available = True
    _torch_to_library_dict = {
        torch.half: cutlass.DataType.f16,
        torch.float16: cutlass.DataType.f16,
        torch.float: cutlass.DataType.f32,
        torch.float32: cutlass.DataType.f32,
        torch.double: cutlass.DataType.f64,
        torch.float64: cutlass.DataType.f64,
    }
    _library_to_torch_dict = {
        cutlass.DataType.f16: torch.half,
        cutlass.DataType.f16: torch.float16,
        cutlass.DataType.f32: torch.float,
        cutlass.DataType.f32: torch.float32,
        cutlass.DataType.f64: torch.double,
        cutlass.DataType.f64: torch.float64,
    }
except ImportError:
    torch_available = False
    _torch_to_library_dict = {}
    _library_to_torch_dict = {}
[docs]def torch_library_type(inp) -> cutlass.DataType:
    return _torch_to_library_dict.get(inp, None) 
[docs]def torch_type(inp):
    return _library_to_torch_dict.get(inp, None) 
try:
    import bfloat16
    bfloat16_available = True
except ImportError:
    bfloat16_available = False
[docs]def bfloat16_library_type(inp) -> cutlass.DataType:
    if bfloat16_available:
        if inp == bfloat16.bfloat16:
            return cutlass.DataType.bf16 
[docs]def bfloat16_type(inp) -> bfloat16.bfloat16:
    if bfloat16_available:
        if inp == cutlass.DataType.bf16:
            return bfloat16.bfloat16 
# Mapping from library data type to Python-bound CUTLASS data type
library_to_binding_dict = {
    cutlass.DataType.s8: cutlass_bindings.int8,
    cutlass.DataType.s32: cutlass_bindings.int32,
    cutlass.DataType.f16: cutlass_bindings.float16,
    cutlass.DataType.bf16: cutlass_bindings.bfloat16,
    cutlass.DataType.f32: cutlass_bindings.float32,
    cutlass.DataType.f64: cutlass_bindings.float64,
    cutlass.DataType.tf32: cutlass_bindings.tfloat32,
}
# Mapping from Python-bound CUTLASS data type to library data type
binding_to_library = {
    cutlass_bindings.int8: cutlass.DataType.s8,
    cutlass_bindings.int32: cutlass.DataType.s32,
    cutlass_bindings.float16: cutlass.DataType.f16,
    cutlass_bindings.bfloat16: cutlass.DataType.bf16,
    cutlass_bindings.float32: cutlass.DataType.f32,
    cutlass_bindings.float64: cutlass.DataType.f64,
    cutlass_bindings.tfloat32: cutlass.DataType.tf32,
}
[docs]def binding_library_type(inp):
    if inp in binding_to_library:
        return binding_to_library[inp]
    return None 
[docs]def has_binding_type(inp: cutlass.DataType):
    return inp in library_to_binding_dict 
[docs]def library_to_binding(inp: cutlass.DataType):
    if not has_binding_type(inp):
        raise Exception(f"No available conversion from library type {inp} to Python-bound CUTLASS type")
    return library_to_binding_dict[inp] 
[docs]def library_type(inp):
    if inp in cutlass.DataTypeSize.keys():
        return inp
    for cvt_fn in [
        bfloat16_library_type,
        cupy_library_type,
        numpy_library_type,
        torch_library_type,
        binding_library_type,
    ]:
        out = cvt_fn(inp)
        if out is not None:
            return out
    raise Exception(f"No available conversion from type {inp} to a library type.") 
[docs]def library_layout(layout):
    if layout in cutlass.LayoutTag.keys():
        return layout
    # Convert Python-bound CUTLASS layout to profiler library layout
    if layout == cutlass_bindings.RowMajor:
        return cutlass.LayoutType.RowMajor
    elif layout == cutlass_bindings.ColumnMajor:
        return cutlass.LayoutType.ColumnMajor
    else:
        raise Exception(f"No conversion available for layout {layout} to library layout.") 
[docs]def binding_type(inp):
    if inp in DataTypeSize.keys():
        return inp
    libtype = library_type(inp)
    return library_to_binding(libtype) 
[docs]def binding_layout(layout):
    if layout in ShortLayoutTypeNames.keys():
        return layout
    elif layout == cutlass.LayoutType.RowMajor:
        return cutlass_bindings.RowMajor
    elif layout == cutlass.LayoutType.ColumnMajor:
        return cutlass_bindings.ColumnMajor
    else:
        raise Exception(f"No conversion available for layout {layout} to Python-bound CUTLASS layout.") 
def _tensor_from_numpy(np_tensor):
    dtype = library_type(np_tensor.dtype)
    if np_tensor.flags.c_contiguous:
        layout = cutlass.LayoutType.RowMajor
    elif np_tensor.flags.f_contiguous:
        layout = cutlass.LayoutType.ColumnMajor
    return (dtype, layout)
def _tensor_from_torch(pt_tensor):
    dtype = library_type(pt_tensor.dtype)
    return (dtype, cutlass.LayoutType.RowMajor)
[docs]def get_datatype_and_layout(tensor):
    if (numpy_available and isinstance(tensor, np.ndarray)) or (
        cupy_available and isinstance(tensor, cp.ndarray)
    ):
        return _tensor_from_numpy(tensor)
    elif torch_available and isinstance(tensor, torch.Tensor):
        return _tensor_from_torch(tensor)
    else:
        raise Exception(f"Unable to convert tensor of type {type(tensor)} to Python-bound CUTLASS datatype and layout.") 
[docs]def binding_opclass(opclass: cutlass.OpcodeClass):
    if opclass == cutlass.OpcodeClass.TensorOp:
        return cutlass_bindings.OpClass.TensorOp
    elif opclass == cutlass.OpcodeClass.Simt:
        return cutlass_bindings.OpClass.Simt
    else:
        raise Exception(f"Unable to convert opcode class of type {opclass} to Python-bound CUTLASS opcode class.") 
_math_operation_value_map = {x.value: x for x in MathOperation}
[docs]def backend_math_operation(math_op: cutlass.MathOperation):
    if math_op.value not in _math_operation_value_map.keys():
        raise Exception(f"Unable to convert math operation of type {math_op} to backend math operation.")
    return _math_operation_value_map[math_op.value] 
[docs]def construct_backend_td(td: cutlass.TileDescription,
                         kernel_schedule: cutlass.KernelScheduleType) -> TileDescription:
    mi = td.math_instruction
    backend_mi = MathInstruction(
        mi.instruction_shape,
        binding_type(mi.element_a),
        binding_type(mi.element_b),
        binding_type(mi.element_accumulator),
        binding_opclass(mi.opcode_class),
        backend_math_operation(mi.math_operation)
    )
    return TileDescription(td.threadblock_shape, td.stages, td.warp_count,
                           backend_mi, td.cluster_shape, kernel_schedule) 
[docs]def td_from_profiler_op(op) -> TileDescription:
    """
    Converts the profiler's TileDescription in ``op`` into the backend TileDescription
    :param op: profiler Operation
    :returns: backend TileDescription
    :rtype: cutlass.backend.TileDescription
    """
    schedule = op.kernel_schedule if hasattr(op, 'kernel_schedule') else None
    return construct_backend_td(op.tile_description, schedule) 
[docs]def td_from_profiler_td(td: cutlass.backend.TileDescription) -> TileDescription:
    """
    Converts the profiler's TileDescription into the backend TileDescription
    :param td: profiler TileDescription
    :type td: cutlass.TileDescription
    :returns: backend TileDescription
    :rtype: cutlass.backend.TileDescription
    """
    return construct_backend_td(td, kernel_schedule=None)