42 namespace threadblock {
    51     typename SmemPaddingA_,
    53     typename SmemPaddingB_,
    82     typename Enable = 
bool>
   104                               Shape::kN / WarpGemm::kN,
   105                               Shape::kK / WarpGemm::kK>;
   108   static int const kWarpGemmIterations =
   109       (WarpGemm::kK / Operator::Policy::MmaShape::kK);
   112   static int const kStages = Stages;
   133                                Shape::kK * kStages +
   134                                    Policy::SmemPaddingA::kColumn>;
   138         MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow,
   139                     Shape::kN + Policy::SmemPaddingB::kColumn>;
   161       return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});
   167       return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});
   201       SharedStorage &shared_storage,
   209       warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx),
   210       warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {
 Policy_ Policy
Definition: mma_base.h:89
Describes the size of a matrix tile. 
Definition: matrix_shape.h:42
Definition: aligned_buffer.h:35
Architecture-specific operators on memory. 
AlignedBuffer< typename Operator::ElementB, ShapeB::kCount > operand_B
Buffer for B operand. 
Definition: mma_base.h:150
Operator::IteratorB warp_tile_iterator_B_
Iterator to load a warp-scoped tile of B operand from shared memory. 
Definition: mma_base.h:193
typename Policy::Operator::Shape WarpGemm
Definition: mma_base.h:100
Defines common types used for all GEMM-like operators. 
Shared storage object needed by threadblock-scoped GEMM. 
Definition: mma_base.h:125
Shape_ Shape
Policy describing tuning details. 
Definition: mma_base.h:88
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
Operator_ Operator
Warp-level GEMM operator (concept: gemm::warp::MmaTensorOp or gemm::warp::MmaSimt) ...
Definition: mma_base.h:58
SmemPaddingA_ SmemPaddingA
Padding used for A operand in shared memory. 
Definition: mma_base.h:61
Defines a Shape template for matrix tiles. 
static CUTLASS_HOST_DEVICE Operator::LayoutB LayoutB()
Returns a layout object for the B matrix. 
Definition: mma_base.h:166
Policy object describing MmaTensorOp. 
Definition: mma_base.h:56
Definition: tensor_ref.h:146
AlignedBuffer is a container for trivially copyable elements suitable for use in unions and shared me...
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Top-level include for all CUTLASS numeric types. 
Modifies semantics of cutlass::Array<> to provide guaranteed alignment. 
Definition: aligned_buffer.h:45
CUTLASS_HOST_DEVICE TensorRefA operand_A_ref()
Returns a TensorRef to the A operand. 
Definition: mma_base.h:172
Shape of a matrix multiply-add operation. 
Definition: include/cutlass/gemm/gemm.h:57
CUTLASS_HOST_DEVICE pointer data()
Definition: aligned_buffer.h:84
CUTLASS_DEVICE MmaBase(SharedStorage &shared_storage, int thread_idx, int warp_idx, int lane_idx)
Construct from tensor references. 
Definition: mma_base.h:199
Definition: mma_base.h:83
typename Policy::Operator Operator
Warp-level Mma. 
Definition: mma_base.h:96
Operator::IteratorA warp_tile_iterator_A_
Iterator to load a warp-scoped tile of A operand from shared memory. 
Definition: mma_base.h:190
AlignedBuffer< typename Operator::ElementA, ShapeA::kCount > operand_A
Buffer for A operand. 
Definition: mma_base.h:147
static CUTLASS_DEVICE Operator::LayoutA LayoutA()
Returns a layout object for the A matrix. 
Definition: mma_base.h:160
SmemPaddingB_ SmemPaddingB
Padding used for B operand in shared memory. 
Definition: mma_base.h:64
static int const kPartitionsK
Number of partitions of K dimension. 
Definition: mma_base.h:67
CUTLASS_HOST_DEVICE TensorRefB operand_B_ref()
Returns a TensorRef to the B operand. 
Definition: mma_base.h:178
Basic include for CUTLASS.