58 namespace threadblock {
    64 template <
typename Shape, 
typename Element, 
typename Layout, 
int AdvanceRank,
    65           typename ThreadMap, 
typename AccessType>
    72 template <
typename Shape_, 
typename Element_, 
int AdvanceRank,
    73           typename ThreadMap_, 
typename AccessType_>
    75                                    AdvanceRank, ThreadMap_, AccessType_> {
    78       AdvanceRank == 0 || AdvanceRank == 1,
    79       "Specialization for pitch-linear iterator may along advance along the "    80       "contiguous(rank=0) or strided(rank=1) dimension.");
    83   using Element = Element_;
    85   static int const kAdvanceRank = AdvanceRank;
    86   using ThreadMap = ThreadMap_;
    99   static int const kPredicatesPerByte = 4;
   100   static int const kPredicatesPerWord = 4 * kPredicatesPerByte;
   103   static int const kPredicateByteCount = (ThreadMap::Iterations::kCount * ThreadMap::ThreadAccessShape::kStrided + kPredicatesPerByte - 1) / kPredicatesPerByte;
   104   static int const kPredicateWordCount = (kPredicateByteCount + 3) / 4;
   106   static unsigned const kPredicateMask = (1u << kPredicatesPerByte) - 1u;
   108   static_assert(kPredicateWordCount <= 4, 
"Too many predicates.");
   111   using Mask = Array<uint32_t, kPredicateWordCount>;
   135     Params(): stride_(0), inc_strided_(0), inc_next_(0), inc_advance_(0) { }
   139     Params(Layout 
const &layout) : stride_(layout.stride(0)) {
   142           (stride_ * ThreadMap::Delta::kStrided) * 
int(
sizeof(Element));
   146         inc_advance_ = Shape::kStrided * stride_ * int(
sizeof(Element));
   149         inc_advance_ = Shape::kContiguous * int(
sizeof(Element));
   152       inc_next_ = inc_advance_ - (ThreadMap::Iterations::kStrided - 1) *
   153                                      ThreadMap::Delta::kStrided * stride_ *
   154                                      int(
sizeof(Element));
   160   using BytePointer = 
char *;
   168   Params 
const ¶ms_;
   171   BytePointer pointer_;
   174   uint32_t predicates_[kPredicateWordCount];
   183   int residue_tile_idx_;
   186   bool is_residue_tile_;
   189   int iteration_contiguous_;
   192   int iteration_strided_;
   195   int iteration_thread_;
   200   void compute_predicates_(
   202       bool is_steady_state = 
false) {
   205     for (
int i = 0; i < kPredicateWordCount; ++i) {
   210     for (
int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
   212       for (
int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
   214         for (
int ts = 0; ts < ThreadMap::ThreadAccessShape::kStrided; ts++) {
   216           TensorCoord iteration_coord(c * ThreadMap::Delta::kContiguous,
   217                                       ts + s * ThreadMap::Delta::kStrided);
   219           TensorCoord coord = thread_offset_ + iteration_coord;
   223           if (is_steady_state) {
   224             if (kAdvanceRank == 0) {
   225               guard = (coord.strided() < extent_.strided());
   227               guard = (coord.contiguous() < extent_.contiguous());
   230             guard = (coord.strided() < extent_.strided() &&
   231                      coord.contiguous() < extent_.contiguous());
   234           int pred_idx = ts + c *  ThreadMap::ThreadAccessShape::kStrided + s * ThreadMap::Iterations::kContiguous *  ThreadMap::ThreadAccessShape::kStrided;
   235           int word_idx = pred_idx / kPredicatesPerWord;
   236           int residual = pred_idx % kPredicatesPerWord;
   237           int byte_idx = residual / kPredicatesPerByte;
   238           int bit_idx = residual % kPredicatesPerByte;
   240           predicates_[word_idx] |= (unsigned(guard) << (byte_idx * 8 + bit_idx));
   254       Params 
const ¶ms,
   264         pointer_(reinterpret_cast<BytePointer>(
   267         is_residue_tile_(true) {
   273           (extent_[kAdvanceRank] - threadblock_offset[kAdvanceRank] - 1) /
   275       residue_offset = 
make_Coord(0, residue_tile_idx_ * Shape::kStrided);
   278           (extent_[kAdvanceRank] - threadblock_offset[kAdvanceRank] - 1) /
   280       residue_offset = 
make_Coord(residue_tile_idx_ * Shape::kContiguous, 0);
   284     thread_offset_ = threadblock_offset + residue_offset +
   285                      ThreadMap::initial_offset(thread_id);
   288     Layout layout(params_.stride_);
   289     add_pointer_offset(layout(thread_offset_));
   291     compute_predicates_(
false);
   293     set_iteration_index(0);
   300       Params 
const ¶ms,
   314     int residual = index % (ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided);
   315     iteration_strided_ = index / (ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided);
   317     iteration_contiguous_ = residual / ThreadMap::ThreadAccessShape::kStrided;
   318     iteration_thread_ = residual % ThreadMap::ThreadAccessShape::kStrided;
   325     pointer_ += int(
sizeof(Element)) * pointer_offset;
   332     if (is_residue_tile_) {
   335         residue_offset = 
TensorCoord(0, residue_tile_idx_ * Shape::kStrided);
   337         residue_offset = 
TensorCoord(residue_tile_idx_ * Shape::kContiguous, 0);
   340       thread_offset_ -= residue_offset;
   342       Layout layout(params_.stride_);
   343       add_pointer_offset(-layout(residue_offset));
   345       compute_predicates_(
true);
   348         pointer_ += params_.inc_advance_ * (tile_offset.strided() - 1);
   349         pointer_ += Shape::kContiguous * tile_offset.contiguous();
   351         pointer_ += params_.inc_advance_ * (tile_offset.contiguous() - 1);
   352         pointer_ += Shape::kStrided * tile_offset.strided();
   356         pointer_ += params_.inc_advance_ * tile_offset.strided();
   357         pointer_ += Shape::kContiguous * tile_offset.contiguous();
   359         pointer_ += params_.inc_advance_ * tile_offset.contiguous();
   360         pointer_ += Shape::kStrided * tile_offset.strided();
   363     is_residue_tile_ = 
false;
   370                 pointer_ + (iteration_thread_ * params_.stride_  + iteration_contiguous_ * ThreadMap::Delta::kContiguous) * 
int(
sizeof(Element)));
   381     if (iteration_thread_ < ThreadMap::ThreadAccessShape::kStrided)
   384     iteration_thread_ = 0;
   386     ++iteration_contiguous_;
   388     if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous)
   393     iteration_contiguous_ = 0;
   394     ++iteration_strided_;
   396     if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
   397       pointer_ += params_.inc_strided_;
   403     iteration_strided_ = 0;
   406     pointer_ += params_.inc_next_;
   411     pointer_ -= params_.inc_advance_;
   428     for (
int i = 0; i < kPredicateWordCount; ++i) {
   438     for (
int i = 0; i < kPredicateWordCount; ++i) {
   439       predicates_[i] = 0xffffffff;
   447     for (
int i = 0; i < kPredicateWordCount; ++i) {
   448       predicates_[i] = mask[i];
   457     for (
int i = 0; i < kPredicateWordCount; ++i) {
   458       mask[i] = predicates_[i];
   468       iteration_contiguous_ * ThreadMap::ThreadAccessShape::kStrided + 
   469       iteration_strided_ * ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided;
   471     int word_idx = pred_idx / kPredicatesPerWord;
   472     int residual = pred_idx % kPredicatesPerWord;
   473     int byte_idx = residual / kPredicatesPerByte;
   474     int bit_idx = residual % kPredicatesPerByte;
   476     bool pred = (predicates_[word_idx] & (1u << (byte_idx * 8 + bit_idx))) != 0;
   491 template <
typename Shape_, 
typename Element_, 
int AdvanceRank,
   492           typename ThreadMap_, 
typename AccessType_>
   494                                    AdvanceRank, ThreadMap_, AccessType_> {
   497       AdvanceRank == 0 || AdvanceRank == 1,
   498       "Specialization for pitch-linear iterator may along advance along the "   499       "contiguous(rank=0) or strided(rank=1) dimension.");
   502   using Element = Element_;
   504   static int const kAdvanceRank = AdvanceRank;
   505   using ThreadMap = ThreadMap_;
   523   using Mask = 
typename UnderlyingIterator::Mask;
   531     typename UnderlyingIterator::Params params_;
   542         : params_(layout::PitchLinear(layout.stride(0))){};
   559       Params 
const ¶ms,
   568       : iterator_(params.params_, pointer,
   569                   layout::PitchLinearCoord(extent.row(), extent.column()),
   571                   layout::PitchLinearCoord(threadblock_offset.row(),
   572                                            threadblock_offset.column())) {}
   577       Params 
const ¶ms,  
   592     iterator_.add_pointer_offset(pointer_offset);
   599     iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()});
   605     return reinterpret_cast<AccessType *
>(iterator_.get());
   652     return iterator_.valid();
   665 template <
typename Shape_, 
typename Element_, 
int AdvanceRank,
   666           typename ThreadMap_, 
typename AccessType_>
   668                                    AdvanceRank, ThreadMap_, AccessType_> {
   671       AdvanceRank == 0 || AdvanceRank == 1,
   672       "Specialization for pitch-linear iterator may along advance along the "   673       "contiguous(rank=0) or strided(rank=1) dimension.");
   676   using Element = Element_;
   678   static int const kAdvanceRank = AdvanceRank;
   679   using ThreadMap = ThreadMap_;
   697   using Mask = 
typename UnderlyingIterator::Mask;
   705     typename UnderlyingIterator::Params params_;
   716         : params_(layout::PitchLinear(layout.stride(0))){};
   733       Params 
const ¶ms,
   742       : iterator_(params.params_, pointer,
   743                   layout::PitchLinearCoord(extent.column(), extent.row()),
   745                   layout::PitchLinearCoord(threadblock_offset.column(),
   746                                            threadblock_offset.row())) {}
   751       Params 
const ¶ms,  
   766     iterator_.add_pointer_offset(pointer_offset);
   773     iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()});
   779     return reinterpret_cast<AccessType *
>(iterator_.get());
   826     return iterator_.valid();
 int64_t LongIndex
Long index type used for offsets. 
Definition: layout/matrix.h:62
Definition: aligned_buffer.h:35
Coordinate in pitch-linear space. 
Definition: pitch_linear.h:52
Defines a structure containing strides, bounds, and a pointer to tensor data. 
Mapping function for pitch-linear memory. 
Definition: pitch_linear.h:163
A Coord is a coordinate of arbitrary rank into a tensor or matrix. 
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate. 
Definition: coord.h:387
int64_t LongIndex
Long index type used for offsets. 
Definition: layout/matrix.h:154
Defines a structure containing strides and a pointer to tensor data. 
Mapping function for column-major matrices. 
Definition: layout/matrix.h:142
Template defining a shape used by pitch-linear operators. 
Definition: pitch_linear.h:43
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Defines container classes and iterators for managing a statically sized vector of boolean predicates...
int32_t Index
Index type used for coordinates. 
Definition: layout/matrix.h:59
CUTLASS_HOST_DEVICE half_t & operator++(half_t &lhs)
Definition: half.h:694
int64_t LongIndex
Long index type used for offsets. 
Definition: pitch_linear.h:175
Defines a Shape template for matrix tiles. 
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
int32_t Index
Index type used for coordinates. 
Definition: pitch_linear.h:172
Mapping function for row-major matrices. 
Definition: layout/matrix.h:50
Defines layout functions used by TensorRef and derived classes. 
Defines layout functions used by TensorRef and derived classes for pitch-linear memory. 
int32_t Index
Index type used for coordinates. 
Definition: layout/matrix.h:151
Basic include for CUTLASS. 
Definition: matrix_coord.h:39