46 namespace threadblock {
    59   typename SmemIteratorA_,
    65   typename SmemIteratorB_,
    73   typename TransformA_ = NumericArrayConverter<
    74     typename SmemIteratorA_::Element, 
    75     typename IteratorA_::Element, 
    76     IteratorA_::Fragment::kElements>,
    79   typename TransformB_ = NumericArrayConverter<
    80     typename SmemIteratorB_::Element, 
    81     typename IteratorB_::Element, 
    82     IteratorB_::Fragment::kElements>,
    84   typename Enable = 
bool   126   using WarpFragmentA = 
typename Operator::FragmentA;
   127   using WarpFragmentB = 
typename Operator::FragmentB;
   142     typename Base::SharedStorage &shared_storage,       
   147     Base(shared_storage, thread_idx, warp_idx, lane_idx),
   148     smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
   149     smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) {
   171     int gemm_k_iterations,                            
   193     iterator_A.load(tb_frag_A);
   194     iterator_B.load(tb_frag_B);
   199     this->smem_iterator_A_.store(transform_A(tb_frag_A));
   200     this->smem_iterator_B_.store(transform_B(tb_frag_B));
   208     WarpFragmentA warp_frag_A[2];
   209     WarpFragmentB warp_frag_B[2];
   222     int smem_write_stage_idx = 1;
   225     if (gemm_k_iterations <= 1) {
   226       iterator_A.clear_mask();
   227       iterator_B.clear_mask();
   239     for (; gemm_k_iterations > 0; --gemm_k_iterations) {
   250         if (warp_mma_k == Base::kWarpGemmIterations - 1) {
   253           this->smem_iterator_A_.store(transform_A(tb_frag_A));
   255           this->smem_iterator_B_.store(transform_B(tb_frag_B));
   263           if (smem_write_stage_idx == 1) {
   269                 {0, -
Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
   275           smem_write_stage_idx ^= 1;
   287         if (warp_mma_k == 0) {
   289           iterator_A.load(tb_frag_A);
   290           iterator_B.load(tb_frag_B);
   296           if (gemm_k_iterations <= 2) {
   297             iterator_A.clear_mask();
   298             iterator_B.clear_mask();
   302         warp_mma(accum, warp_frag_A[warp_mma_k % 2], warp_frag_B[warp_mma_k % 2], accum);
 static int const kM
Definition: include/cutlass/gemm/gemm.h:58
LayoutC_ LayoutC
Layout of accumulator matrix. 
Definition: mma_pipelined.h:96
TransformB_ TransformB
Definition: mma_pipelined.h:103
Definition: aligned_buffer.h:35
Policy_ Policy
Policy describing tuning details. 
Definition: mma_pipelined.h:97
Operator::IteratorB warp_tile_iterator_B_
Iterator to load a warp-scoped tile of B operand from shared memory. 
Definition: mma_base.h:193
Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. 
Definition: mma_pipelined.h:86
IteratorB_ IteratorB
Iterates over tiles of B operand in global memory. 
Definition: mma_pipelined.h:94
Defines common types used for all GEMM-like operators. 
CUTLASS_DEVICE void operator()(int gemm_k_iterations, FragmentC &accum, IteratorA iterator_A, IteratorB iterator_B, FragmentC const &src_accum, TransformA transform_A=TransformA(), TransformB transform_B=TransformB())
Perform a threadblock-scoped matrix multiply-accumulate. 
Definition: mma_pipelined.h:170
IteratorA_ IteratorA
Iterates over tiles of A operand in global memory. 
Definition: mma_pipelined.h:93
typename IteratorB::Fragment FragmentB
Fragment of operand B loaded from global memory. 
Definition: mma_pipelined.h:113
SmemIteratorA_ SmemIteratorA
Definition: mma_pipelined.h:99
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Boost-like numeric conversion operator for CUTLASS numeric types. 
Defines a Shape template for matrix tiles. 
static int const kWarpGemmIterations
Number of warp-level GEMM oeprations. 
Definition: mma_base.h:108
Template for a double-buffered threadblock-scoped GEMM kernel. 
AlignedBuffer is a container for trivially copyable elements suitable for use in unions and shared me...
Shape_ Shape
Size of the Gemm problem - concept: gemm::GemmShape<> 
Definition: mma_pipelined.h:92
static int const kStages
Number of stages. 
Definition: mma_base.h:112
typename IteratorA::Fragment FragmentA
Fragment of operand A loaded from global memory. 
Definition: mma_pipelined.h:110
Top-level include for all CUTLASS numeric types. 
Definition: mma_base.h:83
typename Policy::Operator::FragmentC FragmentC
Fragment of accumulator tile. 
Definition: mma_pipelined.h:116
Operator::IteratorA warp_tile_iterator_A_
Iterator to load a warp-scoped tile of A operand from shared memory. 
Definition: mma_base.h:190
#define CUTLASS_GEMM_LOOP
Definition: cutlass.h:112
ElementC_ ElementC
Data type of accumulator matrix. 
Definition: mma_pipelined.h:95
SmemIteratorA smem_iterator_A_
Iterator to write threadblock-scoped tile of A operand to shared memory. 
Definition: mma_pipelined.h:132
SmemIteratorB_ SmemIteratorB
Definition: mma_pipelined.h:100
SmemIteratorB smem_iterator_B_
Iterator to write threadblock-scoped tile of B operand to shared memory. 
Definition: mma_pipelined.h:135
CUTLASS_DEVICE MmaPipelined(typename Base::SharedStorage &shared_storage, int thread_idx, int warp_idx, int lane_idx)
Construct from tensor references. 
Definition: mma_pipelined.h:141
Basic include for CUTLASS. 
TransformA_ TransformA
Definition: mma_pipelined.h:102
typename Policy::Operator Operator
Warp-level Mma. 
Definition: mma_pipelined.h:119
static int const kN
Definition: include/cutlass/gemm/gemm.h:59