56   typename OperatorShape,     
    57   typename OperatorElementC,  
    58   typename OperatorFragmentC, 
    68   typename OperatorShape_,     
    69   typename OperatorElementC_,  
    70   typename OperatorFragmentC_  
    76   using OperatorShape = OperatorShape_;
    77   using OperatorElementC = OperatorElementC_;
    78   using OperatorFragmentC = OperatorFragmentC_;
    86     Policy::OperatorCount::kColumn * Policy::kElementsPerAccess>;
    91     OperatorFragmentC::kElements * Policy::OperatorCount::kRow * Policy::OperatorCount::kColumn>;
    96   static int const kIterations = Policy::kIterations;
   101   using AccessType = Array<OperatorElementC, Policy::kElementsPerAccess>;
   110   AccessType 
const *accumulators_;
   120     accumulators_(reinterpret_cast<AccessType const *>(&accum)), 
   142     int index = index_ + index_offset;
   144     AccessType *frag_ptr = 
reinterpret_cast<AccessType *
>(&frag);
   147     for (
int n = 0; n < Policy::OperatorCount::kColumn; ++n) {
   149       int accumulator_access_offset = 
   150         index + n * Policy::kAccumulatorColumnStride / Policy::kElementsPerAccess;
   152       frag_ptr[n] = accumulators_[accumulator_access_offset];
   164     typename OperatorShape_,
   166     typename OperatorElementC_,
   168     typename OperatorFragmentC_,
   172                                layout::ColumnMajorInterleaved<InterleavedK>> {
   175   using OperatorShape = OperatorShape_;
   176   using OperatorElementC = OperatorElementC_;
   177   using OperatorFragmentC = OperatorFragmentC_;
   178   static int const kInterleavedK = InterleavedK;
   185       Array<OperatorElementC,
   186             Policy::kElementsPerAccess * InterleavedK / OperatorShape::kN>;
   190       Array<OperatorElementC, OperatorFragmentC::kElements *
   191                                   Policy::OperatorCount::kRow *
   192                                   Policy::OperatorCount::kColumn>;
   195   static int const kIterations = Policy::kIterations;
   200       Array<OperatorElementC, Policy::kElementsPerAccess>;
   208   AccessType 
const *accumulators_;
   217       : accumulators_(reinterpret_cast<AccessType const *>(&accum)),
   237     int index = index_ + index_offset;
   239     AccessType *frag_ptr = 
reinterpret_cast<AccessType *
>(&frag);
   242     for (
int n = 0; n < (InterleavedK / OperatorShape::kN); ++n) {
   243       int index_m = index % (Policy::OperatorCount::kRow *
   244                              Policy::kIterationsPerInstruction);
   245       int index_n = index / (Policy::OperatorCount::kRow *
   246                              Policy::kIterationsPerInstruction);
   247       int accumulator_access_offset =
   248           (index_m / Policy::kIterationsPerInstruction) *
   249               (Policy::OperatorCount::kColumn *
   250                Policy::kIterationsPerInstruction) +
   251           (index_m % Policy::kIterationsPerInstruction) +
   252           index_n * (InterleavedK / OperatorShape::kN) *
   253               Policy::kIterationsPerInstruction +
   254           n * Policy::kIterationsPerInstruction;
   256       frag_ptr[n] = accumulators_[accumulator_access_offset];
 WarpShape_ WarpShape
Definition: fragment_iterator_tensor_op.h:75
CUTLASS_HOST_DEVICE void load(Fragment &frag, int index_offset=0) const 
Loads a fragment from the referenced part of the accumulator tile. 
Definition: fragment_iterator_tensor_op.h:140
Definition: aligned_buffer.h:35
Defines basic structures needed for implementing the warp-scoped phase of the epilogue. These quantities assume a 'column-major' arrangement of TensorOp instructions, of which a row-oriented slice is visible per iteration. 
AccumulatorTile OutputAccumulatorTile
Definition: fragment_iterator_tensor_op.h:93
CUTLASS_HOST_DEVICE void load(Fragment &frag, int index_offset=0) const 
Loads a fragment from the referenced part of the accumulator tile. 
Definition: fragment_iterator_tensor_op.h:236
CUTLASS_HOST_DEVICE FragmentIteratorTensorOp & operator--()
Decrements. 
Definition: fragment_iterator_tensor_op.h:229
WarpShape_ WarpShape
Definition: fragment_iterator_tensor_op.h:174
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Array< OperatorElementC, OperatorFragmentC::kElements *Policy::OperatorCount::kRow *Policy::OperatorCount::kColumn > AccumulatorTile
This is the complete warp-level accumulator tile. 
Definition: fragment_iterator_tensor_op.h:91
Array< OperatorElementC, Policy::OperatorCount::kColumn *Policy::kElementsPerAccess > Fragment
This is the fragment size produced by one access of the iterator. 
Definition: fragment_iterator_tensor_op.h:86
Policy details related to the epilogue. 
Definition: tensor_op_policy.h:50
CUTLASS_HOST_DEVICE FragmentIteratorTensorOp & operator++()
Increments. 
Definition: fragment_iterator_tensor_op.h:222
CUTLASS_HOST_DEVICE FragmentIteratorTensorOp(AccumulatorTile const &accum)
Constructs an iterator. 
Definition: fragment_iterator_tensor_op.h:119
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
CUTLASS_HOST_DEVICE FragmentIteratorTensorOp(AccumulatorTile const &accum)
Constructs an iterator. 
Definition: fragment_iterator_tensor_op.h:216
Definition: fragment_iterator_tensor_op.h:61
Array< OperatorElementC, Policy::kElementsPerAccess *InterleavedK/OperatorShape::kN > Fragment
This is the fragment size produced by one access of the iterator. 
Definition: fragment_iterator_tensor_op.h:186
Mapping function for row-major matrices. 
Definition: layout/matrix.h:50
CUTLASS_HOST_DEVICE FragmentIteratorTensorOp & operator++()
Increments. 
Definition: fragment_iterator_tensor_op.h:126
Defines layout functions used by TensorRef and derived classes. 
Definition: layout/matrix.h:343
Array< OperatorElementC, OperatorFragmentC::kElements *Policy::OperatorCount::kRow *Policy::OperatorCount::kColumn > AccumulatorTile
This is the complete warp-level accumulator tile. 
Definition: fragment_iterator_tensor_op.h:192
CUTLASS_HOST_DEVICE FragmentIteratorTensorOp & operator--()
Decrements. 
Definition: fragment_iterator_tensor_op.h:133