48 namespace threadblock {
    74   using Element = Element_;
    76   static int const kAdvanceRank = AdvanceRank;
    77   using ThreadMap = ThreadMap_;
    78   static int const kAlignment = Alignment;
    86   using Fragment = Array<Element, ThreadMap::Iterations::kCount * ThreadMap::ThreadAccessShape::kCount>;
    89     "Advance rank may only be along the contiguous or strided dimensions.");
   110   Index increment_strided_;
   113   Index increment_advance_;
   127     TensorCoord t = ThreadMap::initial_offset(thread_idx);
   128     long int offset = t[0] * interleave + t[1] * ref.
stride()[0]/interleave;
   129     pointer_ = 
reinterpret_cast<uint8_t *
>(ref.
data() + offset);
   131     stride_ = ref.
stride()[0] / interleave;
   148     for (
int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
   153       for (
int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
   155           int idx = c + s * ThreadMap::Iterations::kContiguous;
   156            frag_ptr[idx] = access_ptr[c * ThreadMap::Delta::kContiguous / ThreadMap::ThreadAccessShape::kStrided];
   159       if (s + 1 < ThreadMap::Iterations::kStrided) {
   160         byte_pointer += increment_strided_;
   168     load_with_pointer_offset(
   170       tile_offset.contiguous() * Shape::kContiguous / ThreadMap::kElementsPerAccess + 
   171         tile_offset.strided() * Shape::kStrided * stride_
   178     load_with_pointer_offset(frag, 0);
   189     for (
int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
   194       for (
int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
   196           int idx = c + s * ThreadMap::Iterations::kContiguous;
   197           access_ptr[c * ThreadMap::Delta::kContiguous / ThreadMap::ThreadAccessShape::kStrided] = frag_ptr[idx];
   200       if (s + 1 < ThreadMap::Iterations::kStrided) {
   201         byte_pointer += increment_strided_;
   209     store_with_pointer_offset(
   211       tile_offset.contiguous() * Shape::kContiguous + tile_offset.strided() * Shape::kStrided * stride_
   218     store_with_pointer_offset(frag, 0);
   224     pointer_ += increment_advance_;
   231     pointer_ -= increment_advance_;
   238     pointer_ += pointer_offset;
   245         (coord.contiguous() * Shape::kContiguous + coord.strided() * Shape::kStrided * stride_) / 8;
   246     add_pointer_offset(offset);
   265   using Element = Element_;
   267   static int const kAdvanceRank = AdvanceRank;
   268   using ThreadMap = ThreadMap_;
   269   static int const kAlignment = Alignment;
   277   using Fragment = Array<Element, ThreadMap::Iterations::kCount * ThreadMap::ThreadAccessShape::kCount>;
   283     (kAdvanceRank == 0 ? 1 : 0),
   289     "Advance rank may only be along the row or column dimensions.");
   305     iterator_({ref.
data(), ref.
stride()}, thread_idx, 4) {
   312     iterator_.load_with_pointer_offset(frag, pointer_offset);
   318     iterator_.load_with_pointer_offset(frag, {tile_offset.column(), tile_offset.row()});
   324     iterator_.load_with_pointer_offset(frag, 0);
   330     iterator_.store_with_pointer_offset(frag, pointer_offset);
   336     iterator_.store_with_pointer_offset(frag, {tile_offset.column(), tile_offset.row()});
   342     iterator_.store_with_pointer_offset(frag, 0);
   362     iterator_.add_pointer_offset(pointer_offset);
   368     iterator_.add_tile_offset({coord.column(), coord.row()});
   387   using Element = Element_;
   389   static int const kAdvanceRank = AdvanceRank;
   390   using ThreadMap = ThreadMap_;
   391   static int const kAlignment = Alignment;
   399   using Fragment = Array<Element, ThreadMap::Iterations::kCount * ThreadMap::ThreadAccessShape::kCount>;
   401                                   ThreadMap::kThreads, ThreadMap::ThreadAccessShape::kCount >;
   408     (kAdvanceRank == 0 ? 0 : 1),
   413     "Advance rank may only be along the row or column dimensions.");
   429     iterator_({ref.
data(), ref.
stride()}, thread_idx, 4) {
   436     iterator_.load_with_pointer_offset(frag, pointer_offset);
   442     iterator_.load_with_pointer_offset(frag, {tile_offset.row(), tile_offset.column()});
   448     iterator_.load_with_pointer_offset(frag, 0);
   454     iterator_.store_with_pointer_offset(frag, pointer_offset);
   460     iterator_.store_with_pointer_offset(frag, {tile_offset.row(), tile_offset.column()});
   466     iterator_.store_with_pointer_offset(frag, 0);
   486     iterator_.add_pointer_offset(pointer_offset);
   492     iterator_.add_tile_offset({coord.row(), coord.column()});
 
int64_t LongIndex
Long index type used for offsets. 
Definition: layout/matrix.h:355
Definition: aligned_buffer.h:35
Coordinate in pitch-linear space. 
Definition: pitch_linear.h:52
static int const value
Definition: numeric_types.h:43
Defines a structure containing strides, bounds, and a pointer to tensor data. 
CUTLASS_HOST_DEVICE Element * data() const 
Returns the pointer to referenced data. 
Definition: tensor_ref.h:254
int64_t LongIndex
Long index type used for offsets. 
Definition: layout/matrix.h:249
Mapping function for pitch-linear memory. 
Definition: pitch_linear.h:163
int32_t Index
Index type used for coordinates. 
Definition: layout/matrix.h:352
Aligned array type. 
Definition: array.h:511
int32_t Index
Index type used for coordinates. 
Definition: layout/matrix.h:246
Template defining a shape used by pitch-linear operators. 
Definition: pitch_linear.h:43
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
int64_t LongIndex
Long index type used for offsets. 
Definition: pitch_linear.h:175
CUTLASS_HOST_DEVICE Stride stride() const 
Returns the layout object's stride vector. 
Definition: tensor_ref.h:277
Defines the size of an element in bits. 
Definition: numeric_types.h:42
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
int32_t Index
Index type used for coordinates. 
Definition: pitch_linear.h:172
Templates implementing storing of tiles from pitch-linear rank=2 tensors. 
Defines layout functions used by TensorRef and derived classes. 
Defines layout functions used by TensorRef and derived classes for pitch-linear memory. 
Basic include for CUTLASS. 
Definition: matrix_coord.h:39