64   typename AccumulatorType,
    76   AccumulatorType initial_accum) {
    79     LayoutA::kRank == 2 &&
    80     LayoutB::kRank == 2 &&
    81     LayoutC::kRank == 2, 
"Tensors must be of rank 2");
    92     (problem_size.
m() + block.x * OutputTile::kRow - 1) / (block.x * OutputTile::kRow),
    93     (problem_size.
n() + block.y * OutputTile::kColumn - 1) / (block.y * OutputTile::kColumn)
   106   ><<< grid, block >>>(
   131   typename AccumulatorType,
   142   AccumulatorType initial_accum) {
   144   compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
   145                 ScalarType, AccumulatorType, InnerProductOp, ConvertOp>(
   146         problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_c,
   158   typename AccumulatorType,
   159   typename InnerProductOp = cutlass::arch::OpMultiplyAdd
   166 template <
typename ElementA, 
typename LayoutA, 
typename ElementB,
   167           typename LayoutB, 
typename ElementC, 
typename LayoutC,
   168           typename ScalarType, 
typename AccumulatorType>
   169 struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
   170             ScalarType, AccumulatorType, arch::OpMultiplyAdd> {
   176                   AccumulatorType initial_accum = AccumulatorType(0)) {
   179       LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
   180       "Tensors must be of rank 2");
   182     compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
   184         problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
   192                   AccumulatorType initial_accum = AccumulatorType(0)) {
   194       LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
   195       "Tensors must be of rank 2");
   197     compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
   199         problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
   206 template <
typename ElementA, 
typename LayoutA, 
typename ElementB,
   207           typename LayoutB, 
typename ElementC, 
typename LayoutC,
   208           typename ScalarType, 
typename AccumulatorType>
   209 struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
   210             AccumulatorType, arch::OpMultiplyAddSaturate> {
   216                   AccumulatorType initial_accum = AccumulatorType(0)) {
   218         LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
   219         "Tensors must be of rank 2");
   221     compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
   222                  ScalarType, AccumulatorType, multiply_add<AccumulatorType>,
   224         problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
   232                   AccumulatorType initial_accum = AccumulatorType(0)) {
   234         LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
   235         "Tensors must be of rank 2");
   237     compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
   238                  ScalarType, AccumulatorType, multiply_add<AccumulatorType>,
   240         problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
   247 template <
typename ElementA, 
typename LayoutA, 
typename ElementB,
   248           typename LayoutB, 
typename ElementC, 
typename LayoutC,
   249           typename ScalarType, 
typename AccumulatorType>
   250 struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
   251             AccumulatorType, arch::OpXorPopc> {
   257                   AccumulatorType initial_accum = AccumulatorType(0)) {
   259         LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
   260         "Tensors must be of rank 2");
   262     compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
   264         problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
   272                   AccumulatorType initial_accum = AccumulatorType(0)) {
   274         LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
   275         "Tensors must be of rank 2");
   277     compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
   279         problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
   295   typename TensorRefCollectionA,
   296   typename TensorRefCollectionB,
   297   typename TensorRefCollectionC,
   299   typename AccumulatorType,
   300   typename InnerProductOp,
   307   TensorRefCollectionA 
const& tensor_a,
   308   TensorRefCollectionB 
