72     typename ElementAccumulator_ = ElementC_,
    74     typename OperatorClass_ = arch::OpClassSimt,
    76     typename ArchTag_ = arch::Sm70,
    78     typename ThreadblockShape_ = 
typename DefaultGemmConfiguration<
    79         OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
    80         ElementAccumulator_>::ThreadblockShape,
    82     typename WarpShape_ = 
typename DefaultGemmConfiguration<
    83         OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
    84         ElementAccumulator_>::WarpShape,
    86     typename InstructionShape_ = 
typename DefaultGemmConfiguration<
    87         OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
    88         ElementAccumulator_>::InstructionShape,
    90     typename EpilogueOutputOp_ = 
typename DefaultGemmConfiguration<
    91         OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
    92         ElementAccumulator_>::EpilogueOutputOp,
    96         DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
    98                                  ElementAccumulator_>::EpilogueOutputOp::kCount,
   102         ElementAccumulator_, 
typename EpilogueOutputOp_::ElementAccumulator,
   103         EpilogueOutputOp_::kCount>,
   105     typename ThreadblockSwizzle_ =
   106         threadblock::GemmSplitKHorizontalThreadblockSwizzle,
   109         DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
   110                                  ElementC_, ElementAccumulator_>::kStages,
   113         DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
   114                                  ElementC_, ElementAccumulator_>::kAlignmentA,
   117         DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
   118                                  ElementC_, ElementAccumulator_>::kAlignmentB,
   120     typename Operator_ = 
typename DefaultGemmConfiguration<
   121         OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
   122         ElementAccumulator_>::Operator>
   211       typename EpilogueOutputOp::Params epilogue_ = 
   212         typename EpilogueOutputOp::Params(),
   213       int split_k_slices = 1,
   214       typename ConvertScaledOp::Params convert_ = 
   215         typename ConvertScaledOp::Params(),
   216       typename ReductionOp::Params reduction_ =
   217         typename ReductionOp::Params()
   219       problem_size(problem_size_),
   225       split_k_slices(split_k_slices),
   227       reduction(reduction_) { }
   233   typename GemmKernel::Params gemm_params_;
   255     ThreadblockSwizzle threadblock_swizzle;
   259       {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
   260       args.split_k_slices);
   262     return sizeof(ElementAccumulator_) * 
size_t(args.problem_size.m()) * 
size_t(args.problem_size.n()) * grid_shape.
k();
   269     ThreadblockSwizzle threadblock_swizzle;
   273       {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
   274       args.split_k_slices);
   282       static_cast<ElementAccumulator_ *>(workspace), 
   283       args.problem_size.n());
   285     int64_t partition_stride = int64_t(args.problem_size.m()) * int64_t(args.problem_size.n());
   288     gemm_params_ = 
