35 #if defined(CUTLASS_ARCH_WMMA_ENABLED)    75 class MmaTensorOpWmmaMultiplicandTileIterator;
    97 class MmaTensorOpWmmaMultiplicandTileIterator<
    98     Shape_, Operand::
kA, Element_, Layout_,
    99     OpDelta_, 32, Policy_> {
   103   using Shape = Shape_;
   109   using Element = Element_;
   112   using Layout = Layout_;
   115   static int const kOpDelta = OpDelta_;
   118   using Policy = Policy_;
   125   using TensorRef = TensorRef<Element, Layout>;
   137   using WmmaShape = MatrixShape<
   138     Policy::Operator::Shape::kM, 
   139     Policy::Operator::Shape::kK
   143   using WmmaDataType = 
typename cutlass::arch::CutlassToWmmaDataType<Element>::Type;
   146   using Iterations = MatrixShape<
   147     Shape::kRow / WmmaShape::kRow,
   152   using Fragment = WmmaFragmentArray<typename Policy::Operator::FragmentA, Iterations::kCount>;
   160     "MmaTensorOpWmmaMultiplicandTileIterator may only be instantiated for A operands to warp-level Mma.");
   166     "Supported list of memory layouts for WMMA are: RowMajor, ColumnMajor");
   170     "Alternative arrangements not supported at present.");
   177   char const *pointer_;
   192   MmaTensorOpWmmaMultiplicandTileIterator() { }
   196   MmaTensorOpWmmaMultiplicandTileIterator(
   197     TensorRef 
const &ref, 
   199   ): pointer_(reinterpret_cast<char const*>(ref.data())), byte_offset_(0), stride_(ref.stride(0)), layout_(ref.stride(0)) { 
   205   MmaTensorOpWmmaMultiplicandTileIterator &add_pointer_offset(LongIndex offset) {
   212   MmaTensorOpWmmaMultiplicandTileIterator &add_tile_offset(TensorCoord 
const &tile_offset) {
   214     Index elements_offset = layout_({tile_offset.row() * Shape::kRow, tile_offset.column() * WmmaShape::kColumn});
   223   MmaTensorOpWmmaMultiplicandTileIterator & 
operator++() {
   225     Index elements_offset = layout_({0, WmmaShape::kColumn});
   234   MmaTensorOpWmmaMultiplicandTileIterator & 
operator--() {
   236     Index elements_offset = layout_({0, WmmaShape::kColumn});
   245   MmaTensorOpWmmaMultiplicandTileIterator & 
operator+=(TensorCoord 
const &tile_offset) {
   246     add_tile_offset(tile_offset);
   252   MmaTensorOpWmmaMultiplicandTileIterator & 
operator-=(TensorCoord 
const &tile_offset) {
   253     add_tile_offset(-tile_offset);
   259   void load_with_byte_offset(Fragment &frag, Index byte_offset)
 const {
   262     for (
int k = 0; k < Iterations::kColumn; ++k) {
   264       for (
int m = 0; m < Iterations::kRow; ++m) {
   268         const WmmaDataType *ptr = 
reinterpret_cast<const WmmaDataType *
>(pointer_ + byte_offset_ + load_byte_offset + byte_offset); 
   270         nvcuda::wmma::load_matrix_sync(frag[m], ptr, stride_); 
   277   void load(Fragment &frag)
 const {
   278     load_with_byte_offset(frag, 0);
   283   void store_with_byte_offset(Fragment 
const &frag, Index byte_offset)
 const {
   286     for (
int k = 0; k < Iterations::kColumn; ++k) {
   288       for (
int m = 0; m < Iterations::kRow; ++m) {
   292         WmmaDataType *ptr = 
reinterpret_cast<WmmaDataType *
>(pointer_ + byte_offset_ + store_byte_offset + byte_offset);
   294         nvcuda::wmma::store_matrix_sync(ptr, frag[m], stride_); 
   302   void store(Fragment 
const &frag)
 const {
   303     store_with_byte_offset(frag, 0);
   314   void set_kgroup_index(
int k_group) {
   341 class MmaTensorOpWmmaMultiplicandTileIterator<
   342     Shape_, Operand::
kB, Element_, Layout_,
   343     OpDelta_, 32, Policy_> {
   347   using Shape = Shape_;
   353   using Element = Element_;
   356   using Layout = Layout_;
   359   static int const kOpDelta = OpDelta_;
   362   using Policy = Policy_;
   370   using TensorRef = TensorRef<Element, Layout>;
   382   using WmmaShape = MatrixShape<
   383     Policy::Operator::Shape::kK, 
   384     Policy::Operator::Shape::kN
   388   using WmmaDataType = 
typename cutlass::arch::CutlassToWmmaDataType<Element>::Type;
   391   using Iterations = MatrixShape<
   393     Shape::kColumn / WmmaShape::kColumn
   397   using Fragment = WmmaFragmentArray<typename Policy::Operator::FragmentB, Iterations::kCount>;
   405     "MmaTensorOpWmmaMultiplicandTileIterator may only be instantiated for B operands to warp-level Mma.");
   411     "Supported list of memory layouts for WMMA are: RowMajor, ColumnMajor");
   415     "Alternative arrangements not supported at present.");
   422   char const *pointer_;
   437   MmaTensorOpWmmaMultiplicandTileIterator() { }
   441   MmaTensorOpWmmaMultiplicandTileIterator(
   442     TensorRef 
const &ref, 
   444   ): pointer_(reinterpret_cast<char const*>(ref.data())), byte_offset_(0), stride_(ref.stride(0)), layout_(ref.stride(0)) {
   449   MmaTensorOpWmmaMultiplicandTileIterator &add_pointer_offset(LongIndex offset) {
   458   MmaTensorOpWmmaMultiplicandTileIterator &add_tile_offset(TensorCoord 
const &tile_offset) {
   460     Index elements_offset = layout_({tile_offset.row() * WmmaShape::kRow, tile_offset.column() * Shape::kColumn});
   469   MmaTensorOpWmmaMultiplicandTileIterator & 
operator++() {
   471     Index elements_offset = layout_({WmmaShape::kRow, 0});
   480   MmaTensorOpWmmaMultiplicandTileIterator & 
operator--() {
   482     Index elements_offset = layout_({WmmaShape::kRow, 0});
   490   MmaTensorOpWmmaMultiplicandTileIterator & 
operator+=(TensorCoord 
const &tile_offset) {
   491     add_tile_offset(tile_offset);
   497   MmaTensorOpWmmaMultiplicandTileIterator & 
operator-=(TensorCoord 
const &tile_offset) {
   498     add_tile_offset(-tile_offset);
   504   void load_with_byte_offset(Fragment &frag, Index byte_offset)
 const {
   507     for (
int k = 0; k < Iterations::kRow; ++k) {
   509       for (
int n = 0; n < Iterations::kColumn; ++n) {
   513         const WmmaDataType *ptr = 
reinterpret_cast<const WmmaDataType *
>(pointer_ + byte_offset_ + load_byte_offset + byte_offset);
   515         nvcuda::wmma::load_matrix_sync(frag[n], ptr, stride_);        
   521   void load(Fragment &frag)
 const {
   522     load_with_byte_offset(frag, 0);
   527   void store_with_byte_offset(Fragment 
const &frag, Index byte_offset)
 const {
   530     for (
int k = 0; k < Iterations::kRow; ++k) {
   532       for (
int n = 0; n < Iterations::kColumn; ++n) {
   536         WmmaDataType *ptr = 
reinterpret_cast<WmmaDataType *
>(pointer_ + byte_offset_ + store_byte_offset + byte_offset);
   538         nvcuda::wmma::store_matrix_sync(ptr, frag[n], stride_);        
   545   void store(Fragment 
const &frag)
 const {
   546     store_with_byte_offset(frag, 0);
   557   void set_kgroup_index(
int k_group) {
   574 class MmaTensorOpWmmaAccumulatorTileIterator;
   598 class MmaTensorOpWmmaAccumulatorTileIterator
   603   using Shape = Shape_;
   606   using Element = Element_;
   609   using Layout = Layout_;
   612   using OpDelta = OpDelta_;
   615   static int const kThreads = 32;
   618   using Policy = Policy_;
   625   using TensorRef = TensorRef<Element, Layout>;
   637   using WmmaShape = MatrixShape<
   638     Policy::Operator::Shape::kM, 
   639     Policy::Operator::Shape::kN
   643   using WmmaDataType = 
typename cutlass::arch::CutlassToWmmaDataType<Element>::Type;
   646   static nvcuda::wmma::layout_t 
const WmmaLayout = cutlass::arch::CutlassToWmmaLayout<Layout>::value;
   649   using Iterations = MatrixShape<
   650     Shape::kRow / WmmaShape::kRow,
   651     Shape::kColumn / WmmaShape::kColumn
   655   using Fragment = WmmaFragmentArray<typename Policy::Operator::FragmentC, Iterations::kCount>;
   664     "Supported list of memory layouts for WMMA are: RowMajor, ColumnMajor");
   675   MmaTensorOpWmmaAccumulatorTileIterator() { }
   679   MmaTensorOpWmmaAccumulatorTileIterator(
   680     TensorRef 
const &ref, 
   686   MmaTensorOpWmmaAccumulatorTileIterator &add_pointer_offset(LongIndex offset) {
   693   MmaTensorOpWmmaAccumulatorTileIterator &add_tile_offset(TensorCoord 
const &tile_offset) {
   694     ref_.
add_coord_offset({tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn});
   700   MmaTensorOpWmmaAccumulatorTileIterator & 
operator++() {
   707   MmaTensorOpWmmaAccumulatorTileIterator & 
operator--() {
   714   MmaTensorOpWmmaAccumulatorTileIterator & 
operator+=(TensorCoord 
const &tile_offset) {
   715     add_tile_offset(tile_offset);
   721   MmaTensorOpWmmaAccumulatorTileIterator & 
operator-=(TensorCoord 
const &tile_offset) {
   722     add_tile_offset(-tile_offset);
   728   void load_with_pointer_offset(Fragment &frag, Index pointer_offset)
 const {
   731     for (
int m = 0; m < Iterations::kRow; ++m) {
   733       for (
int n = 0; n < Iterations::kColumn; ++n) {
   735         const WmmaDataType * ptr = 
reinterpret_cast<const WmmaDataType*
> (ref_.
data() + ref_.
offset({m * WmmaShape::kRow, n * WmmaShape::kColumn}) + pointer_offset);
   737         nvcuda::wmma::load_matrix_sync(frag[m * Iterations::kColumn + n], ptr, ref_.
stride()[0], WmmaLayout); 
   744   void load(Fragment &frag)
 const {
   745     load_with_pointer_offset(frag, 0);
   750   void store_with_pointer_offset(Fragment 
const &frag, Index pointer_offset)
 const {
   753     for (
int m = 0; m < Iterations::kRow; ++m) {
   755       for (
int n = 0; n < Iterations::kColumn; ++n) {
   757         WmmaDataType * ptr = 
reinterpret_cast<WmmaDataType*
> (ref_.
data() + ref_.
offset({m * WmmaShape::kRow, n * WmmaShape::kColumn}) + pointer_offset);
   759         nvcuda::wmma::store_matrix_sync(ptr, frag[m * Iterations::kColumn + n], ref_.
stride()[0], WmmaLayout); 
   766   void store(Fragment 
const &frag)
 const {
   767     store_with_pointer_offset(frag, 0);
   778   void set_kgroup_index(
int k_group) {
   791 #endif // if defined(CUTLASS_ARCH_WMMA_ENABLED) 
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
Definition: aligned_buffer.h:35
static int const value
Definition: numeric_types.h:43
Defines a structure containing strides, bounds, and a pointer to tensor data. 
CUTLASS_HOST_DEVICE Element * data() const 
Returns the pointer to referenced data. 
Definition: tensor_ref.h:254
Operand
GEMM operand enumeration: D = A * B + C. 
Definition: include/cutlass/gemm/gemm.h:39
Architecture-specific operators on memory added for SM75. 
Defines common types used for all GEMM-like operators. 
CUTLASS_HOST_DEVICE half_t & operator+=(half_t &lhs, half_t const &rhs)
Definition: half.h:654
CUTLASS_HOST_DEVICE TensorRef & add_coord_offset(TensorCoord const &coord)
Adds an offset to each pointer. 
Definition: tensor_ref.h:326
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
CUTLASS_HOST_DEVICE half_t & operator-=(half_t &lhs, half_t const &rhs)
Definition: half.h:664
Defines layout functions used by TensorRef and derived classes for common 4-D and 5-D tensor formats...
CUTLASS_HOST_DEVICE half_t & operator++(half_t &lhs)
Definition: half.h:694
CUTLASS_HOST_DEVICE Stride stride() const 
Returns the layout object's stride vector. 
Definition: tensor_ref.h:277
typename Layout::TensorCoord TensorCoord
Coordinate in logical tensor space. 
Definition: tensor_ref.h:171
Defines a Shape template for matrix tiles. 
CUTLASS_HOST_DEVICE half_t & operator--(half_t &lhs)
Definition: half.h:706
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Top-level include for all CUTLASS numeric types. 
CUTLASS_HOST_DEVICE LongIndex offset(TensorCoord const &coord) const 
Computes the offset of an index from the origin of the tensor. 
Definition: tensor_ref.h:301
typename Layout::Index Index
Index type. 
Definition: tensor_ref.h:165
Defines layout functions used by TensorRef and derived classes. 
Defines layout functions used by TensorRef and derived classes for pitch-linear memory. 
CUTLASS_HOST_DEVICE TensorRef & add_pointer_offset(LongIndex offset_)
Adds an offset to each pointer. 
Definition: tensor_ref.h:319
Templates exposing architecture support for warp matrix multiply-add (WMMA) operations. 
Basic include for CUTLASS. 
typename Layout::LongIndex LongIndex
Long index used for pointer offsets. 
Definition: tensor_ref.h:168