76   bool AccumulatorsInRowMajor = 
false,
    80   typename Enable = 
bool   143      typename Policy::Operator::Shape, 
typename Policy::OpDelta>;
   151     !(Shape::kM % Policy::Operator::Shape::kM) && 
   152     !(Shape::kN % Policy::Operator::Shape::kN),
   153     "Shape of warp-level Mma must be divisible by operator shape.");
   157     Shape::kM / Policy::Operator::Shape::kM,
   158     (Shape::kN / Policy::Operator::Shape::kN / kPartitionsN > 0) ?
   159      Shape::kN / Policy::Operator::Shape::kN / kPartitionsN :
   166   typename Policy::Operator 
mma;
   185     int const &partitionN_idx = 0)
 const {
   187     using MmaOperandA = 
typename Policy::Operator::FragmentA;
   188     using MmaOperandB = 
typename Policy::Operator::FragmentB;
   189     using MmaOperandC = 
typename Policy::Operator::FragmentC;
   193     MmaOperandA 
const *ptr_A = 
reinterpret_cast<MmaOperandA 
const *
>(&A);
   194     MmaOperandB 
const *ptr_B = 
reinterpret_cast<MmaOperandB 
const *
>(&B);
   195     MmaOperandC *ptr_D = 
reinterpret_cast<MmaOperandC *
>(&D);
   198       const int n_off = partitionN_idx * FragmentB::kElements / MmaOperandB::kElements / 
kPartitionsN;
   201       for (
int n = 0; n < MmaIterations::kColumn; ++n) {
   204         for (
int m = 0; m < MmaIterations::kRow; ++m) {
   206           int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m);
   208           if (AccumulatorsInRowMajor) {  
   210               ptr_D[n + m_serpentine * MmaIterations::kColumn],
   213               ptr_D[n + m_serpentine * MmaIterations::kColumn]);
   216               ptr_D[m_serpentine + (n + n_off) * MmaIterations::kRow],
   219               ptr_D[m_serpentine + (n + n_off) * MmaIterations::kRow]);
 Describes the size of a matrix tile. 
Definition: matrix_shape.h:42
typename IteratorA::Fragment FragmentA
Storage for A tile. 
Definition: mma_tensor_op.h:129
Definition: aligned_buffer.h:35
LayoutB_ LayoutB
Layout of multiplicand B. 
Definition: mma_tensor_op.h:97
CUTLASS_DEVICE MmaTensorOp()
Ctor. 
Definition: mma_tensor_op.h:176
Architecture-specific operators on memory added for SM75. 
Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. 
Defines common types used for all GEMM-like operators. 
static int const kThreadCount
Number of threads participating in warp-level matrix product. 
Definition: mma_tensor_op.h:112
static int const kPartitionsN
PartitionsN indicating how many PartitionsN for multiplicand B. 
Definition: mma_tensor_op.h:118
Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. 
Definition: mma_tensor_op.h:82
LayoutA_ LayoutA
Layout of multiplicand A. 
Definition: mma_tensor_op.h:91
typename IteratorB::Fragment FragmentB
Storage for B tile. 
Definition: mma_tensor_op.h:138
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Templates exposing architecture support for warp-level multiply-add operations. 
Defines a Shape template for matrix tiles. 
typename IteratorC::Fragment FragmentC
Storage for C tile. 
Definition: mma_tensor_op.h:146
Definition: mma_tensor_op_tile_iterator.h:1794
CUTLASS_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C, int const &partitionN_idx=0) const 
Performs a warp-level matrix multiply-accumulate operation. 
Definition: mma_tensor_op.h:180
Policy::Operator mma
Underlying matrix multiply operator (concept: arch::Mma) 
Definition: mma_tensor_op.h:166
ElementC_ ElementC
Data type of accumulator matrix C. 
Definition: mma_tensor_op.h:100
Top-level include for all CUTLASS numeric types. 
Definition: mma_tensor_op_tile_iterator.h:75
LayoutC_ LayoutC
Layout of accumulator matrix C. 
Definition: mma_tensor_op.h:103
Policy_ Policy
Shape of the warp in units of thread (concept: MmaLanePolicySimt) 
Definition: mma_tensor_op.h:106
Shape_ Shape
Shape of warp-level matrix operation (concept: GemmShape) 
Definition: mma_tensor_op.h:85
ElementB_ ElementB
Data type of multiplicand B. 
Definition: mma_tensor_op.h:94
static int const kPartitionsK
Number of partitions along K dimension. 
Definition: mma_tensor_op.h:115
arch::OpClassTensorOp OperatorClass
Indicates class of matrix operator. 
Definition: mma_tensor_op.h:109
ElementA_ ElementA
Data type of multiplicand A. 
Definition: mma_tensor_op.h:88
Matrix multiply for SM75. 
Basic include for CUTLASS. 
Policy describing implementation details of warp-level GEMM targeting Tensor Cores.