42 namespace threadblock {
    63     Shape::kM / Operator::Shape::kM,
    64     Shape::kN / Operator::Shape::kN,
    69     "Direct epilogue cannot be used with when the threadblock tile is partitioned along the K dimension.");
   113       typename OutputOp::Params output_op_,
   114       typename ConvertOp::Params convert_op_
   116       destination_ref(destination_ref_),
   117       source_ref(source_ref_),
   118       output_op(output_op_),
   119       convert_op(convert_op_) {
   128       typename OutputOp::Params output_op_
   162     output_op(params.output_op),
   163     convert_op(params.convert_op),
   179       warp_m * Operator::Shape::kM, 
   180       warp_n * Operator::Shape::kN
   195       MatrixCoord{tb_tile_coord.
m() * Shape::kM, tb_tile_coord.
n() * Shape::kN} + warp_origin_;
   199       Operator::Shape::kM / Operator::Policy::Operator::Shape::kM,
   200       Operator::Shape::kN / Operator::Policy::Operator::Shape::kN
   206     int const kElementsPerAccess = Operator::Policy::Operator::Shape::kN / 4;
   207     int const kRowsPerTile = 8;
   208     int const kAccumulatorRows = Operator::Policy::Operator::Shape::kM / kRowsPerTile;
   211     for (
int mma_n = 0; mma_n < MmaIterations::kN; ++mma_n) {
   213       for (
int mma_m = 0; mma_m < MmaIterations::kM; ++mma_m) {
   215         int mma_accum_start = kAccumulatorRows * kElementsPerAccess * 
   216           (mma_m * MmaIterations::kN + mma_n);
   219         for (
int row = 0; row < kAccumulatorRows; ++row) {
   221           for (
int col = 0; col < kElementsPerAccess; ++col) {
   223             int accum_m = mma_m * Operator::Policy::Operator::Shape::kM + row * kRowsPerTile;
   224             int accum_n = mma_n * Operator::Policy::Operator::Shape::kN + col;
   225             int idx = mma_accum_start + row * kElementsPerAccess + col;
   229             MatrixCoord thread_coord = thread_origin + accum_coord;
   231             if (thread_coord < 
MatrixCoord{problem_size.
m(), problem_size.
n()}) {
   233               typename ConvertOp::result_type converted_accum = 
convert_op(accumulators[idx]);
   235               typename OutputOp::result_type output = 
output_op(converted_accum, source_ref_.
at(accum_coord));
   237               destination_ref_.
at(accum_coord) = output;
 Epilogue operator. 
Definition: direct_epilogue_tensor_op.h:55
static int const kM
Definition: include/cutlass/gemm/gemm.h:58
Describes the size of a matrix tile. 
Definition: matrix_shape.h:42
CUTLASS_HOST_DEVICE Params(TensorRef destination_ref_, TensorRef source_ref_, typename OutputOp::Params output_op_)
Constructs a Params object. 
Definition: direct_epilogue_tensor_op.h:125
Parameters structure for host-constructible state. 
Definition: direct_epilogue_tensor_op.h:92
Definition: aligned_buffer.h:35
CUTLASS_DEVICE void operator()(gemm::GemmCoord problem_size, gemm::GemmCoord tb_tile_coord, FragmentC const &accumulators)
Streams the result to global memory. 
Definition: direct_epilogue_tensor_op.h:189
TensorRef destination_ref
Definition: direct_epilogue_tensor_op.h:98
Definition: include/cutlass/gemm/gemm.h:94
Defines common types used for all GEMM-like operators. 
CUTLASS_HOST_DEVICE Index const & n() const 
Returns the GEMM N coordinate. 
Definition: include/cutlass/gemm/gemm.h:137
TensorRef< Element, Layout::kRank, Layout > TensorRef
Reference to source and destination tensors. 
Definition: direct_epilogue_tensor_op.h:87
CUTLASS_HOST_DEVICE TensorRef & add_coord_offset(TensorCoord const &coord)
Adds an offset to each pointer. 
Definition: tensor_ref.h:326
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
TensorRef source_ref
Definition: direct_epilogue_tensor_op.h:99
CUTLASS_DEVICE DirectEpilogueTensorOp(Params const ¶ms, SharedStorage &shared_storage, int thread_idx, int warp_idx, int lane_idx)
Constructor. 
Definition: direct_epilogue_tensor_op.h:155
OutputOp::Params output_op
Definition: direct_epilogue_tensor_op.h:101
ConvertOp::Params convert_op
Definition: direct_epilogue_tensor_op.h:102
CUTLASS_HOST_DEVICE Params(TensorRef destination_ref_, TensorRef source_ref_, typename OutputOp::Params output_op_, typename ConvertOp::Params convert_op_)
Constructs a Params object. 
Definition: direct_epilogue_tensor_op.h:110
Shared storage allocation needed by the epilogue. 
Definition: direct_epilogue_tensor_op.h:139
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Top-level include for all CUTLASS numeric types. 
Shape of a matrix multiply-add operation. 
Definition: include/cutlass/gemm/gemm.h:57
typename Operator::FragmentC FragmentC
Accumulator tile is really the warp-scoped tile. 
Definition: direct_epilogue_tensor_op.h:72
Mapping function for row-major matrices. 
Definition: layout/matrix.h:50
Operator_ Operator
Definition: direct_epilogue_tensor_op.h:59
CUTLASS_HOST_DEVICE Reference at(TensorCoord const &coord) const 
Returns a reference to the element at a given Coord. 
Definition: tensor_ref.h:307
OutputOp_ OutputOp
Function operator computing final output. 
Definition: direct_epilogue_tensor_op.h:81
CUTLASS_HOST_DEVICE Index const & m() const 
Returns the GEMM M coordinate. 
Definition: include/cutlass/gemm/gemm.h:129
ConvertOp_ ConvertOp
Conversion operator to shared memory. 
Definition: direct_epilogue_tensor_op.h:84
Basic include for CUTLASS. 
Definition: matrix_coord.h:39
Element_ Element
Data type of output tensor. 
Definition: direct_epilogue_tensor_op.h:75
Shape_ Shape
Definition: direct_epilogue_tensor_op.h:58
static int const kN
Definition: include/cutlass/gemm/gemm.h:59