67   int PartitionGroupSize = 1
    87   int PartitionGroupSize
   124     "The warp-level GEMM M size must be divisible by the number of threads arranged along the M dimension.");
   126   static_assert(Shape::kRow > 0, 
"Shape::kRow must be greater than zero.");
   127   static_assert(Shape::kColumn > 0, 
"Shape::kColumn must be greater than zero.");
   128   static_assert(Policy::WarpShape::kRow > 0, 
"Policy::WarpShape::kRow must be greater than zero.");
   129   static_assert(Shape::kRow / Policy::WarpShape::kRow > 0, 
"Shape::kRow / Policy::WarpShape::kRow must be greater than zero.");
   133     Shape::kRow / Policy::WarpShape::kRow,
   137   static_assert(!(ThreadShape::kRow % Policy::LaneMmaShape::kM), 
   138     "Thread-level GEMM must be divisible by Policy::LaneMmaShape.");
   142     ThreadShape::kRow / Policy::LaneMmaShape::kM,
   147   using Fragment = Array<Element, ThreadShape::kCount>;
   168     typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();
   170     MatrixCoord lane_offset = lane_layout.inverse(lane_id) * 
   176       reinterpret_cast<Array<Element, Policy::LaneMmaShape::kM> *
>(ref.
data()),
   177       ref.
stride(0) / Policy::LaneMmaShape::kM);
   193       coord.row() * Shape::kRow / Policy::LaneMmaShape::kM, 
   194       coord.column() * Shape::kColumn});
   220     Array<Element, Policy::LaneMmaShape::kM> *dst_ptr = 
   221       reinterpret_cast<Array<Element, Policy::LaneMmaShape::kM> *
