58 namespace threadblock {
    65   typename WarpMmaOperator_,                
    67   typename OutputTileIterator_,             
    68   typename AccumulatorFragmentIterator_,    
    69   typename WarpTileIterator_,               
    70   typename SharedLoadIterator_,             
    79     AccumulatorFragmentIterator_, 
    89     AccumulatorFragmentIterator_, 
   121   using TensorRef = 
typename OutputTileIterator::TensorRef;
   131     typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
   142   static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements,
   143     "Mismatch between shared load iterator and output tile iterator.");
   145   static_assert(OutputTileIterator::kElementsPerAccess, 
"OutputTileIterator::kElementsPerAccess must not be zero.");
   147   static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess), 
   165     Base(shared_storage, thread_idx, warp_idx, lane_idx),
   166     shared_load_iterator_(shared_storage.reference(), thread_idx) { }
   177     typename OutputTileIterator::Fragment source_fragment;
   179     if (!output_op.is_source_needed()) {
   180       source_iterator.clear_mask();
   183     source_fragment.clear();
   196     for (
int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
   202       source_iterator.load(source_fragment);
   211       typename AccumulatorFragmentIterator::Fragment accum_fragment;
   213       accum_fragment_iterator.load(accum_fragment);
   214       ++accum_fragment_iterator;
   226       shared_load_iterator_.load(aligned_accum_fragment[0]);
   229       if (kPartitionsK > 1)
   232         const int tile_row_offset = Base::SharedStorage::StorageShape::kRow / PartitionsK;
   236           shared_load_iterator_.add_tile_offset({tile_row_offset , 0});
   237           shared_load_iterator_.load(aligned_accum_fragment[i]);
   238           aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
   241         shared_load_iterator_.add_tile_offset({-1 * (kPartitionsK-1) * tile_row_offset, 0});
   248       typename OutputTileIterator::Fragment output_fragment;
   250       apply_output_operator_(output_fragment, output_op, aligned_accum_fragment[0], source_fragment);
   257       destination_iterator.store(output_fragment);      
   258       ++destination_iterator;
   267   void apply_output_operator_(
   268     typename OutputTileIterator::Fragment &output_fragment,
   271     typename OutputTileIterator::Fragment 
const &source_fragment) {
   282     int const kOutputOpIterations = 
   283       OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
   286     for (
int i = 0; i < kOutputOpIterations; ++i) {
   289       output_frag_ptr[i] = output_op(compute_frag_ptr[i], source_frag_ptr[i]);
 int64_t LongIndex
Long index type used for offsets. 
Definition: layout/matrix.h:62
typename Layout::LongIndex LongIndex
Definition: epilogue.h:105
typename Base::WarpCount WarpCount
Number of warps. 
Definition: epilogue.h:137
Array< Element, ThreadMap::Iterations::kColumn *ThreadMap::Iterations::kRow *ThreadMap::Iterations::kGroup *ThreadMap::Iterations::kCluster *ThreadMap::kElementsPerAccess > Fragment
Fragment object. 
Definition: shared_load_iterator.h:91
Definition: aligned_buffer.h:35
WarpTileIterator warp_tile_iterator_
Stores a warp's fragment of accumulators to SMEM. 
Definition: epilogue_base.h:176
Templates implementing how threads are mapped to a given tile. 
Shared storage allocation needed by the epilogue. 
Definition: epilogue_base.h:97
CUTLASS_DEVICE void operator()(OutputOp const &output_op, OutputTileIterator destination_iterator, AccumulatorTile const &accumulators, OutputTileIterator source_iterator)
Streams the result to global memory. 
Definition: epilogue.h:170
OutputTileIterator_ OutputTileIterator
Definition: epilogue.h:96
Epilogue for threadblock scoped GEMMs using Tensor Ops. 
Defines common types used for all GEMM-like operators. 
CUTLASS_DEVICE Epilogue(typename Base::SharedStorage &shared_storage, int thread_idx, int warp_idx, int lane_idx)
Constructor. 
Definition: epilogue.h:159
Shape_ Shape
Definition: epilogue.h:93
typename OutputTileIterator::TensorRef TensorRef
Tensor reference to destination tensor. 
Definition: epilogue.h:121
gemm::GemmShape< Shape::kM/WarpMmaOperator::Shape::kM, Shape::kN/WarpMmaOperator::Shape::kN, kPartitionsK > WarpCount
Number of warps. 
Definition: epilogue_base.h:92
Definition: functional.h:46
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Defines layout functions used by TensorRef and derived classes for common 4-D and 5-D tensor formats...
static int const kPartitionsK
Definition: epilogue.h:95
OutputOp_ OutputOp
Definition: epilogue.h:100
Definition: tensor_ref.h:146
Padding_ Padding
Definition: epilogue.h:101
Defines a canonical coordinate for rank=4 tensors offering named indices. 
AlignedBuffer is a container for trivially copyable elements suitable for use in unions and shared me...
AccumulatorFragmentIterator_ AccumulatorFragmentIterator
Definition: epilogue.h:97
Top-level include for all CUTLASS numeric types. 
typename OutputTileIterator::ConstTensorRef ConstTensorRef
Const tensor reference to source tensor. 
Definition: epilogue.h:127
WarpTileIterator_ WarpTileIterator
Definition: epilogue.h:98
SharedLoadIterator_ SharedLoadIterator
Definition: epilogue.h:99
Mapping function for row-major matrices. 
Definition: layout/matrix.h:50
Epilogue operator without splitk. 
Definition: epilogue.h:74
typename WarpTileIterator::Element ElementAccumulator
Accumulator element. 
Definition: epilogue.h:111
Defines layout functions used for rank=1 vectors. 
Templates implementing storing of tiles from pitch-linear rank=2 tensors. 
Epilogue for threadblock scoped GEMMs using Tensor Ops. 
Base class for epilogues defining warp-level. 
Definition: epilogue_base.h:67
WarpMmaOperator_ WarpMmaOperator
Definition: epilogue.h:94
typename Base::AccumulatorTile AccumulatorTile
The complete warp-level accumulator tile. 
Definition: epilogue.h:108
Array< typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess > OutputAccessType
Array type used to output. 
Definition: epilogue.h:131
static int const kElementsPerAccess
Output access size. 
Definition: epilogue.h:118
typename AccumulatorFragmentIterator::AccumulatorTile AccumulatorTile
The complete warp-level accumulator tile. 
Definition: epilogue_base.h:81
typename OutputTileIterator::Element ElementOutput
Output element. 
Definition: epilogue.h:115
Basic include for CUTLASS. 
typename cutlass::TensorRef< int, cutlass::layout::PackedVectorLayout > SyncTensorRef
Tensor reference to sync tensor. 
Definition: epilogue.h:124
Define basic numeric operators with specializations for Array<T, N>. SIMD-ize where possible...
Array< typename WarpTileIterator::Element, OutputTileIterator::kElementsPerAccess > AccumulatorAccessType
Array type used by output functor. 
Definition: epilogue.h:134