46   template<
typename ElementAlphaBeta, 
bool BetaIsZero>
    50     ElementAlphaBeta 
const & 
beta;
    54       alpha(alpha_), beta(beta_)
    57     template<
typename FragmentCD, 
typename FragmentAccumulator>
    60                     FragmentCD 
const& fragment_C,
    61                     FragmentCD& fragment_D)
 const    63       using AccType = 
typename FragmentAccumulator::value_type;
    64       using CDType = 
typename FragmentCD::value_type;
    66       static_assert(FragmentCD::kElements == FragmentAccumulator::kElements,
    67                     "Mistmatch in fragment sizes.");
    69       for (
int i = 0; i < FragmentCD::kElements; ++i)
    73           fragment_D[i] = CDType(accumulators[i] * AccType(alpha));
    77           fragment_D[i] = CDType(accumulators[i] * AccType(alpha)
    78                                  + AccType(fragment_C[i]) * AccType(beta));
    87 template <
typename GemvKernel, 
typename ElementAlphaBeta, 
bool BetaIsZero=false>
    90   ElementAlphaBeta 
alpha,
    91   ElementAlphaBeta 
beta,
    92   typename GemvKernel::IteratorA::TensorRef ref_A,
    93   typename GemvKernel::IteratorA::TensorRef::LongIndex lda, 
    94   typename GemvKernel::IteratorB::TensorRef ref_B,
    95   typename GemvKernel::IteratorB::TensorRef::LongIndex ldb, 
    96   typename GemvKernel::IteratorCD::TensorRef ref_C,
    97   typename GemvKernel::IteratorCD::TensorRef::LongIndex ldc,
    98   typename GemvKernel::IteratorCD::TensorRef ref_D,
    99   typename GemvKernel::IteratorCD::TensorRef::LongIndex ldd)
   101   using ThreadBlockGemv = 
typename GemvKernel::ThreadBlockGemv;
   102   using ThreadBlockSwizzle = 
typename GemvKernel::ThreadBlockSwizzle;
   105   ThreadBlockSwizzle swizzler;
   109   int const batch_idx = swizzler.get_batch_idx();
   112   ref_A.add_pointer_offset(batch_idx*lda);
   113   ref_B.add_pointer_offset(batch_idx*ldb);
   116   typename GemvKernel::IteratorA::Params params_A(ref_A.layout());
   117   typename GemvKernel::IteratorA iterator_A(
   120       { 1, problem_size.
k() },
   124   typename GemvKernel::IteratorB::Params params_B(ref_B.layout());
   125   typename GemvKernel::IteratorB iterator_B(
   128       { problem_size.
k(), problem_size.
n() },
   130       { 0, tb_offset.
n()*ThreadBlockGemv::Shape::kN });
   139   typename ThreadBlockGemv::FragmentC accumulators;
   140   accumulators.clear();
   143   mma(problem_size.
mnk(), accumulators, iterator_A, iterator_B, accumulators);
   148   typename GemvKernel::FragmentCD fragment_CD;
   153     tb_offset = swizzler.get_tile_offset();
   154     ref_C.add_pointer_offset(batch_idx*ldc);
   155     typename GemvKernel::IteratorCD::Params params_C(ref_C.layout());
   156     typename GemvKernel::IteratorCD iterator_C(
   159         { 1, problem_size.
n() },
   161         { 0, tb_offset.
n()*ThreadBlockGemv::Shape::kN });
   162     iterator_C.load(fragment_CD);
   166   EpilogueScale epilogue_scale(alpha, beta);
   167   epilogue_scale(accumulators, fragment_CD, fragment_CD);
   170   tb_offset = swizzler.get_tile_offset();
   171   ref_D.add_pointer_offset(batch_idx*ldd);
   172   typename GemvKernel::IteratorCD::Params params_D(ref_D.layout());
   173   typename GemvKernel::IteratorCD iterator_D(
   176       { 1, problem_size.
n() },
   178       { 0, tb_offset.
n()*ThreadBlockGemv::Shape::kN });
   179   iterator_D.store(fragment_CD);
   182 template <
typename GemvKernel, 
typename ElementAlphaBeta, 
bool BetaIsZero>
   185   ElementAlphaBeta 
alpha,
   186   ElementAlphaBeta 
beta,
   187   typename GemvKernel::IteratorA::TensorRef ref_A,
   188   typename GemvKernel::IteratorA::TensorRef::LongIndex lda, 
   189   typename GemvKernel::IteratorB::TensorRef ref_B,
   190   typename GemvKernel::IteratorB::TensorRef::LongIndex ldb, 
   191   typename GemvKernel::IteratorCD::TensorRef ref_C,
   192   typename GemvKernel::IteratorCD::TensorRef::LongIndex ldc,
   193   typename GemvKernel::IteratorCD::TensorRef ref_D,
   194   typename GemvKernel::IteratorCD::TensorRef::LongIndex ldd)
   196   GemvBatchedStridedDevice<GemvKernel, ElementAlphaBeta, BetaIsZero>(
   197     problem_size, 
alpha, 
beta, ref_A, lda, ref_B, ldb, ref_C, ldc, ref_D, ldd
   201 template <
typename GemvKernel, 
typename ElementAlphaBeta>
   204   ElementAlphaBeta 
alpha,
   205   typename GemvKernel::IteratorA::TensorRef ref_A,
   206   typename GemvKernel::IteratorA::TensorRef::LongIndex lda, 
   207   typename GemvKernel::IteratorB::TensorRef ref_B,
   208   typename GemvKernel::IteratorB::TensorRef::LongIndex ldb, 
   209   typename GemvKernel::IteratorCD::TensorRef ref_D,
   210   typename GemvKernel::IteratorCD::TensorRef::LongIndex ldd)
   212   GemvBatchedStridedDevice<GemvKernel, ElementAlphaBeta, true>(
   213     problem_size, 
alpha, ElementAlphaBeta(0), ref_A, lda, ref_B, ldb, ref_D, ldd, ref_D, ldd
   217 template <
typename GemvKernel>
   220   typename GemvKernel::IteratorA::TensorRef ref_A,
   221   typename GemvKernel::IteratorA::TensorRef::LongIndex lda, 
   222   typename GemvKernel::IteratorB::TensorRef ref_B,
   223   typename GemvKernel::IteratorB::TensorRef::LongIndex ldb, 
   224   typename GemvKernel::IteratorCD::TensorRef ref_D,
   225   typename GemvKernel::IteratorCD::TensorRef::LongIndex ldd)
   227   using ElementAlphaBeta = 