>(&frag);
   224     for (
int k = 0; k < Iterations::kColumn; ++k) {
   226       for (
int m = 0; m < Iterations::kRow; ++m) {
   227         dst_ptr[m + k * Iterations::kRow] = 
   228           *(ref_.
data() + ref_.
offset({m * Policy::WarpShape::kRow, k}) + pointer_offset / Policy::LaneMmaShape::kM);
   235     load_with_pointer_offset(frag, 0);
   242     Array<Element, Policy::LaneMmaShape::kM> 
const *src_ptr = 
   243       reinterpret_cast<Array<Element, Policy::LaneMmaShape::kM> *
>(&frag);
   246     for (
int k = 0; k < Iterations::kN; ++k) {
   248       for (
int m = 0; m < Iterations::kM; ++m) {
   249         *(ref_.
data() + ref_.
offset(m * Policy::WarpShape::kM, k) + pointer_offset / Policy::LaneMmaShape::kM) = 
   250           src_ptr[m + k * Iterations::kM];
   258     store_with_pointer_offset(frag, 0);
   290   int PartitionGroupSize
   326   static_assert(!(Shape::kColumn % Policy::WarpShape::kColumn), 
   327     "The warp-level GEMM N size must be divisible by the number of threads arranged along the N dimension.");
   329   static_assert(Shape::kRow > 0, 
"Shape::kRow must be greater than zero.");
   330   static_assert(Shape::kColumn > 0, 
"Shape::kColumn must be greater than zero.");
   331   static_assert(Policy::WarpShape::kColumn > 0, 
"Policy::WarpShape::kColumn must be greater than zero.");
   332   static_assert(Shape::kColumn / Policy::WarpShape::kColumn > 0, 
"Shape::kColumn / Policy::WarpShape::kColumn must be greater than zero.");
   337     Shape::kColumn / Policy::WarpShape::kColumn
   340   static_assert(!(ThreadShape::kColumn % Policy::LaneMmaShape::kN), 
   341     "Thread-level GEMM must be divisible by Policy::LaneMmaShape.");
   346     ThreadShape::kColumn / Policy::LaneMmaShape::kN
   350   using Fragment = Array<Element, ThreadShape::kCount>;
   372     typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();
   374     MatrixCoord lane_offset = lane_layout.inverse(lane_id) * 
   380       reinterpret_cast<Array<Element, Policy::LaneMmaShape::kN> *
>(ref.
data()),
   381       ref.
stride(0) / Policy::LaneMmaShape::kN);
   396       coord.row() * Shape::kRow, 
   397       coord.column() * Shape::kColumn / Policy::LaneMmaShape::kN});
   424     Array<Element, Policy::LaneMmaShape::kN> *dst_ptr = 
   425       reinterpret_cast<Array<Element, Policy::LaneMmaShape::kN> *
>(&frag);
   428     for (
int k = 0; k < Iterations::kRow; ++k) {
   430       for (
int n = 0; n < Iterations::kColumn; ++n) {
   431         dst_ptr[n + k * Iterations::kColumn] = 
   432           *(ref_.
data() + ref_.
offset({k, n * Policy::WarpShape::kColumn}) + pointer_offset / Policy::LaneMmaShape::kN);
   440     load_with_pointer_offset(frag, 0);
   447     Array<Element, Policy::LaneMmaShape::kN> 
const *src_ptr = 
   448       reinterpret_cast<Array<Element, Policy::LaneMmaShape::kN> *
>(&frag);
   451     for (
int k = 0; k < Iterations::kM; ++k) {
   453       for (
int n = 0; n < Iterations::kN; ++n) {
   454         *(ref_.
data() + ref_.
offset({k, n * Policy::WarpShape::kN}) + pointer_offset / Policy::LaneMmaShape::kN) = 
   455           src_ptr[n + k * Iterations::kN];
   463     store_with_pointer_offset(frag, 0);
   528     (!(Shape::kRow % Policy::WarpShape::kRow)) && (!(Shape::kColumn % Policy::WarpShape::kColumn)),
   529     "Warp-level GEMM shape must be divisible by the arrangement of threads in the warp.");
   531   static_assert(Shape::kRow > 0, 
"Shape::kRow must be greater than zero.");
   532   static_assert(Shape::kColumn > 0, 
"Shape::kColumn must be greater than zero.");
   533   static_assert(Policy::WarpShape::kRow > 0, 
"Policy::WarpShape::kRow must be greater than zero.");
   534   static_assert(Policy::WarpShape::kColumn > 0, 
"Policy::WarpShape::kColumn must be greater than zero.");
   535   static_assert(Shape::kRow / Policy::WarpShape::kRow > 0, 
"Shape::kRow / Policy::WarpShape::kRow must be greater than zero.");
   536   static_assert(Shape::kColumn / Policy::WarpShape::kColumn > 0, 
"Shape::kColumn / Policy::WarpShape::kColumn must be greater than zero.");
   540     Shape::kRow / Policy::WarpShape::kRow,
   541     Shape::kColumn / Policy::WarpShape::kColumn
   545     (!(ThreadShape::kRow % Policy::LaneMmaShape::kM)) && (!(ThreadShape::kColumn % Policy::LaneMmaShape::kN)),
   546     "Warp-level GEMM shape must be divisible by the arrangement of threads in the warp.");
   550     ThreadShape::kRow / Policy::LaneMmaShape::kM,
   551     ThreadShape::kColumn / Policy::LaneMmaShape::kN
   555     Policy::WarpShape::kRow * Policy::LaneMmaShape::kM,
   556     Policy::WarpShape::kColumn * Policy::LaneMmaShape::kN
   560   using Fragment = Array<Element, ThreadShape::kCount>;
   581     typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();
   583     MatrixCoord lane_offset = lane_layout.inverse(lane_id) * 
   584       MatrixCoord(Policy::LaneMmaShape::kM, Policy::LaneMmaShape::kN);
   601       coord.row() * Shape::kRow, 
   602       coord.column() * Shape::kColumn});
   629     Index pointer_offset)
 const {               
   632     for (
int mma_n = 0; mma_n < Iterations::kN; ++mma_n) {
   634       for (
int n = 0; n < Policy::LaneMmaShape::kN; ++n) {
   636         Array<Element, Policy::LaneMmaShape::kM> 
const *src_ptr = 
   637           reinterpret_cast<Array<Element, Policy::LaneMmaShape::kM> 
const *
>(
   638             ref_.
data() + pointer_offset + ref_.
offset({0, mma_n * Delta::kN + n}));
   641         for (
int mma_m = 0; mma_m < Iterations::kM; ++mma_m) {
   643           Array<Element, Policy::LaneMmaShape::kM> *dst_ptr = 
   644             reinterpret_cast<Array<Element, Policy::LaneMmaShape::kM> *
>(&frag) + 
   645             mma_m + Iterations::kM * (n + mma_n * Policy::LaneMmaShape::kN);
   647           *dst_ptr = src_ptr[mma_m * Policy::WarpShape::kM];
   656     load_with_pointer_offset(frag, 0);
   664     for (
int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) {
   666       for (
int n = 0; n < Policy::LaneMmaShape::kN; ++n) {
   668         Array<Element, Policy::LaneMmaShape::kM> *dst_ptr= 
   669           reinterpret_cast<Array<Element, Policy::LaneMmaShape::kM> *
>(
   670             ref_.
data() + pointer_offset + ref_.
offset({0, mma_n * Delta::kColumn + n}));
   673         for (
int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) {
   675           Array<Element, Policy::LaneMmaShape::kM> 
const *src_ptr = 
   676             reinterpret_cast<Array<Element, Policy::LaneMmaShape::kM> 
const *
>(&frag) + 
   677             mma_m + Iterations::kRow * (n + mma_n * Policy::LaneMmaShape::kN);
   679           dst_ptr[mma_m * Policy::WarpShape::kRow] = *src_ptr;
   687     store_with_pointer_offset(frag, 0);
   740     (!(Shape::kRow % Policy::WarpShape::kRow)) && (!(Shape::kColumn % Policy::WarpShape::kColumn)),
   741     "Warp-level GEMM shape must be divisible by the arrangement of threads in the warp.");
   743   static_assert(Shape::kRow > 0, 
"Shape::kRow must be greater than zero.");
   744   static_assert(Shape::kColumn > 0, 
"Shape::kColumn must be greater than zero.");
   745   static_assert(Policy::WarpShape::kRow > 0, 
"Policy::WarpShape::kRow must be greater than zero.");
   746   static_assert(Policy::WarpShape::kColumn > 0, 
"Policy::WarpShape::kColumn must be greater than zero.");
   747   static_assert(Shape::kRow / Policy::WarpShape::kRow > 0, 
"Shape::kRow / Policy::WarpShape::kRow must be greater than zero.");
   748   static_assert(Shape::kColumn / Policy::WarpShape::kColumn > 0, 
"Shape::kColumn / Policy::WarpShape::kColumn must be greater than zero.");
   752     Shape::kRow / Policy::WarpShape::kRow,
   753     Shape::kColumn / Policy::WarpShape::kColumn
   757     (!(ThreadShape::kRow % Policy::LaneMmaShape::kM)) && (!(ThreadShape::kColumn % Policy::LaneMmaShape::kN)),
   758     "Warp-level GEMM shape must be divisible by the arrangement of threads in the warp.");
   762     ThreadShape::kRow / Policy::LaneMmaShape::kM,
   763     ThreadShape::kColumn / Policy::LaneMmaShape::kN
   767     Policy::WarpShape::kRow * Policy::LaneMmaShape::kM,
   768     Policy::WarpShape::kColumn * Policy::LaneMmaShape::kN
   772   using Fragment = Array<Element, ThreadShape::kCount>;
   793     typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();
   795     MatrixCoord lane_offset = lane_layout.inverse(lane_id) * 
   796       MatrixCoord(Policy::LaneMmaShape::kM, Policy::LaneMmaShape::kN);
   798     ref_.add_coord_offset(lane_offset);
   804     ref_.add_pointer_offset(offset);
   812     ref_.add_coord_offset({
   813       coord.row() * Shape::kRow, 
   814       coord.column() * Shape::kColumn});
   823     ref_.add_coord_offset({Shape::kRow, 0});
   832     ref_.add_coord_offset({-Shape::kRow, 0});
   841     Index pointer_offset)
 const {               
   844     for (
int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) {
   846       for (
int m = 0; m < Policy::LaneMmaShape::kM; ++m) {
   848         Array<Element, Policy::LaneMmaShape::kN> 
const *src_ptr = 
   849           reinterpret_cast<Array<Element, Policy::LaneMmaShape::kN> 
const *
>(
   850             ref_.data() + pointer_offset + ref_.offset({mma_m * Delta::kRow + m, 0}));
   853         for (
int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) {
   855           Array<Element, Policy::LaneMmaShape::kN> *dst_ptr = 
   856             reinterpret_cast<Array<Element, Policy::LaneMmaShape::kN> *
>(&frag) + 
   857             mma_n + Iterations::kColumn * (m + mma_m * Policy::LaneMmaShape::kM);
   859           *dst_ptr = src_ptr[mma_n * Policy::WarpShape::kColumn];
   868     load_with_pointer_offset(frag, 0);
   876     for (
int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) {
   878       for (
int m = 0; m < Policy::LaneMmaShape::kM; ++m) {
   880         Array<Element, Policy::LaneMmaShape::kN> *dst_ptr = 
   881           reinterpret_cast<Array<Element, Policy::LaneMmaShape::kN> *
>(
   882             ref_.data() + pointer_offset + ref_.offset({mma_m * Delta::kRow + m, 0}));
   885         for (
int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) {
   887           Array<Element, Policy::LaneMmaShape::kN> 
const *src_ptr = 
   888             reinterpret_cast<Array<Element, Policy::LaneMmaShape::kN> 
const *
>(&frag) + 
   889             mma_n + Iterations::kColumn * (m + mma_m * Policy::LaneMmaShape::kM);
   891           dst_ptr[mma_n * Policy::WarpShape::kColumn] = *src_ptr;
   900     store_with_pointer_offset(frag, 0);
   922   int PartitionGroupSize
   955   static const int kInterleave = 4;
   958   static const int kPartitionsK = PartitionsK;
   961   static const int kGroupPerTile = PartitionGroupSize / Shape::kColumn;
   968     "The warp-level GEMM M size must be divisible by the number of threads arranged along the M dimension.");
   970   static_assert(Shape::kRow > 0, 
"Shape::kRow must be greater than zero.");
   971   static_assert(Shape::kColumn > 0, 
"Shape::kColumn must be greater than zero.");
   972   static_assert(Policy::WarpShape::kRow > 0, 
"Policy::WarpShape::kRow must be greater than zero.");
   973   static_assert(Shape::kRow / Policy::WarpShape::kRow > 0, 
"Shape::kRow / Policy::WarpShape::kRow must be greater than zero.");
   977     Shape::kRow / Policy::WarpShape::kRow,
   981   static_assert(!(ThreadShape::kRow % Policy::LaneMmaShape::kM) && !(ThreadShape::kColumn % Policy::LaneMmaShape::kK), 
   982     "Thread-level GEMM must be divisible by Policy::LaneMmaShape.");
   986     ThreadShape::kRow / Policy::LaneMmaShape::kM,
   987     ThreadShape::kColumn / Policy::LaneMmaShape::kK
   991   using Fragment = Array<Element, ThreadShape::kCount>;
  1013     typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();
  1015     MatrixCoord lane_offset = lane_layout.inverse(lane_id) * 
  1021     ref_.
reset(
reinterpret_cast<Array<Element, Policy::LaneMmaShape::kMK> *
>(ref.
data()), ref.
stride(0)/Policy::LaneMmaShape::kMK);
  1037       coord.row() * Shape::kRow / Policy::LaneMmaShape::kMK, 
  1038       coord.column() * Shape::kColumn});
  1047     add_tile_offset({0, 1});
  1049     if (kPartitionsK > 1) {
  1052       if (k_group_idx_ == kGroupPerTile) {
  1054         add_tile_offset({0, kGroupPerTile * (kPartitionsK-1)});
  1074     Array<Element, Policy::LaneMmaShape::kMK > *dst_ptr = 
  1075       reinterpret_cast<Array<Element, Policy::LaneMmaShape::kMK> *
>(&frag);
  1078     for (
int k = 0; k < Iterations::kColumn; ++k) {
  1081       for (
int m = 0; m < Iterations::kRow; ++m) {
  1083         dst_ptr[m + k * Iterations::kRow] = 
  1084           *((ref_.
data() + ref_.
offset({m * Policy::WarpShape::kRow / kInterleave, 
  1085                   k*Policy::LaneMmaShape::kK}) + pointer_offset / Policy::LaneMmaShape::kM));
  1093     load_with_pointer_offset(frag, 0);
  1100     Array<Element, Policy::LaneMmaShape::kMK> 
const *src_ptr = 
  1101       reinterpret_cast<Array<Element, Policy::LaneMmaShape::kMK > *
>(&frag);
  1104     for (
int k = 0; k < Iterations::kN; ++k) {
  1106       for (
int m = 0; m < Iterations::kM; ++m) {
  1107         *(ref_.
data() + ref_.
offset(m * Policy::WarpShape::kM, k) + pointer_offset / Policy::LaneMmaShape::kM) = 
  1108           src_ptr[m + k * Iterations::kM];
  1116     store_with_pointer_offset(frag, 0);
  1148   int PartitionGroupSize
  1181   static const int kInterleave = 4;
  1184   static const int kPartitionsK = PartitionsK;
  1187   static const int kGroupPerTile = PartitionGroupSize / Shape::kRow;
  1193   static_assert(!(Shape::kColumn % Policy::WarpShape::kColumn), 
  1194     "The warp-level GEMM N size must be divisible by the number of threads arranged along the N dimension.");
  1196   static_assert(Shape::kRow > 0, 
"Shape::kRow must be greater than zero.");
  1197   static_assert(Shape::kColumn > 0, 
"Shape::kColumn must be greater than zero.");
  1198   static_assert(Policy::WarpShape::kColumn > 0, 
"Policy::WarpShape::kColumn must be greater than zero.");
  1199   static_assert(Shape::kColumn / Policy::WarpShape::kColumn > 0, 
"Shape::kColumn / Policy::WarpShape::kColumn must be greater than zero.");
  1204     Shape::kColumn / Policy::WarpShape::kColumn
  1207   static_assert(!(ThreadShape::kColumn % Policy::LaneMmaShape::kN) && !(ThreadShape::kRow % Policy::LaneMmaShape::kK), 
  1208     "Thread-level GEMM must be divisible by Policy::LaneMmaShape.");
  1212     ThreadShape::kRow / Policy::LaneMmaShape::kK,
  1213     ThreadShape::kColumn / Policy::LaneMmaShape::kN
  1242     typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();
  1244     MatrixCoord lane_offset = lane_layout.inverse(lane_id) * 
  1252       reinterpret_cast<Array<Element, Policy::LaneMmaShape::kKN> *
>(ref.
data()),
  1253       ref.
stride(0) / Policy::LaneMmaShape::kKN);
  1268       coord.row() * Shape::kRow, 
  1269       coord.column() * Shape::kColumn / Policy::LaneMmaShape::kKN});
  1278     add_tile_offset({1, 0});
  1280     if (kPartitionsK > 1) {
  1283       if (k_group_idx_ == kGroupPerTile) {
  1285         add_tile_offset({kGroupPerTile * (kPartitionsK-1), 0});
  1305     Array<Element, Policy::LaneMmaShape::kKN> *dst_ptr = 
  1306       reinterpret_cast<Array<Element, Policy::LaneMmaShape::kKN> *
>(&frag);
  1309     for (
int k = 0; k < Iterations::kRow; ++k) {
  1311       for (
int n = 0; n < Iterations::kColumn; ++n) {
  1312         dst_ptr[n + k * Iterations::kColumn] = 
  1313           *(ref_.
data() + ref_.
offset({k * Policy::LaneMmaShape::kK, 
  1314                 n * Policy::WarpShape::kColumn / kInterleave}) + pointer_offset / Policy::LaneMmaShape::kN);
  1322     load_with_pointer_offset(frag, 0);
  1329     Array<Element, Policy::LaneMmaShape::kN> 
const *src_ptr = 
  1330       reinterpret_cast<Array<Element, Policy::LaneMmaShape::kN> *
>(&frag);
  1333     for (
int k = 0; k < Iterations::kM; ++k) {
  1335       for (
int n = 0; n < Iterations::kN; ++n) {
  1336         *(ref_.
data() + ref_.
offset({k, n * Policy::WarpShape::kN}) + pointer_offset / Policy::LaneMmaShape::kN) = 
  1337           src_ptr[n + k * Iterations::kN];
  1345     store_with_pointer_offset(frag, 0);
 CUTLASS_HOST_DEVICE void store(Fragment const &frag) const 
Stores a fragment to memory at the location pointed to by the iterator. 
Definition: mma_simt_tile_iterator.h:257
Describes the lane policy used by warp-level matrix multiply operators targeting SIMT instructions...
Describes the size of a matrix tile. 
Definition: matrix_shape.h:42
Array< Element, ThreadShape::kCount > Fragment
Fragment object holding a thread's part of a tile. 
Definition: mma_simt_tile_iterator.h:991
CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator++()
Advances the iterator along the advance dimension. 
Definition: mma_simt_tile_iterator.h:404
Definition: aligned_buffer.h:35
CUTLASS_HOST_DEVICE void store(Fragment const &frag) const 
Stores a fragment to memory at the location pointed to by the iterator. 
Definition: mma_simt_tile_iterator.h:686
Defines a structure containing strides, bounds, and a pointer to tensor data. 
CUTLASS_HOST_DEVICE MmaSimtTileIterator(TensorRef ref, int lane_id)
Constructor from TensorRef. 
Definition: mma_simt_tile_iterator.h:1007
Policy_ Policy
Decomposition of elements among threads. 
Definition: mma_simt_tile_iterator.h:308
CUTLASS_HOST_DEVICE Element * data() const 
Returns the pointer to referenced data. 
Definition: tensor_ref.h:254
CUTLASS_HOST_DEVICE MmaSimtTileIterator(TensorRef const &ref, int lane_id)
Constructor from TensorRef. 
Definition: mma_simt_tile_iterator.h:786
CUTLASS_HOST_DEVICE MmaSimtTileIterator(TensorRef const &ref, int lane_id)
Constructor from TensorRef. 
Definition: mma_simt_tile_iterator.h:574
CUTLASS_DEVICE void set_kgroup_index(int k_group)
Definition: mma_simt_tile_iterator.h:1356
CUTLASS_HOST_DEVICE MmaSimtTileIterator()
Definition: mma_simt_tile_iterator.h:1003
typename TensorRef::LongIndex LongIndex
Long Index type. 
Definition: mma_simt_tile_iterator.h:730
CUTLASS_DEVICE void set_kgroup_index(int k_group)
Definition: mma_simt_tile_iterator.h:1127
typename TensorRef::TensorCoord TensorCoord
Coordinate for an element in the tensor. 
Definition: mma_simt_tile_iterator.h:521
Operand
GEMM operand enumeration: D = A * B + C. 
Definition: include/cutlass/gemm/gemm.h:39
typename TensorRef::Index Index
Index type. 
Definition: mma_simt_tile_iterator.h:946
Policy_ Policy
Decomposition of elements among threads. 
Definition: mma_simt_tile_iterator.h:105
typename TensorRef::LongIndex LongIndex
Long Index type. 
Definition: mma_simt_tile_iterator.h:317
Array< Element, ThreadShape::kCount > Fragment
Fragment object holding a thread's part of a tile. 
Definition: mma_simt_tile_iterator.h:1217
Defines common types used for all GEMM-like operators. 
CUTLASS_HOST_DEVICE void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const 
Loads a fragment from memory at the location pointed to by the iterator. 
Definition: mma_simt_tile_iterator.h:1072
Array< Element, ThreadShape::kCount > Fragment
Fragment object holding a thread's part of a tile. 
Definition: mma_simt_tile_iterator.h:772
CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator++()
Advances the iterator along the advance dimension. 
Definition: mma_simt_tile_iterator.h:201
CUTLASS_HOST_DEVICE MmaSimtTileIterator(TensorRef ref, int lane_id)
Constructor from TensorRef. 
Definition: mma_simt_tile_iterator.h:1236
CUTLASS_HOST_DEVICE void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const 
Stores a fragment to memory at the location pointed to by the iterator. 
Definition: mma_simt_tile_iterator.h:445
CUTLASS_HOST_DEVICE MmaSimtTileIterator()
Default ctor constructs null iterator. 
Definition: mma_simt_tile_iterator.h:362
CUTLASS_HOST_DEVICE void load(Fragment &frag) const 
Loads a fragment from memory at the location pointed to by the iterator. 
Definition: mma_simt_tile_iterator.h:655
CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_pointer_offset(LongIndex offset)
Adds a pointer offset to internal pointer(s) to advance through memory. 
Definition: mma_simt_tile_iterator.h:386
Element_ Element
Element type. 
Definition: mma_simt_tile_iterator.h:302
CUTLASS_HOST_DEVICE void load(Fragment &frag) const 
Loads a fragment from memory at the location pointed to by the iterator. 
Definition: mma_simt_tile_iterator.h:234
CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_tile_offset(TensorCoord const &coord)
Advances an iterator along logical dimensions of matrix in units of whole tiles. 
Definition: mma_simt_tile_iterator.h:810
CUTLASS_HOST_DEVICE MmaSimtTileIterator()
Default ctor constructs null iterator. 
Definition: mma_simt_tile_iterator.h:782
CUTLASS_HOST_DEVICE MmaSimtTileIterator(TensorRef ref, int lane_id)
Constructor from TensorRef. 
Definition: mma_simt_tile_iterator.h:162
CUTLASS_HOST_DEVICE TensorRef & add_coord_offset(TensorCoord const &coord)
Adds an offset to each pointer. 
Definition: tensor_ref.h:326
Shape_ Shape
Shape of tile to load (concept: MatrixShape) 
Definition: mma_simt_tile_iterator.h:497
Mapping function for column-major matrices. 
Definition: layout/matrix.h:142
CUTLASS_HOST_DEVICE void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const 
Loads a fragment from memory at the location pointed to by the iterator. 
Definition: mma_simt_tile_iterator.h:1303
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
CUTLASS_HOST_DEVICE void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const 
Stores a fragment to memory at the location pointed to by the iterator. 
Definition: mma_simt_tile_iterator.h:1098
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Definition: mma_simt_tile_iterator.h:69
CUTLASS_HOST_DEVICE void load(Fragment &frag) const 
Loads a fragment from memory at the location pointed to by the iterator. 
Definition: mma_simt_tile_iterator.h:439
CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator--()
Advances the iterator along the advance dimension. 
Definition: mma_simt_tile_iterator.h:618
CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator--()
Advances the iterator along the advance dimension. 
Definition: mma_simt_tile_iterator.h:210
TensorRef< Element, Layout > TensorRef
TensorRef type for loading element from a tensor. 
Definition: mma_simt_tile_iterator.h:724
Element_ Element
Element type. 
Definition: mma_simt_tile_iterator.h:1160
CUTLASS_HOST_DEVICE MmaSimtTileIterator()
Default ctor constructs null iterator. 
Definition: mma_simt_tile_iterator.h:1232
CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_pointer_offset(LongIndex offset)
Adds a pointer offset to internal pointer(s) to advance through memory. 
Definition: mma_simt_tile_iterator.h:591
Policy_ Policy
Decomposition of elements among threads. 
Definition: mma_simt_tile_iterator.h:940
CUTLASS_HOST_DEVICE void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const 
Loads a fragment from memory at the location pointed to by the iterator. 
Definition: mma_simt_tile_iterator.h:219
CUTLASS_HOST_DEVICE Stride stride() const 
Returns the layout object's stride vector. 
Definition: tensor_ref.h:277
typename TensorRef::LongIndex LongIndex
Long Index type. 
Definition: mma_simt_tile_iterator.h:114
CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator++()
Advances the iterator along the advance dimension. 
Definition: mma_simt_tile_iterator.h:1276
typename Layout::TensorCoord TensorCoord
Coordinate in logical tensor space. 
Definition: tensor_ref.h:171
Defines a Shape template for matrix tiles. 
CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_tile_offset(TensorCoord const &coord)
Advances an iterator along logical dimensions of matrix in units of whole tiles. 
Definition: mma_simt_tile_iterator.h:393
Array< Element, ThreadShape::kCount > Fragment
Fragment object holding a thread's part of a tile. 
Definition: mma_simt_tile_iterator.h:560
CUTLASS_DEVICE void set_kgroup_index(int k_group)
Definition: mma_simt_tile_iterator.h:474
CUTLASS_HOST_DEVICE void load(Fragment &frag) const 
Loads a fragment from memory at the location pointed to by the iterator. 
Definition: mma_simt_tile_iterator.h:1092
CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_pointer_offset(LongIndex offset)
Adds a pointer offset to internal pointer(s) to advance through memory. 
Definition: mma_simt_tile_iterator.h:183
CUTLASS_HOST_DEVICE void reset(Element *ptr=nullptr)
Updates only the pointer. 
Definition: tensor_ref.h:235
Array< Element, ThreadShape::kCount > Fragment
Fragment object holding a thread's part of a tile. 
Definition: mma_simt_tile_iterator.h:350
typename TensorRef::Index Index
Index type. 
Definition: mma_simt_tile_iterator.h:515
CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_pointer_offset(LongIndex offset)
Adds a pointer offset to internal pointer(s) to advance through memory. 
Definition: mma_simt_tile_iterator.h:1027
typename TensorRef::TensorCoord TensorCoord
Coordinate for an element in the tensor. 
Definition: mma_simt_tile_iterator.h:952
Element_ Element
Element type. 
Definition: mma_simt_tile_iterator.h:934
typename TensorRef::LongIndex LongIndex
Long Index type. 
Definition: mma_simt_tile_iterator.h:518
typename TensorRef::TensorCoord TensorCoord
Coordinate for an element in the tensor. 
Definition: mma_simt_tile_iterator.h:320
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
CUTLASS_HOST_DEVICE LongIndex offset(TensorCoord const &coord) const 
Computes the offset of an index from the origin of the tensor. 
Definition: tensor_ref.h:301
CUTLASS_DEVICE void set_kgroup_index(int k_group)
Definition: mma_simt_tile_iterator.h:269
CUTLASS_HOST_DEVICE MmaSimtTileIterator(TensorRef ref, int lane_id)
Constructor from TensorRef. 
Definition: mma_simt_tile_iterator.h:366
CUTLASS_HOST_DEVICE void store(Fragment const &frag, Index pointer_offset) const 
Stores a fragment to memory at the location pointed to by the iterator. 
Definition: mma_simt_tile_iterator.h:462
CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator--()
Advances the iterator along the advance dimension. 
Definition: mma_simt_tile_iterator.h:1063
CUTLASS_HOST_DEVICE void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const 
Stores a fragment to memory at the location pointed to by the iterator. 
Definition: mma_simt_tile_iterator.h:240
Shape_ Shape
Shape of tile to load (concept: MatrixShape) 
Definition: mma_simt_tile_iterator.h:928
CUTLASS_HOST_DEVICE void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const 
Stores a fragment to memory at the location pointed to by the iterator. 
Definition: mma_simt_tile_iterator.h:661
Element_ Element
Element type. 
Definition: mma_simt_tile_iterator.h:99
Policy_ Policy
Decomposition of elements among threads. 
Definition: mma_simt_tile_iterator.h:509
CUTLASS_HOST_DEVICE void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const 
Loads a fragment from memory with additional logical offset. 
Definition: mma_simt_tile_iterator.h:627
CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_tile_offset(TensorCoord const &coord)
Advances an iterator along logical dimensions of matrix in units of whole tiles. 
Definition: mma_simt_tile_iterator.h:1034
CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator++()
Advances the iterator along the advance dimension. 
Definition: mma_simt_tile_iterator.h:1045
typename Layout::Index Index
Index type. 
Definition: tensor_ref.h:165
CUTLASS_HOST_DEVICE void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const 
Stores a fragment to memory at the location pointed to by the iterator. 
Definition: mma_simt_tile_iterator.h:873
CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_pointer_offset(LongIndex offset)
Adds a pointer offset to internal pointer(s) to advance through memory. 
Definition: mma_simt_tile_iterator.h:1258
Mapping function for row-major matrices. 
Definition: layout/matrix.h:50
CUTLASS_HOST_DEVICE void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const 
Loads a fragment from memory at the location pointed to by the iterator. 
Definition: mma_simt_tile_iterator.h:422
Policy_ Policy
Decomposition of elements among threads. 
Definition: mma_simt_tile_iterator.h:1166
CUTLASS_HOST_DEVICE void load(Fragment &frag) const 
Loads a fragment from memory at the location pointed to by the iterator. 
Definition: mma_simt_tile_iterator.h:867
typename TensorRef::Index Index
Index type. 
Definition: mma_simt_tile_iterator.h:1172
typename TensorRef::Index Index
Index type. 
Definition: mma_simt_tile_iterator.h:111
typename TensorRef::TensorCoord TensorCoord
Coordinate for an element in the tensor. 
Definition: mma_simt_tile_iterator.h:733
typename TensorRef::LongIndex LongIndex
Long Index type. 
Definition: mma_simt_tile_iterator.h:1175
CUTLASS_HOST_DEVICE void store(Fragment const &frag) const 
Stores a fragment to memory at the location pointed to by the iterator. 
Definition: mma_simt_tile_iterator.h:899
CUTLASS_HOST_DEVICE void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const 
Loads a fragment from memory with additional logical offset. 
Definition: mma_simt_tile_iterator.h:839
typename TensorRef::Index Index
Index type. 
Definition: mma_simt_tile_iterator.h:314
CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator--()
Advances the iterator along the advance dimension. 
Definition: mma_simt_tile_iterator.h:413
Defines layout functions used by TensorRef and derived classes. 
Shape_ Shape
Shape of tile to load (concept: MatrixShape) 
Definition: mma_simt_tile_iterator.h:296
CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator++()
Advances the iterator along the advance dimension. 
Definition: mma_simt_tile_iterator.h:821
CUTLASS_HOST_DEVICE MmaSimtTileIterator()
Default ctor constructs null iterator. 
Definition: mma_simt_tile_iterator.h:570
Array< Element, ThreadShape::kCount > Fragment
Fragment object holding a thread's part of a tile. 
Definition: mma_simt_tile_iterator.h:147
CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_tile_offset(TensorCoord const &coord)
Advances an iterator along logical dimensions of matrix in units of whole tiles. 
Definition: mma_simt_tile_iterator.h:190
Definition: layout/matrix.h:343
CUTLASS_HOST_DEVICE void load(Fragment &frag) const 
Loads a fragment from memory at the location pointed to by the iterator. 
Definition: mma_simt_tile_iterator.h:1321
CUTLASS_HOST_DEVICE TensorRef & add_pointer_offset(LongIndex offset_)
Adds an offset to each pointer. 
Definition: tensor_ref.h:319
CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator--()
Advances the iterator along the advance dimension. 
Definition: mma_simt_tile_iterator.h:1294
CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_pointer_offset(LongIndex offset)
Adds a pointer offset to internal pointer(s) to advance through memory. 
Definition: mma_simt_tile_iterator.h:803
CUTLASS_HOST_DEVICE MmaSimtTileIterator()
Default ctor constructs null iterator. 
Definition: mma_simt_tile_iterator.h:158
CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator--()
Advances the iterator along the advance dimension. 
Definition: mma_simt_tile_iterator.h:830
CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_tile_offset(TensorCoord const &coord)
Advances an iterator along logical dimensions of matrix in units of whole tiles. 
Definition: mma_simt_tile_iterator.h:598
typename TensorRef::TensorCoord TensorCoord
Coordinate for an element in the tensor. 
Definition: mma_simt_tile_iterator.h:1178
CUTLASS_HOST_DEVICE void store(Fragment const &frag) const 
Stores a fragment to memory at the location pointed to by the iterator. 
Definition: mma_simt_tile_iterator.h:1115
typename TensorRef::Index Index
Index type. 
Definition: mma_simt_tile_iterator.h:727
Shape_ Shape
Shape of tile to load (concept: MatrixShape) 
Definition: mma_simt_tile_iterator.h:1154
typename TensorRef::TensorCoord TensorCoord
Coordinate for an element in the tensor. 
Definition: mma_simt_tile_iterator.h:117
CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator++()
Advances the iterator along the advance dimension. 
Definition: mma_simt_tile_iterator.h:609
Shape_ Shape
Shape of tile to load (concept: MatrixShape) 
Definition: mma_simt_tile_iterator.h:93
CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_tile_offset(TensorCoord const &coord)
Advances an iterator along logical dimensions of matrix in units of whole tiles. 
Definition: mma_simt_tile_iterator.h:1265
CUTLASS_HOST_DEVICE void store(Fragment const &frag, Index pointer_offset) const 
Stores a fragment to memory at the location pointed to by the iterator. 
Definition: mma_simt_tile_iterator.h:1344
Basic include for CUTLASS. 
Definition: matrix_coord.h:39
CUTLASS_HOST_DEVICE void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const 
Stores a fragment to memory at the location pointed to by the iterator. 
Definition: mma_simt_tile_iterator.h:1327
typename TensorRef::LongIndex LongIndex
Long Index type. 
Definition: mma_simt_tile_iterator.h:949
Element_ Element
Element type. 
Definition: mma_simt_tile_iterator.h:715
Element_ Element
Element type. 
Definition: mma_simt_tile_iterator.h:503
Policy_ Policy
Decomposition of elements among threads. 
Definition: mma_simt_tile_iterator.h:721
typename Layout::LongIndex LongIndex
Long index used for pointer offsets. 
Definition: tensor_ref.h:168
Shape_ Shape
Shape of tile to load (concept: MatrixShape) 
Definition: mma_simt_tile_iterator.h:709
Definition: layout/matrix.h:237