/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.modality.nlp.generate;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;

public final class StepGeneration {
    private StepGeneration() {
    }

    public static NDList constrastiveStepGeneration(NDArray topKIds, NDArray logits, NDArray contextHiddenStates, NDArray topkHiddenStates, NDArray offSets, float alpha) {
        long batch = topKIds.getShape().get(0);
        long topK = topKIds.getShape().get(1);
        long hiddenDim = topkHiddenStates.getShape().getLastDimension();
        topkHiddenStates = topkHiddenStates.reshape(batch, topK, hiddenDim);
        topkHiddenStates = topkHiddenStates.normalize(2.0, 2L);
        contextHiddenStates = contextHiddenStates.normalize(2.0, 2L);
        NDArray cosSimilarity = topkHiddenStates.batchMatMul(contextHiddenStates.transpose(0, 2, 1));
        long[] offSetsArray = offSets.toLongArray();
        for (int i = 0; i < offSetsArray.length; ++i) {
            cosSimilarity.set(new NDIndex("{}, :, {}:{}", i, 0, offSetsArray[i]), (Number)-1);
        }
        NDArray topkScorePart1 = cosSimilarity.max(new int[]{2});
        assert (topkScorePart1.getShape().getShape().length == 2) : "Wrong output size";
        NDArray topkScorePart2 = logits.softmax(1).gather(topKIds, 1);
        NDArray topkScore = topkScorePart2.muli(Float.valueOf(1.0f - alpha)).subi(topkScorePart1.muli(Float.valueOf(alpha)));
        NDArray select = topkScore.argMax(1);
        NDIndex selectIndex = new NDIndex("{}, {}, ...", logits.getManager().arange(0.0f, topKIds.getShape().get(0), 1.0f, DataType.INT64), select);
        NDArray outputIds = topKIds.get(selectIndex).reshape(-1L, 1L);
        return new NDList(outputIds, select);
    }

    public static NDArray greedyStepGen(NDArray logits) {
        assert (logits.getShape().getShape().length == 3) : "unexpected input";
        logits = logits.get(":, -1, :", new Object[0]);
        return logits.argMax(-1).expandDims(1);
    }

    public static NDList beamStepGeneration(NDArray lastProbs, NDArray logits, long numBatch, long numBeam) {
        NDArray allProbs = logits.get(":, -1, :", new Object[0]).softmax(1).reshape(numBatch, numBeam, -1L);
        NDList topK = allProbs.topK(Math.toIntExact(numBeam), -1, true, false);
        NDArray outputIs = (NDArray)topK.get(1);
        NDArray stepProbs = (NDArray)topK.get(0);
        lastProbs = lastProbs.reshape(numBatch, numBeam, 1L);
        NDArray newProbs = stepProbs.muli(lastProbs);
        topK = newProbs.reshape(numBatch, numBeam * numBeam).topK(Math.toIntExact(numBeam), -1, true, false);
        NDArray select = (NDArray)topK.get(1);
        NDIndex selectIndex = new NDIndex("{}, {}, ...", logits.getManager().arange(0.0f, numBatch, 1.0f, DataType.INT64).expandDims(1).repeat(1, numBeam), select);
        outputIs = outputIs.reshape(numBatch, numBeam * numBeam).get(selectIndex).expandDims(2);
        newProbs = newProbs.reshape(numBatch, numBeam * numBeam).get(selectIndex).normalize(1.0, 1L);
        assert (select.getDataType() == DataType.INT64) : "Wrong output! Expect integer division";
        assert (select.getShape().getShape().length == 2) : "Wrong size. Expect [batch, beamNew]";
        long[] index = select.toLongArray();
        for (int i = 0; i < index.length; ++i) {
            index[i] = Math.floorDiv(index[i], numBeam);
        }
        NDArray sourceBeamSelected = logits.getManager().create(index, new Shape(numBatch, numBeam));
        return new NDList(outputIs, newProbs, sourceBeamSelected);
    }
}

