58 namespace threadblock {
    64 template <
typename Shape, 
typename Element, 
typename Layout, 
int AdvanceRank,
    65           typename ThreadMap, 
typename AccessType>
    72 template <
typename Shape_, 
typename Element_, 
int AdvanceRank,
    73           typename ThreadMap_, 
typename AccessType_>
    75                                    AdvanceRank, ThreadMap_, AccessType_> {
    78       AdvanceRank == 0 || AdvanceRank == 1,
    79       "Specialization for pitch-linear iterator may along advance along the "    80       "contiguous(rank=0) or strided(rank=1) dimension.");
    83   using Element = Element_;
    85   static int const kAdvanceRank = AdvanceRank;
    86   using ThreadMap = ThreadMap_;
    99   static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
   101   static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), 
   102     "Vectors implied by the thread map must be divisible by the access type.");
   104   static int const kPredicatesPerByte = 4;
   105   static int const kPredicatesPerWord = 4 * kPredicatesPerByte;
   107   static int const kPredicateCount = ThreadMap::Iterations::kCount * kAccessesPerVector;
   110   static int const kPredicateByteCount = 
   111     (kPredicateCount + kPredicatesPerByte - 1) / kPredicatesPerByte;
   112   static int const kPredicateWordCount = (kPredicateByteCount + 3) / 4;
   114   static unsigned const kPredicateMask = (1u << kPredicatesPerByte) - 1u;
   116   static_assert(kPredicateWordCount <= 4, 
"Too many predicates.");
   119   using Mask = Array<uint32_t, kPredicateWordCount>;
   143     Params(): stride_(0), inc_strided_(0), inc_next_(0), inc_advance_(0) { }
   147     Params(Layout 
const &layout) : stride_(layout.stride(0)) {
   148       inc_strided_ = (stride_ * ThreadMap::Delta::kStrided) *
   160       inc_next_ = inc_advance_ - (ThreadMap::Iterations::kStrided - 1) *
   161                                      ThreadMap::Delta::kStrided * stride_ *
   168   using BytePointer = 
char *;
   176   Params 
const ¶ms_;
   179   BytePointer pointer_;
   182   uint32_t predicates_[kPredicateWordCount];
   194   bool is_residue_tile_;
   197   int iteration_vector_;
   200   int iteration_contiguous_;
   203   int iteration_strided_;
   208   void compute_predicates_(
   212       bool is_steady_state = 
false) {
   215     for (
int i = 0; i < kPredicateWordCount; ++i) {
   219     for (
int access_idx = 0; access_idx < ThreadMap::Iterations::kCount * kAccessesPerVector; ++access_idx) {
   221       int s = access_idx / (ThreadMap::Iterations::kContiguous * kAccessesPerVector);
   223       int access_residual = access_idx % (ThreadMap::Iterations::kContiguous * kAccessesPerVector);
   225       int c = access_residual / kAccessesPerVector;
   226       int v = access_residual % kAccessesPerVector;
   228       TensorCoord iteration_coord(c * ThreadMap::Delta::kContiguous + v * AccessType::kElements,
   229                                 s * ThreadMap::Delta::kStrided);
   231       TensorCoord coord = thread_offset_ + iteration_coord;
   235       if (is_steady_state) {
   236         if (kAdvanceRank == 0) {
   237           guard = (coord.strided() < extent.strided());
   239           guard = (coord.contiguous() < extent.contiguous());
   242         guard = (coord.strided() < extent.strided() &&
   243                  coord.contiguous() < extent.contiguous());
   246       int pred_idx = v + kAccessesPerVector * (c + ThreadMap::Iterations::kContiguous * s);
   248       int word_idx = pred_idx / kPredicatesPerWord;
   249       int residual = pred_idx % kPredicatesPerWord;
   250       int byte_idx = residual / kPredicatesPerByte;
   251       int bit_idx = residual % kPredicatesPerByte;
   253       predicates_[word_idx] |= (unsigned(guard) << (byte_idx * 8 + bit_idx));
   265       Params 
const ¶ms,
   275         pointer_(reinterpret_cast<BytePointer>(
   278         is_residue_tile_(true) {
   283       Index residue_size = (extent_[kAdvanceRank] % Shape::kStrided);
   285         residue_size = Shape::kStrided;
   288       residue_offset_ = 
make_Coord(0, residue_size);
   290         extent_.contiguous(), 
   291         min(threadblock_offset.strided() + residue_offset_.strided(), extent_.strided())
   296       Index residue_size = (extent_[kAdvanceRank] % Shape::kContiguous);
   298         residue_size = Shape::kContiguous;
   300       residue_offset_ = 
make_Coord(residue_size, 0);
   302         min(extent_.contiguous(), threadblock_offset.contiguous() + residue_offset_.contiguous()),
   308     thread_offset_ = threadblock_offset + ThreadMap::initial_offset(thread_id);
   311     Layout layout(params_.stride_);
   312     add_pointer_offset(layout(thread_offset_));
   314     compute_predicates_(residue_extent, 
false);
   316     set_iteration_index(0);
   323       Params 
const ¶ms,
   337     iteration_vector_ = index % kAccessesPerVector;
   338     int residual_access = index / kAccessesPerVector;
   340     iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
   341     iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
   355     if (is_residue_tile_) {
   357       thread_offset_ += residue_offset_;
   359       Layout layout(params_.stride_);
   360       add_pointer_offset(layout(residue_offset_));
   362       compute_predicates_(extent_, 
true);
   365         pointer_ += params_.inc_advance_ * (tile_offset.strided() - 1);
   366         pointer_ += Shape::kContiguous * tile_offset.contiguous();
   368         pointer_ += params_.inc_advance_ * (tile_offset.contiguous() - 1);
   369         pointer_ += Shape::kStrided * tile_offset.strided();
   373         pointer_ += params_.inc_advance_ * tile_offset.strided();
   374         pointer_ += Shape::kContiguous * tile_offset.contiguous();
   376         pointer_ += params_.inc_advance_ * tile_offset.contiguous();
   377         pointer_ += Shape::kStrided * tile_offset.strided();
   380     is_residue_tile_ = 
false;
   396     if (iteration_vector_ < kAccessesPerVector) {
   400     iteration_vector_ = 0;
   401     ++iteration_contiguous_;
   403     if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
   409     iteration_contiguous_ = 0;
   410     ++iteration_strided_;
   412     if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
   413       pointer_ += params_.inc_strided_;
   419     iteration_strided_ = 0;
   422     pointer_ += params_.inc_next_;
   427     pointer_ -= params_.inc_advance_;
   444     for (
int i = 0; i < kPredicateWordCount; ++i) {
   454     for (
int i = 0; i < kPredicateWordCount; ++i) {
   455       predicates_[i] = 0xffffffff;
   463     for (
int i = 0; i < kPredicateWordCount; ++i) {
   464       predicates_[i] = mask[i];
   473     for (
int i = 0; i < kPredicateWordCount; ++i) {
   474       mask[i] = predicates_[i];
   484       iteration_vector_ + kAccessesPerVector * (iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous);
   486     int word_idx = pred_idx / kPredicatesPerWord;
   487     int residual = pred_idx % kPredicatesPerWord;
   488     int byte_idx = residual / kPredicatesPerByte;
   489     int bit_idx = residual % kPredicatesPerByte;
   491     bool pred = (predicates_[word_idx] & (1u << (byte_idx * 8 + bit_idx))) != 0;
   508 template <
typename Shape_, 
typename Element_, 
int AdvanceRank,
   509           typename ThreadMap_, 
typename AccessType_>
   511                                    AdvanceRank, ThreadMap_, AccessType_> {
   514       AdvanceRank == 0 || AdvanceRank == 1,
   515       "Specialization for pitch-linear iterator may along advance along the "   516       "contiguous(rank=0) or strided(rank=1) dimension.");
   519   using Element = Element_;
   521   static int const kAdvanceRank = AdvanceRank;
   522   using ThreadMap = ThreadMap_;
   540   using Mask = 
typename UnderlyingIterator::Mask;
   542   static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector;
   550     typename UnderlyingIterator::Params params_;
   561         : params_(layout::PitchLinear(layout.stride(0))){};
   578       Params 
const ¶ms,
   587       : iterator_(params.params_, pointer,
   588                   layout::PitchLinearCoord(extent.row(), extent.column()),
   590                   layout::PitchLinearCoord(threadblock_offset.row(),
   591                                            threadblock_offset.column())) {}
   596       Params 
const ¶ms,  
   611     iterator_.add_pointer_offset(pointer_offset);
   618     iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()});
   624     return reinterpret_cast<AccessType *
>(iterator_.get());
   671     return iterator_.valid();
   684 template <
typename Shape_, 
typename Element_, 
int AdvanceRank,
   685           typename ThreadMap_, 
typename AccessType_>
   687                                    AdvanceRank, ThreadMap_, AccessType_> {
   690       AdvanceRank == 0 || AdvanceRank == 1,
   691       "Specialization for pitch-linear iterator may along advance along the "   692       "contiguous(rank=0) or strided(rank=1) dimension.");
   695   using Element = Element_;
   697   static int const kAdvanceRank = AdvanceRank;
   698   using ThreadMap = ThreadMap_;
   715   static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector;
   718   using Mask = 
typename UnderlyingIterator::Mask;
   726     typename UnderlyingIterator::Params params_;
   737         : params_(layout::PitchLinear(layout.stride(0))){};
   754       Params 
const ¶ms,
   763       : iterator_(params.params_, pointer,
   764                   layout::PitchLinearCoord(extent.column(), extent.row()),
   766                   layout::PitchLinearCoord(threadblock_offset.column(),
   767                                            threadblock_offset.row())) {}
   772       Params 
const ¶ms,  
   787     iterator_.add_pointer_offset(pointer_offset);
   794     iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()});
   800     return reinterpret_cast<AccessType *
>(iterator_.get());
   847     return iterator_.valid();
   862 template <
typename Shape_, 
typename Element_, 
int AdvanceRank,
   863           typename ThreadMap_, 
typename AccessType_, 
int InterleavedK>
   865                                    layout::ColumnMajorInterleaved<InterleavedK>,
   866                                    AdvanceRank, ThreadMap_, AccessType_> {
   869       AdvanceRank == 0 || AdvanceRank == 1,
   870       "Specialization for pitch-linear iterator may along advance along the "   871       "contiguous(rank=0) or strided(rank=1) dimension.");
   874   using Element = Element_;
   875   static int const kInterleavedK = InterleavedK;
   877   static int const kAdvanceRank = AdvanceRank;
   878   using ThreadMap = ThreadMap_;
   893                                Shape::kColumn / kInterleavedK>,
   897   static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector;
   900   using Mask = 
typename UnderlyingIterator::Mask;
   908     typename UnderlyingIterator::Params params_;
   917         : params_(layout::PitchLinear(layout.stride(0))) {}
   934       Params 
const ¶ms,
   943       : iterator_(params.params_, pointer,
   944                   layout::PitchLinearCoord(extent.row() * kInterleavedK,
   945                                            extent.column() / kInterleavedK),
   947                   layout::PitchLinearCoord(
   948                       threadblock_offset.row() * kInterleavedK,
   949                       threadblock_offset.column() / kInterleavedK)) {}
   954       Params 
const ¶ms,  
   969     iterator_.add_pointer_offset(pointer_offset);
   976     iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()});
   982     return reinterpret_cast<AccessType *
>(iterator_.get());
  1028   bool valid() { 
return iterator_.valid(); }
  1041 template <
typename Shape_, 
typename Element_, 
int AdvanceRank,
  1042           typename ThreadMap_, 
typename AccessType_, 
int InterleavedK>
  1044                                    layout::RowMajorInterleaved<InterleavedK>,
  1045                                    AdvanceRank, ThreadMap_, AccessType_> {
  1048       AdvanceRank == 0 || AdvanceRank == 1,
  1049       "Specialization for pitch-linear iterator may along advance along the "  1050       "contiguous(rank=0) or strided(rank=1) dimension.");
  1053   using Element = Element_;
  1054   static int const kInterleavedK = InterleavedK;
  1056   static int const kAdvanceRank = AdvanceRank;
  1057   using ThreadMap = ThreadMap_;
  1072                                Shape::kRow / kInterleavedK>,
  1077   static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector;
  1080   using Mask = 
typename UnderlyingIterator::Mask;
  1088     typename UnderlyingIterator::Params params_;
  1097         : params_(layout::PitchLinear(layout.stride(0))) {}
  1114       Params 
const ¶ms,
  1123       : iterator_(params.params_, pointer,
  1124                   layout::PitchLinearCoord(extent.column() * kInterleavedK,
  1125                                            extent.row() / kInterleavedK),
  1127                   layout::PitchLinearCoord(
  1128                       threadblock_offset.column() * kInterleavedK,
  1129                       threadblock_offset.row() / kInterleavedK)) {}
  1134       Params 
const ¶ms,  
  1149     iterator_.add_pointer_offset(pointer_offset);
  1156     iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()});
  1162     return reinterpret_cast<AccessType *
>(iterator_.get());
  1208   bool valid() { 
return iterator_.valid(); }
 
int64_t LongIndex
Long index type used for offsets. 
Definition: layout/matrix.h:62
int64_t LongIndex
Long index type used for offsets. 
Definition: layout/matrix.h:355
Definition: aligned_buffer.h:35
Coordinate in pitch-linear space. 
Definition: pitch_linear.h:52
Defines a structure containing strides, bounds, and a pointer to tensor data. 
int64_t LongIndex
Long index type used for offsets. 
Definition: layout/matrix.h:249
Mapping function for pitch-linear memory. 
Definition: pitch_linear.h:163
int32_t Index
Index type used for coordinates. 
Definition: layout/matrix.h:352
A Coord is a coordinate of arbitrary rank into a tensor or matrix. 
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate. 
Definition: coord.h:387
int64_t LongIndex
Long index type used for offsets. 
Definition: layout/matrix.h:154
int32_t Index
Index type used for coordinates. 
Definition: layout/matrix.h:246
Defines a structure containing strides and a pointer to tensor data. 
Mapping function for column-major matrices. 
Definition: layout/matrix.h:142
Template defining a shape used by pitch-linear operators. 
Definition: pitch_linear.h:43
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Defines container classes and iterators for managing a statically sized vector of boolean predicates...
int32_t Index
Index type used for coordinates. 
Definition: layout/matrix.h:59
CUTLASS_HOST_DEVICE half_t & operator++(half_t &lhs)
Definition: half.h:694
int64_t LongIndex
Long index type used for offsets. 
Definition: pitch_linear.h:175
Defines a Shape template for matrix tiles. 
Defines the size of an element in bits. 
Definition: numeric_types.h:42
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
int32_t Index
Index type used for coordinates. 
Definition: pitch_linear.h:172
Mapping function for row-major matrices. 
Definition: layout/matrix.h:50
Defines layout functions used by TensorRef and derived classes. 
Defines layout functions used by TensorRef and derived classes for pitch-linear memory. 
Definition: layout/matrix.h:343
int32_t Index
Index type used for coordinates. 
Definition: layout/matrix.h:151
Basic include for CUTLASS. 
Definition: matrix_coord.h:39
Definition: layout/matrix.h:237