typename GemmKernel::Params{
   291       args.ref_A.non_const_ref(),
   292       args.ref_B.non_const_ref(),
   299       args.problem_size.mn(),
   318     gemm_params_.ref_A.reset(args.ref_A.data());
   319     gemm_params_.ref_B.reset(args.ref_B.data());
   320     gemm_params_.ref_D.reset(workspace);     
   322     reduction_params_.ref_D.reset(args.ref_D.data());
   323     reduction_params_.ref_C.reset(args.ref_C.data());
   335     ThreadblockSwizzle threadblock_swizzle;
   337     dim3 grid = threadblock_swizzle.get_grid_shape(gemm_params_.grid_tiled_shape);
   338     dim3 block(GemmKernel::kThreadCount, 1, 1);
   342     int smem_size = int(
sizeof(
typename GemmKernel::SharedStorage));
   343     if (smem_size >= (48 << 10)) {
   345       result = cudaFuncSetAttribute(
   347         cudaFuncAttributeMaxDynamicSharedMemorySize,
   350       if (result != cudaSuccess) {
   354       result = cudaFuncSetAttribute(
   356         cudaFuncAttributePreferredSharedMemoryCarveout, 100);
   358       if (result != cudaSuccess) {
   363     Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(gemm_params_);
   365     result = cudaGetLastError();
   367     if (result != cudaSuccess) {
   378     Kernel<ReductionKernel><<< grid, block, 0, stream >>>(reduction_params_);
   380     result = cudaGetLastError();
   382     if (result != cudaSuccess) {
   397     void *workspace = 
nullptr, 
   398     cudaStream_t stream = 
nullptr) {
   403       status = 
run(stream);
   425     typename ElementAccumulator_,
   427     typename OperatorClass_,
   431     typename ThreadblockShape_,
   435     typename InstructionShape_,
   437     typename EpilogueOutputOp_,
   439     typename ConvertScaledOp_,
   441     typename ReductionOp_,
   443     typename ThreadblockSwizzle_,
   445     int Stages, 
int kAlignmentA, 
int kAlignmentB,
   449                          layout::ColumnMajor, ElementAccumulator_,
   450                          OperatorClass_, ArchTag_, ThreadblockShape_,
   451                          WarpShape_, InstructionShape_, EpilogueOutputOp_,
   452                          ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_,
   453                          Stages, kAlignmentA, kAlignmentB, Operator_> {
   456   using ElementA = ElementA_;
   457   using LayoutA = LayoutA_;
   458   using ElementB = ElementB_;
   459   using LayoutB = LayoutB_;
   462   using ElementAccumulator = ElementAccumulator_;
   463   using OperatorClass = OperatorClass_;
   464   using ArchTag = ArchTag_;
   465   using ThreadblockShape = ThreadblockShape_;
   466   using WarpShape = WarpShape_;
   467   using InstructionShape = InstructionShape_;
   468   using ConvertScaledOp = ConvertScaledOp_;
   469   using EpilogueOutputOp = EpilogueOutputOp_;
   471   using ThreadblockSwizzle = ThreadblockSwizzle_;
   473   static int const kStages = Stages;
   535       typename EpilogueOutputOp::Params epilogue_ = 
   536         typename EpilogueOutputOp::Params(),
   537       int split_k_slices = 1,
   538       typename ConvertScaledOp::Params convert_ = 
   539         typename ConvertScaledOp::Params(),
   540       typename ReductionOp::Params reduction_ =
   541         typename ReductionOp::Params()
   543       problem_size(problem_size_),
   549       split_k_slices(split_k_slices),
   551       reduction(reduction_) { }
   567       {args.problem_size.n(), args.problem_size.m(), args.problem_size.k()},
   568       {args.ref_B.data(), args.ref_B.stride(0)},
   569       {args.ref_A.data(), args.ref_A.stride(0)},
   570       {args.ref_C.data(), args.ref_C.stride(0)},
   571       {args.ref_D.data(), args.ref_D.stride(0)},
   582     return UnderlyingOperator::can_implement(to_underlying_arguments(args));
   588     return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args));
   594     return underlying_operator_.
initialize(to_underlying_arguments(args), workspace);
   600     return underlying_operator_.
update(to_underlying_arguments(args), workspace);
   606     return underlying_operator_.
run(stream);
   617     void *workspace = 
nullptr, 
   618     cudaStream_t stream = 
nullptr) {
   623       status = 
run(stream);
 Definition: conversion_op.h:53
cutlass::gemm::device::GemmSplitKParallel< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages, kAlignmentA, kAlignmentB, Operator >::WarpShape WarpShape WarpShape
Definition: device/gemm_splitk_parallel.h:136
cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::GemmKernel typename UnderlyingOperator::GemmKernel GemmKernel
Definition: device/gemm_splitk_parallel.h:499
cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::Arguments::ref_D TensorRef< ElementC, LayoutC > ref_D
Definition: device/gemm_splitk_parallel.h:513
cutlass::gemm::device::GemmSplitKParallel< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages, kAlignmentA, kAlignmentB, Operator >::LayoutB typename layout::LayoutTranspose< LayoutA >::type LayoutB
Definition: device/gemm_splitk_parallel.h:129
cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::Arguments::ref_C TensorRef< ElementC const, LayoutC > ref_C
Definition: device/gemm_splitk_parallel.h:512
cutlass::gemm::device::GemmSplitKParallel< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages, kAlignmentA, kAlignmentB, Operator >::Operator Operator Operator
Definition: device/gemm_splitk_parallel.h:142
Describes the size of a matrix tile. 
Definition: matrix_shape.h:42
cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::ReductionKernel typename UnderlyingOperator::ReductionKernel ReductionKernel
Definition: device/gemm_splitk_parallel.h:500
Definition: aligned_buffer.h:35
cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::Arguments::problem_size GemmCoord problem_size
Definition: device/gemm_splitk_parallel.h:509
Definition: default_gemm_splitk_parallel.h:88
Status operator()(Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr)
Runs the kernel using initialized state. 
Definition: device/gemm_splitk_parallel.h:395
static int const kStages
Definition: device/gemm_splitk_parallel.h:143
static CUTLASS_HOST_DEVICE dim3 block_shape()
Determines the threadblock shape. 
Definition: reduce_split_k.h:138
cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::ElementC ElementC_ ElementC
Definition: device/gemm_splitk_parallel.h:460
cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::Arguments::ref_A TensorRef< ElementA const, LayoutA > ref_A
Definition: device/gemm_splitk_parallel.h:510
Kernel performing a reduction over densely packed tensors in global memory. 
Definition: include/cutlass/gemm/gemm.h:94
cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::Arguments::epilogue EpilogueOutputOp::Params epilogue
Definition: device/gemm_splitk_parallel.h:514
Functor performing conversion operations used by epilogues. 
int split_k_slices
Definition: device/gemm_splitk_parallel.h:191
ReductionOp::Params reduction
Definition: device/gemm_splitk_parallel.h:193
Mixed-precision reduction. 
Definition: reduction_operators.h:50
cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::run Status run(cudaStream_t stream=nullptr)
Runs the kernel using initialized state. 
Definition: device/gemm_splitk_parallel.h:604
cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::update Status update(Arguments const &args, void *workspace=nullptr)
Lightweight update given a subset of arguments. 
Definition: device/gemm_splitk_parallel.h:598
Params structure. 
Definition: reduce_split_k.h:80
cutlass::gemm::device::GemmSplitKParallel< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages, kAlignmentA, kAlignmentB, Operator >::InstructionShape InstructionShape InstructionShape
Definition: device/gemm_splitk_parallel.h:137
CUTLASS_HOST_DEVICE Arguments()
Default ctor. 
Definition: device/gemm_splitk_parallel.h:201
CUTLASS_HOST_DEVICE Index const & k() const 
Returns the GEMM K coordinate. 
Definition: include/cutlass/gemm/gemm.h:145
ConvertScaledOp::Params convert
Definition: device/gemm_splitk_parallel.h:192
Mapping function for column-major matrices. 
Definition: layout/matrix.h:142
cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::GemmSplitKParallel GemmSplitKParallel()
Constructs the GEMM. 
Definition: device/gemm_splitk_parallel.h:562
cutlass::gemm::device::GemmSplitKParallel< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages, kAlignmentA, kAlignmentB, Operator >::ElementC ElementC ElementC
Definition: device/gemm_splitk_parallel.h:130
cutlass::gemm::device::GemmSplitKParallel< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages, kAlignmentA, kAlignmentB, Operator >::ThreadblockShape ThreadblockShape ThreadblockShape
Definition: device/gemm_splitk_parallel.h:135
cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::Arguments::ref_B TensorRef< ElementB const, LayoutB > ref_B
Definition: device/gemm_splitk_parallel.h:511
cutlass::gemm::device::GemmSplitKParallel< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages, kAlignmentA, kAlignmentB, Operator >::EpilogueOutputOp EpilogueOutputOp EpilogueOutputOp
Definition: device/gemm_splitk_parallel.h:139
cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::get_workspace_size static size_t get_workspace_size(Arguments const &args)
Gets the workspace size. 
Definition: device/gemm_splitk_parallel.h:586
cutlass::gemm::device::GemmSplitKParallel< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages, kAlignmentA, kAlignmentB, Operator >::ArchTag ArchTag ArchTag
Definition: device/gemm_splitk_parallel.h:134
cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::operator() Status operator()(Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr)
Runs the kernel using initialized state. 
Definition: device/gemm_splitk_parallel.h:615
cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::operator() Status operator()(cudaStream_t stream=nullptr)
Runs the kernel using initialized state. 
Definition: device/gemm_splitk_parallel.h:610
static Status can_implement(Arguments const &args)
Determines whether the GEMM can execute the given problem. 
Definition: device/gemm_splitk_parallel.h:244
TensorRef< ElementC, LayoutC > ref_D
Definition: device/gemm_splitk_parallel.h:189
GemmSplitKParallel()
Constructs the GEMM. 
Definition: device/gemm_splitk_parallel.h:241
Defines transposes of matrix layouts. 
Definition: layout/matrix.h:921
GemmCoord problem_size
Definition: device/gemm_splitk_parallel.h:185
cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::ReductionOp ReductionOp_ ReductionOp
Definition: device/gemm_splitk_parallel.h:470
cutlass::gemm::device::GemmSplitKParallel< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages, kAlignmentA, kAlignmentB, Operator >::GemmKernel typename kernel::DefaultGemmSplitKParallel< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, LayoutC, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, ConvertScaledOp, ThreadblockSwizzle, kStages, Operator >::GemmKernel GemmKernel
GEMM kernel. 
Definition: device/gemm_splitk_parallel.h:165
cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::Arguments::reduction ReductionOp::Params reduction
Definition: device/gemm_splitk_parallel.h:517
cutlass::gemm::device::GemmSplitKParallel< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages, kAlignmentA, kAlignmentB, Operator >::LayoutA typename layout::LayoutTranspose< LayoutB >::type LayoutA
Definition: device/gemm_splitk_parallel.h:127
Status operator()(cudaStream_t stream=nullptr)
Runs the kernel using initialized state. 
Definition: device/gemm_splitk_parallel.h:390
cutlass::gemm::device::GemmSplitKParallel< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages, kAlignmentA, kAlignmentB, Operator >::ReductionOp ReductionOp ReductionOp
Definition: device/gemm_splitk_parallel.h:140
cutlass::gemm::device::GemmSplitKParallel< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages, kAlignmentA, kAlignmentB, Operator >::ElementB ElementA ElementB
Definition: device/gemm_splitk_parallel.h:128
An error within CUTLASS occurred. 
Template for generic CUTLASS kernel. 
Kernel performing a reduction over densely packed tensors in global memory. 
cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::UnderlyingArguments typename UnderlyingOperator::Arguments UnderlyingArguments
Definition: device/gemm_splitk_parallel.h:498
cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::Operator Operator_ Operator
Definition: device/gemm_splitk_parallel.h:472
cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::Arguments::convert ConvertScaledOp::Params convert
Definition: device/gemm_splitk_parallel.h:516
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Top-level include for all CUTLASS numeric types. 
Definition: reduce_split_k.h:55
static CUTLASS_HOST_DEVICE dim3 grid_shape(cutlass::MatrixCoord problem_size)
Computes the grid size given a chosen threadblock shape. 
Definition: reduce_split_k.h:128
Definitions for GEMM structures. 
CUTLASS_HOST_DEVICE NonConstTensorRef non_const_ref() const 
Definition: tensor_ref.h:229
static size_t get_workspace_size(Arguments const &args)
Gets the workspace size. 
Definition: device/gemm_splitk_parallel.h:252
Definition: device/gemm_splitk_parallel.h:123
Mapping function for row-major matrices. 
Definition: layout/matrix.h:50
TensorRef< ElementC const, LayoutC > ref_C
Definition: device/gemm_splitk_parallel.h:188
cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::initialize Status initialize(Arguments const &args, void *workspace)
Initializes GEMM state from arguments. 
Definition: device/gemm_splitk_parallel.h:592
cutlass::gemm::device::GemmSplitKParallel< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages, kAlignmentA, kAlignmentB, Operator >::ConvertScaledOp ConvertScaledOp ConvertScaledOp
Definition: device/gemm_splitk_parallel.h:138
cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::can_implement static Status can_implement(Arguments const &args)
Determines whether the GEMM can execute the given problem. 
Definition: device/gemm_splitk_parallel.h:580
CUTLASS_HOST_DEVICE Arguments(GemmCoord problem_size_, TensorRef< ElementA const, LayoutA > ref_A_, TensorRef< ElementB const, LayoutB > ref_B_, TensorRef< ElementC const, LayoutC > ref_C_, TensorRef< ElementC, LayoutC > ref_D_, typename EpilogueOutputOp::Params epilogue_=typename EpilogueOutputOp::Params(), int split_k_slices=1, typename ConvertScaledOp::Params convert_=typename ConvertScaledOp::Params(), typename ReductionOp::Params reduction_=typename ReductionOp::Params())
Constructs an Arguments structure. 
Definition: device/gemm_splitk_parallel.h:205
cutlass::gemm::device::GemmSplitKParallel< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages, kAlignmentA, kAlignmentB, Operator >::ElementAccumulator ElementAccumulator ElementAccumulator
Definition: device/gemm_splitk_parallel.h:132
The given workspace is null when it is required to be non-null. 
Operation was successful. 
TensorRef< ElementB const, LayoutB > ref_B
Definition: device/gemm_splitk_parallel.h:187
Implements several possible threadblock-swizzling functions mapping blockIdx to GEMM problems...
Defines tags for architecture-specific configurations. 
cutlass::gemm::device::GemmSplitKParallel< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages, kAlignmentA, kAlignmentB, Operator >::ThreadblockSwizzle ThreadblockSwizzle ThreadblockSwizzle
Definition: device/gemm_splitk_parallel.h:141
cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::Arguments::split_k_slices int split_k_slices
Definition: device/gemm_splitk_parallel.h:515
Status update(Arguments const &args, void *workspace=nullptr)
Lightweight update given a subset of arguments. 
Definition: device/gemm_splitk_parallel.h:312
Status run(cudaStream_t stream=nullptr)
Runs the kernel using initialized state. 
Definition: device/gemm_splitk_parallel.h:329
EpilogueOutputOp::Params epilogue
Definition: device/gemm_splitk_parallel.h:190
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with the appropr...
Argument structure. 
Definition: device/gemm_splitk_parallel.h:179
LayoutC_ LayoutC
Definition: device/gemm_splitk_parallel.h:131
cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::Arguments::Arguments CUTLASS_HOST_DEVICE Arguments(GemmCoord problem_size_, TensorRef< ElementA const, LayoutA > ref_A_, TensorRef< ElementB const, LayoutB > ref_B_, TensorRef< ElementC const, LayoutC > ref_C_, TensorRef< ElementC, LayoutC > ref_D_, typename EpilogueOutputOp::Params epilogue_=typename EpilogueOutputOp::Params(), int split_k_slices=1, typename ConvertScaledOp::Params convert_=typename ConvertScaledOp::Params(), typename ReductionOp::Params reduction_=typename ReductionOp::Params())
Constructs an Arguments structure. 
Definition: device/gemm_splitk_parallel.h:529
Status initialize(Arguments const &args, void *workspace)
Initializes GEMM state from arguments. 
Definition: device/gemm_splitk_parallel.h:266
TensorRef< ElementA const, LayoutA > ref_A
Definition: device/gemm_splitk_parallel.h:186
Basic include for CUTLASS. 
cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::Arguments::Arguments CUTLASS_HOST_DEVICE Arguments()
Default ctor. 
Definition: device/gemm_splitk_parallel.h:525
cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::to_underlying_arguments static UnderlyingArguments to_underlying_arguments(Arguments const &args)
Helper to construct a transposed equivalent for the underying GEMM operator. 
Definition: device/gemm_splitk_parallel.h:565
Status
Status code returned by CUTLASS operations. 
Definition: cutlass.h:39
Template for a pipelined GEMM kernel. Does not compute batching or support split-K. 
cutlass::gemm::device::GemmSplitKParallel< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages, kAlignmentA, kAlignmentB, Operator >::ElementA ElementB ElementA
Definition: device/gemm_splitk_parallel.h:126
cutlass::gemm::device::GemmSplitKParallel< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages, kAlignmentA, kAlignmentB, Operator >::OperatorClass OperatorClass OperatorClass
Definition: device/gemm_splitk_parallel.h:133