typename GemvKernel::IteratorCD::Element;
   228   GemvBatchedStridedDevice<GemvKernel, ElementAlphaBeta, true>(
   229     problem_size, ElementAlphaBeta(1), ElementAlphaBeta(0), ref_A, lda, ref_B, ldb, ref_D, ldd, ref_D, ldd
 Definition: aligned_buffer.h:35
CUTLASS_HOST_DEVICE GemmCoord mnk() const 
Obtains a GemmCoord from BatchedGemmCoord. 
Definition: include/cutlass/gemm/gemm.h:330
Defines common types used for all GEMM-like operators. 
CUTLASS_DEVICE void GemvBatchedStridedDevice(cutlass::gemm::BatchedGemmCoord problem_size, ElementAlphaBeta alpha, ElementAlphaBeta beta, typename GemvKernel::IteratorA::TensorRef ref_A, typename GemvKernel::IteratorA::TensorRef::LongIndex lda, typename GemvKernel::IteratorB::TensorRef ref_B, typename GemvKernel::IteratorB::TensorRef::LongIndex ldb, typename GemvKernel::IteratorCD::TensorRef ref_C, typename GemvKernel::IteratorCD::TensorRef::LongIndex ldc, typename GemvKernel::IteratorCD::TensorRef ref_D, typename GemvKernel::IteratorCD::TensorRef::LongIndex ldd)
Definition: gemv_batched_strided.h:88
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
__global__ void GemvBatchedStrided(cutlass::gemm::BatchedGemmCoord problem_size, ElementAlphaBeta alpha, ElementAlphaBeta beta, typename GemvKernel::IteratorA::TensorRef ref_A, typename GemvKernel::IteratorA::TensorRef::LongIndex lda, typename GemvKernel::IteratorB::TensorRef ref_B, typename GemvKernel::IteratorB::TensorRef::LongIndex ldb, typename GemvKernel::IteratorCD::TensorRef ref_C, typename GemvKernel::IteratorCD::TensorRef::LongIndex ldc, typename GemvKernel::IteratorCD::TensorRef ref_D, typename GemvKernel::IteratorCD::TensorRef::LongIndex ldd)
Definition: gemv_batched_strided.h:183
Defines a Shape template for matrix tiles. 
Definition: include/cutlass/gemm/gemm.h:260
AlignedBuffer is a container for trivially copyable elements suitable for use in unions and shared me...
CUTLASS_HOST_DEVICE Index const & k() const 
Returns the GEMM K coordinate. 
Definition: include/cutlass/gemm/gemm.h:314
Top-level include for all CUTLASS numeric types. 
ElementAlphaBeta const & beta
Definition: gemv_batched_strided.h:50
CUTLASS_DEVICE void operator()(FragmentAccumulator &accumulators, FragmentCD const &fragment_C, FragmentCD &fragment_D) const 
Definition: gemv_batched_strided.h:59
ElementAlphaBeta const & alpha
Definition: gemv_batched_strided.h:49
Definition: gemv_batched_strided.h:47
CUTLASS_HOST_DEVICE Index const & n() const 
Returns the GEMM N coordinate. 
Definition: include/cutlass/gemm/gemm.h:306
Basic include for CUTLASS. 
CUTLASS_DEVICE GemvBatchedStridedEpilogueScaling(ElementAlphaBeta &alpha_, ElementAlphaBeta &beta_)
Definition: gemv_batched_strided.h:53