47 #include <cuda_runtime.h>   161 template <
typename T> T 
from_string(std::string 
const &);
   258     instruction_shape(instruction_shape), element_accumulator(element_accumulator), opcode_class(opcode_class) {}
   289     int threadblock_stages = 0,
   292     int minimum_compute_capability = 0,
   293     int maximum_compute_capability = 0
   295     threadblock_shape(threadblock_shape), 
   296     threadblock_stages(threadblock_stages), 
   297     warp_count(warp_count),
   298     math_instruction(math_instruction),
   299     minimum_compute_capability(minimum_compute_capability),
   300     maximum_compute_capability(maximum_compute_capability) { }
   319     char const * name = 
"unknown",
   323     name(name), kind(kind), tile_description(tile_description) { }
   351     int log_extent_range = 24,
   352     int log_stride_range = 24
   356     alignment(alignment), 
   357     log_extent_range(log_extent_range), 
   358     log_stride_range(log_stride_range) { }
   404     gemm_kind(gemm_kind),
   408     element_epilogue(element_epilogue),
   409     split_k_mode(split_k_mode),
   410     transform_A(transform_A),
   411     transform_B(transform_B) {} 
   425   virtual Status can_implement(
   426     void const *configuration, 
   427     void const *arguments) 
const = 0;
   429   virtual uint64_t get_host_workspace_size(
   430     void const *configuration) 
const = 0;
   432   virtual uint64_t get_device_workspace_size(
   433     void const *configuration) 
const = 0;
   435   virtual Status initialize(
   436     void const *configuration, 
   437     void *host_workspace, 
   438     void *device_workspace, 
   439     cudaStream_t stream = 
nullptr) 
const = 0;
   442     void const *arguments,
   443     void *host_workspace, 
   444     void *device_workspace = 
nullptr, 
   445     cudaStream_t stream = 
nullptr) 
const = 0;
   565   void const * 
const *
A;
   566   void const * 
const *
B;
   567   void const * 
const *
C;
 int64_t lda
