31 #include <cuda_fp16.h>    45 template <
typename LayoutA, 
typename LayoutB, 
typename LayoutC>
    47   gemm::GemmShape<2,1,1>,
    62     Array<half_t, 2> 
const &a,
    63     Array<half_t, 1> 
const &b,
    64     Array<half_t, 2> 
const &c
    67 #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600))    69     __half2 
const & A = 
reinterpret_cast<__half2 
const &
>(a);
    70     __half2 B = __half2half2(reinterpret_cast<__half const &>(b));
    71     __half2 
const & C = 
reinterpret_cast<__half2 
const &
>(c);
    73     __half2 D = __hfma2(A, B, C);
    75     d = 
reinterpret_cast<Array<half_t, 2> &
>(D);
    79     for (
int i = 0; i < 2; ++i) {
    80       d[i] = a[i] * b[0] + c[i];
    89 template <
typename LayoutA, 
typename LayoutB>
    91   gemm::GemmShape<1,2,1>,
   106     Array<half_t, 1> 
const &a,
   107     Array<half_t, 2> 
const &b,
   108     Array<half_t, 2> 
const &c
   111 #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600))   113     __half2 
const & A = __half2half2(reinterpret_cast<__half const &>(a));
   114     __half2 B = 
reinterpret_cast<__half2 
const &
>(b);
   115     __half2 
const & C = 
reinterpret_cast<__half2 
const &
>(c);
   117     __half2 D = __hfma2(A, B, C);
   119     d = 
reinterpret_cast<Array<half_t, 2> &
>(D);
   123     for (
int i = 0; i < 2; ++i) {
   124       d[i] = a[0] * b[i] + c[i];
   135   gemm::GemmShape<2, 2, 1>,
   150     Array<half_t, 2> 
const &a,
   151     Array<half_t, 2> 
const &b,
   152     Array<half_t, 4> 
const &c
   155 #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600))   157     __half2 
const & A = 
reinterpret_cast<__half2 
const &
>(a);
   158     __half2 Blo = __low2half2(reinterpret_cast<__half2 const &>(b));
   159     __half2 Bhi = __high2half2(reinterpret_cast<__half2 const &>(b));
   161     __half2 
const *C = 
reinterpret_cast<__half2 
const *
>(&c);
   163     __half2 Dlo = __hfma2(A, Blo, C[0]);
   164     __half2 Dhi = __hfma2(A, Bhi, C[1]);
   166     Array<half_t, 2> * D = 
reinterpret_cast<Array<half_t, 2> *
>(&d);
   168     D[0] = 
reinterpret_cast<Array<half_t, 2> 
const &
>(Dlo);
   169     D[1] = 
reinterpret_cast<Array<half_t, 2> 
const &
>(Dhi);
   173     for (
int j = 0; j < 2; ++j) {
   175       for (
int i = 0; i < 2; ++i) {
   176         d[i + 2 * j] = a[i] * b[j] + c[i + 2 * j];
   188   gemm::GemmShape<2, 2, 1>,
   203     Array<half_t, 2> 
const &a,
   204     Array<half_t, 2> 
const &b,
   205     Array<half_t, 4> 
const &c
   208 #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600))   210     __half2 Alo = __low2half2(reinterpret_cast<__half2 const &>(a));
   211     __half2 Ahi = __high2half2(reinterpret_cast<__half2 const &>(a));
   212     __half2 
const & B = 
reinterpret_cast<__half2 
const &
>(b);
   214     __half2 
const *C = 
reinterpret_cast<__half2 
const *
>(&c);
   216     __half2 Dlo = __hfma2(Alo, B, C[0]);
   217     __half2 Dhi = __hfma2(Ahi, B, C[0]);
   219     Array<half_t, 2> * D = 
reinterpret_cast<Array<half_t, 2> *
>(&d);
   221     D[0] = 
reinterpret_cast<Array<half_t, 2> &
>(Dlo);
   222     D[1] = 
reinterpret_cast<Array<half_t, 2> &
>(Dhi);
   225     for (
int i = 0; i < 2; ++i) {
   227       for (
int j = 0; j < 2; ++j) {
   228         d[i * 2 + j] = a[i] * b[j] + c[i * 2 + j];
 cutlass::arch::Mma< gemm::GemmShape< 1, 2, 1 >, 1, half_t, LayoutA, half_t, LayoutB, half_t, layout::RowMajor, OpMultiplyAdd >::operator() CUTLASS_HOST_DEVICE void operator()(Array< half_t, 2 > &d, Array< half_t, 1 > const &a, Array< half_t, 2 > const &b, Array< half_t, 2 > const &c)
Definition: arch/mma_sm60.h:104
Definition: aligned_buffer.h:35
IEEE half-precision floating-point type. 
Definition: half.h:126
cutlass::arch::Mma< gemm::GemmShape< 2, 1, 1 >, 1, half_t, LayoutA, half_t, LayoutB, half_t, LayoutC, OpMultiplyAdd >::operator() CUTLASS_HOST_DEVICE void operator()(Array< half_t, 2 > &d, Array< half_t, 2 > const &a, Array< half_t, 1 > const &b, Array< half_t, 2 > const &c)
Definition: arch/mma_sm60.h:60
Mapping function for column-major matrices. 
Definition: layout/matrix.h:142
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Templates exposing architecture support for multiply-add operations. 
cutlass::arch::Mma< gemm::GemmShape< 2, 2, 1 >, 1, half_t, layout::ColumnMajor, half_t, layout::RowMajor, half_t, layout::ColumnMajor, OpMultiplyAdd >::operator() CUTLASS_HOST_DEVICE void operator()(Array< half_t, 4 > &d, Array< half_t, 2 > const &a, Array< half_t, 2 > const &b, Array< half_t, 4 > const &c)
Definition: arch/mma_sm60.h:148
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Shape of a matrix multiply-add operation. 
Definition: include/cutlass/gemm/gemm.h:57
Mapping function for row-major matrices. 
Definition: layout/matrix.h:50
Defines layout functions used by TensorRef and derived classes. 
Matrix multiply-add operation. 
Definition: arch/mma.h:92
cutlass::arch::Mma< gemm::GemmShape< 2, 2, 1 >, 1, half_t, layout::ColumnMajor, half_t, layout::RowMajor, half_t, layout::RowMajor, OpMultiplyAdd >::operator() CUTLASS_HOST_DEVICE void operator()(Array< half_t, 4 > &d, Array< half_t, 2 > const &a, Array< half_t, 2 > const &b, Array< half_t, 4 > const &c)
Definition: arch/mma_sm60.h:201