47   typename ThreadblockSwizzle_    
    67     typename Mma::IteratorA::TensorRef 
ref_A;
    69     typename Mma::IteratorB::TensorRef 
ref_B;
    70     typename Epilogue::OutputTileIterator::Params 
params_D;
    71     typename Epilogue::OutputTileIterator::TensorRef 
ref_D;
    87       typename Mma::IteratorA::TensorRef ref_A,
    88       typename Mma::IteratorB::TensorRef ref_B,
    89       typename Epilogue::OutputTileIterator::TensorRef ref_D,
    90       typename OutputOp::Params output_op,
    91       int64_t splitk_slice_stride
    93       problem_size(problem_size),
    94       grid_tiled_shape(grid_tiled_shape),
    95       params_A(ref_A.layout()),
    97       params_B(ref_B.layout()),
    99       params_D(ref_D.layout()),
   101       output_op(output_op),
   102       splitk_slice_stride(splitk_slice_stride) {
   104       int full_gemm_k_iterations = problem_size.
k() / Mma::Shape::kK;
   105       int gemm_k_iterations = full_gemm_k_iterations / grid_tiled_shape.
k();
   107       gemm_k_size = gemm_k_iterations * Mma::Shape::kK;
   142       threadblock_tile_offset.
m() * Mma::Shape::kM,
   148       threadblock_tile_offset.
n() * Mma::Shape::kN
   157       problem_size_k = (threadblock_tile_offset.
k() + 1) * params.
gemm_k_size;
   161     int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK;
   164     int thread_idx = threadIdx.x;
   167     typename Mma::IteratorA iterator_A(
   174     typename Mma::IteratorB iterator_B(
   181     int warp_idx = threadIdx.x / 32;
   182     int lane_idx = threadIdx.x % 32;
   190     Mma mma(shared_storage.
main_loop, thread_idx, warp_idx, lane_idx);
   192     typename Mma::FragmentC accumulators;
   194     accumulators.clear();
   196     mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);
   208     threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
   212       threadblock_tile_offset.
m() * Mma::Shape::kM,
   213       threadblock_tile_offset.
n() * Mma::Shape::kN
   217     typename Epilogue::OutputTileIterator iterator_D(
   235     epilogue(output_op, iterator_D, accumulators, iterator_D);
 CUTLASS_DEVICE void operator()(Params const ¶ms, SharedStorage &shared_storage)
Executes one GEMM. 
Definition: kernel/gemm_splitk_parallel.h:126
CUTLASS_HOST_DEVICE GemmSplitKParallel()
Definition: kernel/gemm_splitk_parallel.h:122
Definition: aligned_buffer.h:35
Epilogue_ Epilogue
Definition: kernel/gemm_splitk_parallel.h:52
cutlass::gemm::GemmCoord problem_size
Definition: kernel/gemm_splitk_parallel.h:64
Shared memory storage structure. 
Definition: kernel/gemm_splitk_parallel.h:112
Epilogue::SharedStorage epilogue
Definition: kernel/gemm_splitk_parallel.h:114
Definition: include/cutlass/gemm/gemm.h:94
CUTLASS_HOST_DEVICE Coord< 2 > mn() const 
Obtains a Coord<2> from GemmCoord. 
Definition: include/cutlass/gemm/gemm.h:171
cutlass::gemm::GemmCoord grid_tiled_shape
Definition: kernel/gemm_splitk_parallel.h:65
static int const kThreadCount
Definition: kernel/gemm_splitk_parallel.h:58
Mma::SharedStorage main_loop
Definition: kernel/gemm_splitk_parallel.h:113
Defines common types used for all GEMM-like operators. 
CUTLASS_HOST_DEVICE Index const & n() const 
Returns the GEMM N coordinate. 
Definition: include/cutlass/gemm/gemm.h:137
Parameters structure. 
Definition: kernel/gemm_splitk_parallel.h:63
typename Mma::WarpCount WarpCount
Warp count (concept: GemmShape) 
Definition: kernel/gemm_splitk_parallel.h:57
ThreadblockSwizzle_ ThreadblockSwizzle
Definition: kernel/gemm_splitk_parallel.h:54
CUTLASS_HOST_DEVICE Params(cutlass::gemm::GemmCoord const &problem_size, cutlass::gemm::GemmCoord const &grid_tiled_shape, typename Mma::IteratorA::TensorRef ref_A, typename Mma::IteratorB::TensorRef ref_B, typename Epilogue::OutputTileIterator::TensorRef ref_D, typename OutputOp::Params output_op, int64_t splitk_slice_stride)
Definition: kernel/gemm_splitk_parallel.h:84
OutputOp::Params output_op
Definition: kernel/gemm_splitk_parallel.h:72
CUTLASS_HOST_DEVICE Index const & k() const 
Returns the GEMM K coordinate. 
Definition: include/cutlass/gemm/gemm.h:145
Mma::IteratorA::TensorRef ref_A
Definition: kernel/gemm_splitk_parallel.h:67
Mma::IteratorB::TensorRef ref_B
Definition: kernel/gemm_splitk_parallel.h:69
int gemm_k_size
Definition: kernel/gemm_splitk_parallel.h:74
Epilogue::OutputTileIterator::TensorRef ref_D
Definition: kernel/gemm_splitk_parallel.h:71
CUTLASS_HOST_DEVICE Params()
Definition: kernel/gemm_splitk_parallel.h:81
Epilogue::OutputTileIterator::Params params_D
Definition: kernel/gemm_splitk_parallel.h:70
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Mma::IteratorA::Params params_A
Definition: kernel/gemm_splitk_parallel.h:66
static int const kAlignmentK
Definition: kernel/gemm_splitk_parallel.h:60
Defines a canonical coordinate for rank=2 matrices offering named indices. 
Definition: kernel/gemm_splitk_parallel.h:49
CUTLASS_HOST_DEVICE Index const & m() const 
Returns the GEMM M coordinate. 
Definition: include/cutlass/gemm/gemm.h:129
Mma_ Mma
Definition: kernel/gemm_splitk_parallel.h:51
Mma::IteratorB::Params params_B
Definition: kernel/gemm_splitk_parallel.h:68
int64_t splitk_slice_stride
Definition: kernel/gemm_splitk_parallel.h:73
Basic include for CUTLASS. 
Definition: matrix_coord.h:39
typename Epilogue::OutputOp OutputOp
Definition: kernel/gemm_splitk_parallel.h:53