Definition: library.h:609
int alignment
Alignment restriction on pointers, strides, and extents. 
Definition: library.h:336
void const *const * A
Definition: library.h:565
virtual ~Operation()
Definition: library.h:421
High-level description of an operation. 
Definition: library.h:304
char const * to_string(OperationKind type, bool pretty=false)
Converts a NumericType enumerant to a string. 
Definition: aligned_buffer.h:35
bool is_complex_type(NumericTypeID type)
Returns true if the numeric type is a complex data type or false if real-valued. 
void *const * D
Definition: library.h:568
LayoutTypeID layout
Enumerant identifying the layout function for the tensor. 
Definition: library.h:333
GemmKind gemm_kind
Indicates the kind of GEMM performed. 
Definition: library.h:367
int64_t ldc
Definition: library.h:587
Arguments for GEMM. 
Definition: library.h:477
int batch_count
Definition: library.h:560
ComplexTransform
Enumeraed type describing a transformation on a complex value. 
Definition: library.h:111
void const *const * C
Definition: library.h:567
gemm::GemmCoord problem_size
Definition: library.h:583
Configuration for batched GEMM in which multiple matrix products are computed. 
Definition: library.h:551
bool is_signed_integer(NumericTypeID type)
Returns true if numeric type is a signed integer. 
GemmKind
Enumeration indicating what kind of GEMM operation to perform. 
Definition: library.h:149
NumericTypeID get_real_type(NumericTypeID type)
Returns the real-valued type underlying a type (only different from 'type' if complex) ...
OperationKind from_string< OperationKind >(std::string const &str)
Parses a NumericType enumerant from a string. 
Definition: include/cutlass/gemm/gemm.h:94
int get_layout_stride_rank(LayoutTypeID layout_id)
Returns the rank of a layout's stride base on the LayoutTypeID. 
int64_t ldb
Leading dimension of B matrix. 
Definition: library.h:517
int64_t const * ldc
Definition: library.h:557
int64_t batched_stride_B
Definition: library.h:620
Complex valued GEMM in which real and imaginary parts are separated by a stride. 
Definition: library.h:581
int log_stride_range
log2() of the maximum value each relevant stride may have 
Definition: library.h:342
Defines common types used for all GEMM-like operators. 
ComplexTransform transform_A
Transformation on A operand. 
Definition: library.h:385
int64_t imag_stride_B
Definition: library.h:615
GemmDescription(GemmKind gemm_kind=GemmKind::kGemm, TensorDescription const &A=TensorDescription(), TensorDescription const &B=TensorDescription(), TensorDescription const &C=TensorDescription(), NumericTypeID element_epilogue=NumericTypeID::kInvalid, SplitKMode split_k_mode=SplitKMode::kNone, ComplexTransform transform_A=ComplexTransform::kNone, ComplexTransform transform_B=ComplexTransform::kNone)
Definition: library.h:394
int sizeof_bits(NumericTypeID type)
Returns the size of a data type in bits. 
Base class for all device-wide operations. 
Definition: library.h:418
int64_t imag_stride_A
Definition: library.h:590
NumericTypeID from_string< NumericTypeID >(std::string const &str)
Parses a NumericType enumerant from a string. 
LayoutTypeID
Layout type identifier. 
Definition: library.h:63
OpcodeClassID
Indicates the classificaition of the math instruction. 
Definition: library.h:139
int64_t ldc
Definition: library.h:611
std::string lexical_cast(int64_t int_value)
Lexical cast from int64_t to string. 
ScalarPointerMode pointer_mode
Enumerant indicating whether alpha/beta point to host or device memory. 
Definition: library.h:498
int64_t const * ldb
Definition: library.h:556
gemm::GemmCoord problem_size
Definition: library.h:553
int64_t batched_stride_C
Definition: library.h:621
OperationDescription(char const *name="unknown", OperationKind kind=OperationKind::kInvalid, TileDescription const &tile_description=TileDescription())
Definition: library.h:318
int maximum_compute_capability
Minimum compute capability (e.g. 70, 75) of a device eligible to run the operation. 
Definition: library.h:281
Configuration for basic GEMM operations. 
Definition: library.h:455
int64_t imag_stride_D
Definition: library.h:617
Definition: library.h:238
void const * B
Pointer to B matrix. 
Definition: library.h:483
int64_t imag_stride_A
Definition: library.h:614
TensorDescription A
Describes the A operand. 
Definition: library.h:370
Structure describing the tiled structure of a GEMM-like computation. 
Definition: library.h:263
int split_k_slices
Number of partitions of K dimension. 
Definition: library.h:473
int64_t imag_stride_B
Definition: library.h:591
OpcodeClassID from_string< OpcodeClassID >(std::string const &str)
Converts a OpcodeClassID enumerant from a string. 
void const * A
Pointer to A matrix. 
Definition: library.h:480
Defines layout functions used by TensorRef and derived classes for common 4-D and 5-D tensor formats...
int64_t ldd
Leading dimension of D matrix. 
Definition: library.h:470
ComplexTransform transform_B
Transformation on B operand. 
Definition: library.h:388
int64_t ldb
Definition: library.h:610
bool is_signed_type(NumericTypeID type)
Returns true if numeric type is signed. 
NumericTypeID element_epilogue
Describes the data type of the scalars passed to the epilogue. 
Definition: library.h:379
int64_t const * lda
Definition: library.h:555
int64_t const * ldd
Definition: library.h:558
int minimum_compute_capability
Minimum compute capability (e.g. 70, 75) of a device eligible to run the operation. 
Definition: library.h:278
int64_t ldd
Definition: library.h:588
int64_t batch_stride_C
Stride between instances of the C matrix in memory. 
Definition: library.h:532
void const *const * B
Definition: library.h:566
NumericTypeID
Numeric data type. 
Definition: library.h:77
int64_t lda
Definition: library.h:585
cutlass::gemm::GemmCoord warp_count
Number of warps in each logical dimension. 
Definition: library.h:272
int64_t lda
Leading dimension of A matrix. 
Definition: library.h:461
bool is_float_type(NumericTypeID type)
Returns true if numeric type is floating-point type. 
TileDescription(cutlass::gemm::GemmCoord threadblock_shape=cutlass::gemm::GemmCoord(), int threadblock_stages=0, cutlass::gemm::GemmCoord warp_count=cutlass::gemm::GemmCoord(), MathInstructionDescription math_instruction=MathInstructionDescription(), int minimum_compute_capability=0, int maximum_compute_capability=0)
Definition: library.h:287
bool cast_from_double(std::vector< uint8_t > &bytes, NumericTypeID type, double src)
Casts from a real value represented as a double to the destination type. Returns true if successful...
NumericTypeID element_accumulator
Describes the data type of the internal accumulator. 
Definition: library.h:244
Defines a canonical coordinate for rank=4 tensors offering named indices. 
TensorDescription B
Describes the B operand. 
Definition: library.h:373
void const * alpha
Definition: library.h:569
void const * beta
Host or device pointer to beta scalar. 
Definition: library.h:495
int64_t ldd
Leading dimension of D matrix. 
Definition: library.h:523
OpcodeClassID opcode_class
Classification of math instruction. 
Definition: library.h:247
gemm::GemmCoord problem_size
GEMM problem size. 
Definition: library.h:511
int64_t ldc
Leading dimension of C matrix. 
Definition: library.h:467
void * D
Pointer to D matrix. 
Definition: library.h:489
bool cast_from_uint64(std::vector< uint8_t > &bytes, NumericTypeID type, uint64_t src)
Casts from an unsigned int64 to the destination type. Returns true if successful. ...
TensorDescription C
Describes the source and destination matrices. 
Definition: library.h:376
int64_t ldb
Definition: library.h:586
int64_t imag_stride_C
Definition: library.h:616
Batched complex valued GEMM in which real and imaginary parts are separated by a stride. 
Definition: library.h:605
int64_t batched_stride_D
Definition: library.h:622
Configuration for batched GEMM in which multiple matrix products are computed. 
Definition: library.h:508
int64_t batch_stride_A
Stride between instances of the A matrix in memory. 
Definition: library.h:526
bool is_integer_type(NumericTypeID type)
Returns true if numeric type is integer. 
ScalarPointerMode
Enumeration indicating whether scalars are in host or device memory. 
Definition: library.h:123
NumericTypeID element
Numeric type of an individual element. 
Definition: library.h:330
int batch_count
Number of GEMMs in batch. 
Definition: library.h:538
void const * C
Pointer to C matrix. 
Definition: library.h:486
T from_string(std::string const &)
Lexical cast from string. 
int threadblock_stages
Describes the number of pipeline stages in the threadblock-scoped mainloop. 
Definition: library.h:269
Defines a canonical coordinate for rank=2 matrices offering named indices. 
int64_t imag_stride_D
Definition: library.h:593
LayoutTypeID from_string< LayoutTypeID >(std::string const &str)
Parses a LayoutType enumerant from a string. 
ScalarPointerMode pointer_mode
Definition: library.h:571
MathInstructionDescription(cutlass::gemm::GemmCoord instruction_shape=cutlass::gemm::GemmCoord(), NumericTypeID element_accumulator=NumericTypeID::kInvalid, OpcodeClassID opcode_class=OpcodeClassID::kInvalid)
Definition: library.h:253
int64_t batched_stride_A
Definition: library.h:619
Description of all GEMM computations. 
Definition: library.h:364
int64_t lda
Leading dimension of A matrix. 
Definition: library.h:514
gemm::GemmCoord problem_size
GEMM problem size. 
Definition: library.h:458
SplitKMode split_k_mode
Describes the structure of parallel reductions. 
Definition: library.h:382
bool cast_from_int64(std::vector< uint8_t > &bytes, NumericTypeID type, int64_t src)
Casts from a signed int64 to the destination type. Returns true if successful. 
int log_extent_range
log2() of the maximum extent of each dimension 
Definition: library.h:339
char const * name
Unique identifier describing the operation. 
Definition: library.h:307
int64_t batch_stride_B
Stride between instances of the B matrix in memory. 
Definition: library.h:529
cutlass::gemm::GemmCoord instruction_shape
Shape of the target math instruction. 
Definition: library.h:241
void const * beta
Definition: library.h:570
TileDescription tile_description
Describes the tiled structure of a GEMM-like computation. 
Definition: library.h:313
int64_t ldc
Leading dimension of C matrix. 
Definition: library.h:520
Structure describing the properties of a tensor. 
Definition: library.h:327
int64_t ldb
Leading dimension of B matrix. 
Definition: library.h:464
gemm::GemmCoord problem_size
Definition: library.h:607
bool is_unsigned_integer(NumericTypeID type)
returns true if numeric type is an unsigned integer 
Arguments for GEMM - used by all the GEMM operations. 
Definition: library.h:564
OperationKind
Enumeration indicating the kind of operation. 
Definition: library.h:117
void const * alpha
Host or device pointer to alpha scalar. 
Definition: library.h:492
OperationKind kind
Kind of operation. 
Definition: library.h:310
cutlass::gemm::GemmCoord threadblock_shape
Describes the shape of a threadblock (in elements) 
Definition: library.h:266
int64_t ldd
Definition: library.h:612
int64_t imag_stride_C
Definition: library.h:592
MathInstructionDescription math_instruction
Core math instruction. 
Definition: library.h:275
Basic include for CUTLASS. 
SplitKMode
Describes how reductions are performed across threadblocks. 
Definition: library.h:130
Status
Status code returned by CUTLASS operations. 
Definition: cutlass.h:39
int64_t batch_stride_D
Stride between instances of the D matrix in memory. 
Definition: library.h:535
TensorDescription(NumericTypeID element=NumericTypeID::kInvalid, LayoutTypeID layout=LayoutTypeID::kInvalid, int alignment=1, int log_extent_range=24, int log_stride_range=24)
Definition: library.h:347