57 namespace threadblock {
    66     typename WarpMmaOperator_,
    70     typename OutputTileIterator_,
    72     typename AccumulatorFragmentIterator_,
    78     bool IsBetaZero = 
false>
   104   using TensorRef = 
typename OutputTileIterator::TensorRef;
   115                                  OutputTileIterator::kElementsPerAccess>;
   119       Array<ElementAccumulator, OutputTileIterator::kElementsPerAccess>;
   124                       Shape::kN / WarpMmaOperator::Shape::kN, kPartitionsK>;
   128                 "This must not be zero.");
   131                   OutputTileIterator::kElementsPerAccess),
   160     if (IsBetaZero && output_op.is_source_needed())
   163     typename OutputTileIterator::Fragment source_fragment;
   166       if (!output_op.is_source_needed()) {
   167         source_iterator.clear_mask();
   171     source_fragment.clear();
   184     for (
int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
   190         source_iterator.set_iteration_index(iter);
   191         source_iterator.load(source_fragment);
   199       typename AccumulatorFragmentIterator::Fragment accum_fragment;
   201       accum_fragment_iterator.load(accum_fragment);
   202       ++accum_fragment_iterator;
   208       typename OutputTileIterator::Fragment output_fragment;
   209       apply_output_operator_(output_op, output_fragment, accum_fragment, source_fragment);
   215       destination_iterator.set_iteration_index(iter);
   216       destination_iterator.store(output_fragment);
   217       ++destination_iterator;
   224   void apply_output_operator_(
   226       typename OutputTileIterator::Fragment &output_fragment,
   227       typename AccumulatorFragmentIterator::Fragment 
const   228           &aligned_accum_fragment,
   229       typename OutputTileIterator::Fragment 
const &source_fragment) {
   235             &aligned_accum_fragment);
   240     int const kOutputOpIterations = OutputTileIterator::Fragment::kElements /
   241                                     OutputTileIterator::kElementsPerAccess;
   244     for (
int i = 0; i < kOutputOpIterations; ++i) {
   246       output_frag_ptr[i] = output_op(compute_frag_ptr[i], source_frag_ptr[i]);
 Shape_ Shape
Definition: interleaved_epilogue.h:81
Definition: aligned_buffer.h:35
typename AccumulatorTile::Element ElementAccumulator
Accumulator element. 
Definition: interleaved_epilogue.h:95
Templates implementing how threads are mapped to a given tile. 
CUTLASS_DEVICE InterleavedEpilogue(SharedStorage &shared_storage, int thread_idx, int warp_idx, int lane_idx)
Constructor. 
Definition: interleaved_epilogue.h:141
typename AccumulatorFragmentIterator::AccumulatorTile AccumulatorTile
The complete warp-level accumulator tile. 
Definition: interleaved_epilogue.h:92
Epilogue for threadblock scoped GEMMs using Tensor Ops. 
Array< ElementAccumulator, OutputTileIterator::kElementsPerAccess > AccumulatorAccessType
Array type used by output functor. 
Definition: interleaved_epilogue.h:119
Epilogue operator without splitk. 
Definition: interleaved_epilogue.h:79
Array< typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess > OutputAccessType
Array type used to output. 
Definition: interleaved_epilogue.h:115
OutputOp_ OutputOp
Definition: interleaved_epilogue.h:86
Defines common types used for all GEMM-like operators. 
typename OutputTileIterator::ConstTensorRef ConstTensorRef
Const tensor reference to source tensor. 
Definition: interleaved_epilogue.h:111
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Defines layout functions used by TensorRef and derived classes for common 4-D and 5-D tensor formats...
typename OutputTileIterator::TensorRef TensorRef
Tensor reference to destination tensor. 
Definition: interleaved_epilogue.h:104
Shared storage allocation needed by the epilogue. 
Definition: interleaved_epilogue.h:135
WarpMmaOperator_ WarpMmaOperator
Definition: interleaved_epilogue.h:82
typename cutlass::TensorRef< int, cutlass::layout::PackedVectorLayout > SyncTensorRef
Tensor reference to sync tensor. 
Definition: interleaved_epilogue.h:108
Definition: tensor_ref.h:146
Defines a canonical coordinate for rank=4 tensors offering named indices. 
AlignedBuffer is a container for trivially copyable elements suitable for use in unions and shared me...
OutputTileIterator_ OutputTileIterator
Definition: interleaved_epilogue.h:85
Top-level include for all CUTLASS numeric types. 
Shape of a matrix multiply-add operation. 
Definition: include/cutlass/gemm/gemm.h:57
CUTLASS_DEVICE void operator()(OutputOp const &output_op, OutputTileIterator destination_iterator, AccumulatorTile const &accumulators, OutputTileIterator source_iterator)
Streams the result to global memory. 
Definition: interleaved_epilogue.h:150
static int const kElementsPerAccess
Output access size. 
Definition: interleaved_epilogue.h:101
Defines layout functions used for rank=1 vectors. 
Templates implementing storing of tiles from pitch-linear rank=2 tensors. 
Epilogue for threadblock scoped GEMMs using Tensor Ops. 
Definition: layout/matrix.h:343
static int const kPartitionsK
Definition: interleaved_epilogue.h:83
AccumulatorFragmentIterator_ AccumulatorFragmentIterator
Definition: interleaved_epilogue.h:84
typename OutputTileIterator::Element ElementOutput
Output element. 
Definition: interleaved_epilogue.h:98
Basic include for CUTLASS.