45 namespace threadblock {
    59   static int const kRow = Row;
    64   static int const kCount = kColumn * kRow * kGroup * kCluster * 
kTile;
    82   static int const kThreads = ThreadMap::kThreads;
    85   static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
   107     Index cluster = coord.
strided() / (Shape::kGroup * Shape::kRow);
   108     Index cluster_residual = coord.
strided() % (Shape::kGroup * Shape::kRow);
   110     Index group = cluster_residual / (Shape::kRow);
   111     Index row = cluster_residual % (Shape::kRow);
   114       row + group * Shape::kRow * Count::kRow 
   115         + cluster * Shape::kGroup * Count::kGroup * Shape::kRow * Count::kRow,
   129   int ElementsPerAccess,
   139   int ElementsPerAccess,
   142 struct RowArrangement<Shape, WarpsRemaining, ElementsPerAccess, ElementSize, false> {
   143   static int const kWarpSize = 32;
   144   static int const kElementsPerAccess = ElementsPerAccess;
   145   static int const kElementSize = ElementSize;
   147   static int const kIterationsRow = 1;
   148   static int const kDeltaRow = 1;
   149   static int const kIterationsColumn = Shape::kColumn / kElementsPerAccess / kWarpSize;
   150   static int const kDeltaColumn = kWarpSize * kElementsPerAccess;
   152   static int const kAccessWidth = kWarpSize;
   153   static int const kAccessRows = 1;
   154   static int const kWarpPartitionsRow = 1;
   155   static int const kWarpPartitionsColumn = WarpsRemaining;
   162   int ElementsPerAccess,
   165 struct RowArrangement<Shape, WarpsRemaining, ElementsPerAccess, ElementSize, true> {
   167   static int const kMemoryAccessSize = 128;
   168   static int const kWarpSize = 32;
   170   static int const kElementsPerAccess = ElementsPerAccess;
   171   static int const kElementSize = ElementSize;
   174     static int const kShapeRow = Shape::kRow / WarpsRemaining;
   175     static int const kShapeWidth = Shape::kColumn / kElementsPerAccess;
   177     static int const kTargetMemoryAccessWidth = 
   178       kMemoryAccessSize / (kElementsPerAccess * kElementSize / 8);
   180     static int const kTargetAccessRows = kWarpSize / kTargetMemoryAccessWidth;
   183   static int const kAccessWidth = 
   184     (Detail::kTargetAccessRows > Detail::kShapeRow ?
   185       kWarpSize / Detail::kShapeRow
   188         const_min(kWarpSize, kMemoryAccessSize / (kElementsPerAccess * kElementSize / 8))
   191   static int const kAccessRows =
   192     (Detail::kTargetAccessRows > Detail::kShapeRow ?
   194       : 
const_min(Shape::kRow, kWarpSize / kAccessWidth));
   196   static int const kIterationsRow = Detail::kShapeRow / kAccessRows;
   197   static int const kDeltaRow = kAccessRows;
   199   static int const kIterationsColumn = Detail::kShapeWidth / kAccessWidth;
   200   static int const kDeltaColumn = kAccessWidth * kElementsPerAccess;
   202   static_assert( kAccessWidth * kElementsPerAccess <= Shape::kColumn, 
"Accessing too many elements per access");
   203   static_assert( kIterationsColumn > 0, 
"Iteration Count Column must be > 0" );
   204   static_assert( kIterationsRow > 0, 
"Iteration Count Row must be > 0" );
   206   static int const kWarpPartitionsRow = 1;
   207   static int const kWarpPartitionsColumn = 1;
   225   int ElementsPerAccess,
   233   static int const kWarpSize = 32;
   234   static int const kThreads = Threads;
   235   static int const kWarpCount = kThreads / kWarpSize;
   237   static int const kElementsPerAccess = ElementsPerAccess;
   238   static int const kElementSize = ElementSize;
   247     static int const kIterationsCluster = 
   248       ((Shape::kCluster > kWarpCount) ?
   249         Shape::kCluster / kWarpCount
   252     static int const kDeltaCluster =
   253       ((Shape::kCluster > kWarpCount) ?
   254         Shape::kRow * Count::kRow * Shape::kGroup * Count::kGroup * Shape::kCluster / kIterationsCluster
   257     static int const kCompactedDeltaCluster =
   258       ((Shape::kCluster > kWarpCount) ?
   259         Shape::kRow * Shape::kGroup * Shape::kCluster / kIterationsCluster
   262     static int const kWarpPartitionsCluster =
   263       ((Shape::kCluster > kWarpCount) ?
   265         : kWarpCount / Shape::kCluster);
   267     static int const kWarpsRemainingForGroups =
   268       ((Shape::kCluster > kWarpCount) ? 1 : kWarpCount / Shape::kCluster);
   271     static int const kIterationsGroup =
   272       ((Shape::kGroup > kWarpsRemainingForGroups) ?
   273         Shape::kGroup / kWarpsRemainingForGroups
   276     static int const kDeltaGroup =
   277       ((Shape::kGroup > kWarpsRemainingForGroups) ?
   278         Shape::kRow * Count::kRow * Shape::kGroup / kIterationsGroup
   281     static int const kCompactedDeltaGroup =
   282       ((Shape::kGroup > kWarpsRemainingForGroups) ?
   283         Shape::kRow * Shape::kGroup / kIterationsGroup
   286     static int const kWarpPartitionsGroup =
   287       ((Shape::kGroup > kWarpsRemainingForGroups) ?
   289         : kWarpsRemainingForGroups / Shape::kGroup);
   291     static int const kWarpsRemainingForRows =
   292       ((Shape::kGroup > kWarpsRemainingForGroups) ?
   294         : kWarpsRemainingForGroups / Shape::kGroup);
   299       kWarpsRemainingForRows,
   302       (Shape::kRow > kWarpsRemainingForRows)
   307       RowArrangement::kWarpPartitionsColumn,
   308       RowArrangement::kWarpPartitionsRow,
   309       kWarpPartitionsGroup,
   310       kWarpPartitionsCluster,
   313     static int const kAccessWidth = RowArrangement::kAccessWidth;
   314     static int const kAccessRows = RowArrangement::kAccessRows;
   322     Detail::RowArrangement::kIterationsColumn, 
   323     Detail::RowArrangement::kIterationsRow, 
   324     Detail::kIterationsGroup, 
   325     Detail::kIterationsCluster, 
   329     Detail::RowArrangement::kDeltaColumn,
   330     Detail::RowArrangement::kDeltaRow,
   332     Detail::kDeltaCluster,
   339     int warp_idx = thread_idx / kWarpSize;
   340     int lane_idx = thread_idx % kWarpSize;
   343     int cluster_idx = warp_idx / Detail::WarpPartitions::kCluster;
   344     int residual_cluster = warp_idx % Detail::WarpPartitions::kCluster;
   346     int group_idx = residual_cluster / Detail::WarpPartitions::kGroup;
   347     int residual_group = residual_cluster % Detail::WarpPartitions::kGroup;
   349     int row_idx = residual_group / Detail::WarpPartitions::kRow;
   350     int col_idx = residual_group % Detail::WarpPartitions::kRow;
   353     int lane_row_offset = lane_idx / Detail::kAccessWidth;
   354     int lane_col_offset = lane_idx % Detail::kAccessWidth;
   357     int cluster_offset = cluster_idx * Shape::kRow * Count::kRow * Shape::kGroup * Count::kGroup;
   358     int group_offset = group_idx * Shape::kRow * Count::kRow;
   359     int row_offset = row_idx * Iterations::kRow * Detail::kAccessRows;
   360     int column_offset = col_idx * Iterations::kColumn * Detail::kAccessWidth * kElementsPerAccess;
   363       cluster_offset + group_offset + row_offset + lane_row_offset,
   364       (column_offset + lane_col_offset) * kElementsPerAccess
   375       Detail::RowArrangement::kIterationsColumn,
   376       Detail::RowArrangement::kIterationsRow,
   377       Detail::kIterationsGroup,
   378       Detail::kIterationsCluster,
   382       Detail::RowArrangement::kDeltaColumn,
   383       Detail::RowArrangement::kDeltaRow,
   384       Detail::kCompactedDeltaGroup,
   385       Detail::kCompactedDeltaCluster,
   389     static int const kElementsPerAccess = ElementsPerAccess;
   392     static int const kThreads = Threads;
   398       int warp_idx = thread_idx / kWarpSize;
   399       int lane_idx = thread_idx % kWarpSize;
   402       int cluster_idx = warp_idx / Detail::WarpPartitions::kCluster;
   403       int residual_cluster = warp_idx % Detail::WarpPartitions::kCluster;
   405       int group_idx = residual_cluster / Detail::WarpPartitions::kGroup;
   406       int residual_group = residual_cluster % Detail::WarpPartitions::kGroup;
   408       int row_idx = residual_group / Detail::WarpPartitions::kRow;
   409       int col_idx = residual_group % Detail::WarpPartitions::kRow;
   412       int lane_row_offset = lane_idx / Detail::kAccessWidth;
   413       int lane_col_offset = lane_idx % Detail::kAccessWidth;
   416       int cluster_offset = cluster_idx * Shape::kRow * Shape::kGroup;
   417       int group_offset = group_idx * Shape::kRow;
   418       int row_offset = row_idx * Iterations::kRow * Detail::kAccessRows;
   419       int column_offset = col_idx * Iterations::kColumn * Detail::kAccessWidth * kElementsPerAccess;
   422         cluster_offset + group_offset + row_offset + lane_row_offset,
   423         (column_offset + lane_col_offset) * kElementsPerAccess
   440 template <
typename WarpCount_, 
typename MmaCount_, 
int Threads,
   441           int ElementsPerAccess, 
int ElementSize>
   446   static int const kWarpSize = 32;
   447   static int const kThreads = Threads;
   448   static int const kWarpCount = kThreads / kWarpSize;
   450   static int const kElementsPerAccess = ElementsPerAccess;
   451   static int const kElementSize = ElementSize;
   470     int warp_idx = thread_idx / kWarpSize;
   471     int lane_idx = thread_idx % kWarpSize;
   475         Delta::kContiguous * Iterations::kContiguous,
   476         Delta::kStrided * Iterations::kStrided};
   479                                          warp_idx / WarpCount::kContiguous};
   483         lane_idx * kElementsPerAccess, 0};
   486         warp_footprint * warp_offset + thread_offset_in_warp;
   488     return thread_offset_in_threadblock_tile;
 int Index
Integer-valued index. 
Definition: pitch_linear.h:56
ThreadMap_ ThreadMap
Conventional thread map (concept: ThreadMap) 
Definition: output_tile_thread_map.h:79
Definition: output_tile_thread_map.h:228
Definition: aligned_buffer.h:35
Coordinate in pitch-linear space. 
Definition: pitch_linear.h:52
Defines a structure containing strides, bounds, and a pointer to tensor data. 
Count_ Count
Definition: output_tile_thread_map.h:231
static int const kGroup
Definition: output_tile_thread_map.h:60
Tuple defining point in output tile. 
Definition: output_tile_thread_map.h:57
WarpCount_ WarpCount
Definition: output_tile_thread_map.h:443
Iterations_ Iterations
Iterations performed by each thread. 
Definition: output_tile_thread_map.h:91
static int const kColumn
Definition: output_tile_thread_map.h:58
RowArrangement determines how one or more warps cover a region of consecutive rows. 
Definition: output_tile_thread_map.h:133
Definition: output_tile_thread_map.h:442
Template defining a shape used by pitch-linear operators. 
Definition: pitch_linear.h:43
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
Compacted thread map in which the 4D region is contiguous. 
Definition: output_tile_thread_map.h:369
Count_ Count
Number of iterator iterations. 
Definition: output_tile_thread_map.h:97
static CUTLASS_HOST_DEVICE MatrixCoord initial_offset(int thread_idx)
Function to compute each thread's initial offset. 
Definition: output_tile_thread_map.h:396
Defines a Shape template for matrix tiles. 
Shape_ Shape
Definition: output_tile_thread_map.h:230
static CUTLASS_HOST_DEVICE layout::PitchLinearCoord initial_offset(int thread_idx)
Initial offset function. 
Definition: output_tile_thread_map.h:469
detail::RowArrangement< Shape, kWarpsRemainingForRows, kElementsPerAccess, kElementSize,(Shape::kRow > kWarpsRemainingForRows) > RowArrangement
Definition: output_tile_thread_map.h:303
MmaCount Iterations
Definition: output_tile_thread_map.h:463
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Top-level include for all CUTLASS numeric types. 
CUTLASS_HOST_DEVICE Index const & contiguous() const 
Returns the contiguous dimension. 
Definition: pitch_linear.h:89
Shape_ Shape
Definition: output_tile_thread_map.h:372
Delta_ Delta
Delta between accesses. 
Definition: output_tile_thread_map.h:94
static CUTLASS_HOST_DEVICE MatrixCoord initial_offset(int thread_idx)
Initial offset function. 
Definition: output_tile_thread_map.h:101
Definition: output_tile_thread_map.h:457
static int const kRow
Definition: output_tile_thread_map.h:59
Defines layout functions used by TensorRef and derived classes. 
Definition: output_tile_thread_map.h:76
Shape_ Shape
Shape of the tile. 
Definition: output_tile_thread_map.h:88
static int const kTile
Definition: output_tile_thread_map.h:62
static int const kCount
Definition: output_tile_thread_map.h:64
MmaCount_ MmaCount
Definition: output_tile_thread_map.h:444
static CUTLASS_HOST_DEVICE MatrixCoord initial_offset(int thread_idx)
Initial offset function. 
Definition: output_tile_thread_map.h:337
CUTLASS_HOST_DEVICE constexpr int const_min(int a, int b)
Definition: fast_math.h:219
Basic include for CUTLASS. 
Definition: matrix_coord.h:39
Definition: output_tile_thread_map.h:244
static int const kCluster
Definition: output_tile_thread_map.h:61
CUTLASS_HOST_DEVICE Index const & strided() const 
Returns the column of the coordinate. 
Definition: pitch_linear.h:97