const& tensor_b,
   310   TensorRefCollectionC &tensor_c,
   311   AccumulatorType initial_accum) {
   314     TensorRefCollectionA::kRank == 2 &&
   315     TensorRefCollectionB::kRank == 2 &&
   316     TensorRefCollectionC::kRank == 2, 
"Tensors must be of rank 2");
   326     (problem_size.
m() + block.x * OutputTile::kRow - 1) / (block.x * OutputTile::kRow),
   327     (problem_size.
n() + block.y * OutputTile::kColumn - 1) / (block.y * OutputTile::kColumn),
   333     TensorRefCollectionA,
   334     TensorRefCollectionB,
   335     TensorRefCollectionC,
   341   ><<< grid, block >>>(
   358   typename TensorRefCollectionA,
   359   typename TensorRefCollectionB,
   360   typename TensorRefCollectionC,
   362   typename AccumulatorType
   368   TensorRefCollectionA 
const& tensor_a,
   369   TensorRefCollectionB 
const& tensor_b,
   371   TensorRefCollectionC &tensor_c) {
   373   BatchedGemm(problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, ScalarType(0));
 Fused multiply-add. 
Definition: functional.h:92
void operator()(gemm::GemmCoord problem_size, ScalarType alpha, TensorRef< ElementA, LayoutA > tensor_a, TensorRef< ElementB, LayoutB > tensor_b, ScalarType beta, TensorRef< ElementC, LayoutC > tensor_c, TensorRef< ElementC, LayoutC > tensor_d, AccumulatorType initial_accum=AccumulatorType(0))
Definition: tools/util/include/cutlass/util/reference/device/gemm.h:267
Describes the size of a matrix tile. 
Definition: matrix_shape.h:42
Definition: aligned_buffer.h:35
Definition: numeric_conversion.h:254
A Coord is a coordinate of arbitrary rank into a tensor or matrix. 
Definition: include/cutlass/gemm/gemm.h:94
Defines common types used for all GEMM-like operators. 
CUTLASS_HOST_DEVICE Index const & n() const 
Returns the GEMM N coordinate. 
Definition: include/cutlass/gemm/gemm.h:137
Defines a structure containing strides and a pointer to tensor data. 
__global__ void BatchedGemm(gemm::GemmCoord problem_size, ScalarType alpha, TensorRefCollectionA tensor_collection_a, TensorRefCollectionB tensor_collection_b, ScalarType beta, TensorRefCollectionC tensor_collection_c, AccumulatorType initial_accum)
Definition: tools/util/include/cutlass/util/reference/device/kernel/gemm.h:108
Boost-like numeric conversion operator for CUTLASS numeric types. 
Definition: tools/util/include/cutlass/util/reference/device/gemm.h:161
void BatchedGemm(gemm::GemmCoord problem_size, int batch_count, ScalarType alpha, TensorRefCollectionA const &tensor_a, TensorRefCollectionB const &tensor_b, ScalarType beta, TensorRefCollectionC &tensor_c, AccumulatorType initial_accum)
Computes a batch of GEMMs over a set of matrices of common dimension. 
Definition: tools/util/include/cutlass/util/reference/device/gemm.h:303
void operator()(gemm::GemmCoord problem_size, ScalarType alpha, TensorRef< ElementA, LayoutA > tensor_a, TensorRef< ElementB, LayoutB > tensor_b, ScalarType beta, TensorRef< ElementC, LayoutC > tensor_c, AccumulatorType initial_accum=AccumulatorType(0))
Definition: tools/util/include/cutlass/util/reference/device/gemm.h:212
Top-level include for all CUTLASS numeric types. 
Definition: numeric_conversion.h:59
void operator()(gemm::GemmCoord problem_size, ScalarType alpha, TensorRef< ElementA, LayoutA > tensor_a, TensorRef< ElementB, LayoutB > tensor_b, ScalarType beta, TensorRef< ElementC, LayoutC > tensor_c, TensorRef< ElementC, LayoutC > tensor_d, AccumulatorType initial_accum=AccumulatorType(0))
Definition: tools/util/include/cutlass/util/reference/device/gemm.h:227
Fused multiply-add. 
Definition: functional.h:101
CUTLASS_HOST_DEVICE Index const & m() const 
Returns the GEMM M coordinate. 
Definition: include/cutlass/gemm/gemm.h:129
void compute_gemm(gemm::GemmCoord problem_size, ScalarType alpha, TensorRef< ElementA, LayoutA > tensor_a, TensorRef< ElementB, LayoutB > tensor_b, ScalarType beta, TensorRef< ElementC, LayoutC > tensor_c, TensorRef< ElementC, LayoutC > tensor_d, AccumulatorType initial_accum)
Definition: tools/util/include/cutlass/util/reference/device/gemm.h:68
void operator()(gemm::GemmCoord problem_size, ScalarType alpha, TensorRef< ElementA, LayoutA > tensor_a, TensorRef< ElementB, LayoutB > tensor_b, ScalarType beta, TensorRef< ElementC, LayoutC > tensor_c, TensorRef< ElementC, LayoutC > tensor_d, AccumulatorType initial_accum=AccumulatorType(0))
Definition: tools/util/include/cutlass/util/reference/device/gemm.h:187
void operator()(gemm::GemmCoord problem_size, ScalarType alpha, TensorRef< ElementA, LayoutA > tensor_a, TensorRef< ElementB, LayoutB > tensor_b, ScalarType beta, TensorRef< ElementC, LayoutC > tensor_c, AccumulatorType initial_accum=AccumulatorType(0))
Definition: tools/util/include/cutlass/util/reference/device/gemm.h:253
Defines properties of matrices used to denote layout and operands to GEMM kernels. 
__global__ void Gemm(gemm::GemmCoord problem_size, ScalarType alpha, TensorRefA tensor_a, TensorRefB tensor_b, ScalarType beta, TensorRefC tensor_c, TensorRefC tensor_d, AccumulatorType initial_accum)
Definition: tools/util/include/cutlass/util/reference/device/kernel/gemm.h:57
Define basic numeric operators with specializations for Array<T, N>. SIMD-ize where possible...
void operator()(gemm::GemmCoord problem_size, ScalarType alpha, TensorRef< ElementA, LayoutA > tensor_a, TensorRef< ElementB, LayoutB > tensor_b, ScalarType beta, TensorRef< ElementC, LayoutC > tensor_c, AccumulatorType initial_accum=AccumulatorType(0))
Definition: tools/util/include/cutlass/util/reference/device/gemm.h:172