/*
 * Decompiled with CFR 0.152.
 */
package htsjdk.samtools.cram.compression.rans.ransnx16;

import htsjdk.samtools.cram.CRAMException;
import htsjdk.samtools.cram.compression.CompressionUtils;
import htsjdk.samtools.cram.compression.rans.ArithmeticDecoder;
import htsjdk.samtools.cram.compression.rans.RANSDecode;
import htsjdk.samtools.cram.compression.rans.RANSDecodingSymbol;
import htsjdk.samtools.cram.compression.rans.Utils;
import htsjdk.samtools.cram.compression.rans.ransnx16.RANSNx16Params;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Arrays;

public class RANSNx16Decode
extends RANSDecode {
    private static final ByteBuffer EMPTY_BUFFER = CompressionUtils.allocateByteBuffer(0);
    private static final int FREQ_TABLE_OPTIONALLY_COMPRESSED_MASK = 1;
    private static final int RLE_META_OPTIONALLY_COMPRESSED_MASK = 1;

    @Override
    public ByteBuffer uncompress(ByteBuffer inBuffer) {
        inBuffer.order(ByteOrder.LITTLE_ENDIAN);
        return this.uncompress(inBuffer, 0);
    }

    private ByteBuffer uncompress(ByteBuffer inBuffer, int outSize) {
        ByteBuffer outBuffer;
        int uncompressedSize;
        if (inBuffer.remaining() == 0) {
            return EMPTY_BUFFER;
        }
        int formatFlags = inBuffer.get() & 0xFF;
        RANSNx16Params ransNx16Params = new RANSNx16Params(formatFlags);
        int n = uncompressedSize = ransNx16Params.isNosz() ? outSize : CompressionUtils.readUint7(inBuffer);
        if (ransNx16Params.isStripe()) {
            return this.decodeStripe(inBuffer, uncompressedSize);
        }
        int packDataLength = 0;
        int numSymbols = 0;
        byte[] packMappingTable = null;
        if (ransNx16Params.isPack()) {
            packDataLength = uncompressedSize;
            numSymbols = inBuffer.get() & 0xFF;
            if (numSymbols <= 16 && numSymbols != 0) {
                packMappingTable = new byte[numSymbols];
                for (int i = 0; i < numSymbols; ++i) {
                    packMappingTable[i] = inBuffer.get();
                }
                uncompressedSize = CompressionUtils.readUint7(inBuffer);
            } else {
                throw new CRAMException("Bit Packing is not permitted when number of distinct symbols is greater than 16 or equal to 0. Number of distinct symbols: " + numSymbols);
            }
        }
        int uncompressedRLEOutputLength = 0;
        int[] rleSymbols = null;
        ByteBuffer uncompressedRLEMetaData = null;
        if (ransNx16Params.isRLE()) {
            rleSymbols = new int[256];
            int uncompressedRLEMetaDataLength = CompressionUtils.readUint7(inBuffer);
            uncompressedRLEOutputLength = uncompressedSize;
            uncompressedSize = CompressionUtils.readUint7(inBuffer);
            uncompressedRLEMetaData = this.decodeRLEMeta(inBuffer, uncompressedRLEMetaDataLength, rleSymbols, ransNx16Params);
        }
        if (ransNx16Params.isCAT()) {
            outBuffer = CompressionUtils.slice(inBuffer);
            outBuffer.limit(uncompressedSize);
            inBuffer.position(inBuffer.position() + uncompressedSize);
        } else {
            outBuffer = CompressionUtils.allocateByteBuffer(uncompressedSize);
            if (uncompressedSize == 0) {
                throw new CRAMException("Unexpected uncompressed size of 0 in RANSNx16 stream");
            }
            switch (ransNx16Params.getOrder()) {
                case ZERO: {
                    this.uncompressOrder0WayN(inBuffer, outBuffer, uncompressedSize, ransNx16Params);
                    break;
                }
                case ONE: {
                    this.uncompressOrder1WayN(inBuffer, outBuffer, ransNx16Params);
                    break;
                }
                default: {
                    throw new CRAMException("Unknown rANSNx16 order: " + String.valueOf((Object)ransNx16Params.getOrder()));
                }
            }
        }
        if (ransNx16Params.isRLE()) {
            outBuffer = this.decodeRLE(outBuffer, rleSymbols, uncompressedRLEMetaData, uncompressedRLEOutputLength);
        }
        if (ransNx16Params.isPack()) {
            outBuffer = CompressionUtils.decodePack(outBuffer, packMappingTable, numSymbols, packDataLength);
        }
        return outBuffer;
    }

    private void uncompressOrder0WayN(ByteBuffer inBuffer, ByteBuffer outBuffer, int outSize, RANSNx16Params ransNx16Params) {
        this.initializeRANSDecoder();
        this.readFrequencyTableOrder0(inBuffer);
        int Nway = ransNx16Params.getNumInterleavedRANSStates();
        long[] rans = new long[Nway];
        for (int r = 0; r < Nway; ++r) {
            rans[r] = inBuffer.getInt();
        }
        int interleaveSize = Nway == 4 ? outSize >> 2 : outSize >> 5;
        int remSize = outSize - interleaveSize * Nway;
        int out_end = outSize - remSize;
        ArithmeticDecoder D = this.getD()[0];
        RANSDecodingSymbol[] syms = this.getDecodingSymbols()[0];
        for (int i = 0; i < out_end; i += Nway) {
            for (int r = 0; r < Nway; ++r) {
                byte decodedSymbol = D.reverseLookup[Utils.RANSGetCumulativeFrequency(rans[r], 12)];
                outBuffer.put(i + r, decodedSymbol);
                rans[r] = syms[0xFF & decodedSymbol].advanceSymbolStep(rans[r], 12);
                rans[r] = Utils.RANSDecodeRenormalizeNx16(rans[r], inBuffer);
            }
        }
        outBuffer.position(out_end);
        int reverseIndex = 0;
        while (remSize > 0) {
            byte remainingSymbol = D.reverseLookup[Utils.RANSGetCumulativeFrequency(rans[reverseIndex], 12)];
            syms[0xFF & remainingSymbol].advanceSymbolNx16(rans[reverseIndex], inBuffer, 12);
            outBuffer.put(remainingSymbol);
            --remSize;
            ++reverseIndex;
        }
        outBuffer.rewind();
    }

    private void uncompressOrder1WayN(ByteBuffer inBuffer, ByteBuffer outBuffer, RANSNx16Params ransNx16Params) {
        ByteBuffer freqTableSource;
        boolean optionalCompressFlag;
        int frequencyTableFirstByte = inBuffer.get() & 0xFF;
        boolean bl = optionalCompressFlag = (frequencyTableFirstByte & 1) != 0;
        if (optionalCompressFlag) {
            int uncompressedLength = CompressionUtils.readUint7(inBuffer);
            int compressedLength = CompressionUtils.readUint7(inBuffer);
            byte[] compressedFreqTable = new byte[compressedLength];
            inBuffer.get(compressedFreqTable, 0, compressedLength);
            freqTableSource = CompressionUtils.allocateByteBuffer(uncompressedLength);
            ByteBuffer compressedFrequencyTableBuffer = CompressionUtils.wrap(compressedFreqTable);
            this.uncompressOrder0WayN(compressedFrequencyTableBuffer, freqTableSource, uncompressedLength, new RANSNx16Params(-6));
        } else {
            freqTableSource = inBuffer;
        }
        this.initializeRANSDecoder();
        int shift = frequencyTableFirstByte >> 4;
        this.readFrequencyTableOrder1(freqTableSource, shift);
        int outputSize = outBuffer.remaining();
        int Nway = ransNx16Params.getNumInterleavedRANSStates();
        long[] rans = new long[Nway];
        int[] interleaveStreamIndex = new int[Nway];
        int[] context = new int[Nway];
        int interleaveSize = Nway == 4 ? outputSize >> 2 : outputSize >> 5;
        for (int r = 0; r < Nway; ++r) {
            rans[r] = inBuffer.getInt();
            interleaveStreamIndex[r] = r * interleaveSize;
            context[r] = 0;
        }
        ArithmeticDecoder[] D = this.getD();
        RANSDecodingSymbol[][] syms = this.getDecodingSymbols();
        int[] symbol = new int[Nway];
        while (interleaveStreamIndex[0] < interleaveSize) {
            int r;
            for (r = 0; r < Nway; ++r) {
                symbol[r] = 0xFF & D[context[r]].reverseLookup[Utils.RANSGetCumulativeFrequency(rans[r], shift)];
                outBuffer.put(interleaveStreamIndex[r], (byte)symbol[r]);
                rans[r] = syms[context[r]][symbol[r]].advanceSymbolStep(rans[r], shift);
                rans[r] = Utils.RANSDecodeRenormalizeNx16(rans[r], inBuffer);
                context[r] = symbol[r];
            }
            r = 0;
            while (r < Nway) {
                int n = r++;
                interleaveStreamIndex[n] = interleaveStreamIndex[n] + 1;
            }
        }
        while (interleaveStreamIndex[Nway - 1] < outputSize) {
            symbol[Nway - 1] = 0xFF & D[context[Nway - 1]].reverseLookup[Utils.RANSGetCumulativeFrequency(rans[Nway - 1], shift)];
            outBuffer.put(interleaveStreamIndex[Nway - 1], (byte)symbol[Nway - 1]);
            rans[Nway - 1] = syms[context[Nway - 1]][symbol[Nway - 1]].advanceSymbolNx16(rans[Nway - 1], inBuffer, shift);
            context[Nway - 1] = symbol[Nway - 1];
            int n = Nway - 1;
            interleaveStreamIndex[n] = interleaveStreamIndex[n] + 1;
        }
    }

    private void readFrequencyTableOrder0(ByteBuffer cp) {
        int[] alphabet = RANSNx16Decode.readAlphabet(cp);
        ArithmeticDecoder decoder = this.getD()[0];
        for (int j = 0; j < 256; ++j) {
            if (alphabet[j] <= 0) continue;
            decoder.frequencies[j] = CompressionUtils.readUint7(cp);
        }
        Utils.normaliseFrequenciesOrder0Shift(decoder.frequencies, 12);
        RANSDecodingSymbol[] decodingSymbols = this.getDecodingSymbols()[0];
        int cumulativeFrequency = 0;
        for (int j = 0; j < 256; ++j) {
            if (alphabet[j] <= 0) continue;
            decodingSymbols[j].set(cumulativeFrequency, decoder.frequencies[j]);
            Arrays.fill(decoder.reverseLookup, cumulativeFrequency, cumulativeFrequency + decoder.frequencies[j], (byte)j);
            cumulativeFrequency += decoder.frequencies[j];
        }
    }

    private void readFrequencyTableOrder1(ByteBuffer cp, int shift) {
        ArithmeticDecoder[] D = this.getD();
        RANSDecodingSymbol[][] decodingSymbols = this.getDecodingSymbols();
        int[] alphabet = RANSNx16Decode.readAlphabet(cp);
        for (int i = 0; i < 256; ++i) {
            if (alphabet[i] <= 0) continue;
            int run = 0;
            for (int j = 0; j < 256; ++j) {
                if (alphabet[j] <= 0) continue;
                if (run > 0) {
                    --run;
                    continue;
                }
                D[i].frequencies[j] = CompressionUtils.readUint7(cp);
                if (D[i].frequencies[j] != 0) continue;
                run = cp.get() & 0xFF;
            }
            Utils.normaliseFrequenciesOrder0Shift(D[i].frequencies, shift);
            int cumulativeFreq = 0;
            for (int j = 0; j < 256; ++j) {
                decodingSymbols[i][j].set(cumulativeFreq, D[i].frequencies[j]);
                Arrays.fill(D[i].reverseLookup, cumulativeFreq, cumulativeFreq + D[i].frequencies[j], (byte)j);
                cumulativeFreq += D[i].frequencies[j];
            }
        }
    }

    private static int[] readAlphabet(ByteBuffer cp) {
        int symbol;
        int[] alphabet = new int[256];
        int rle = 0;
        int lastSymbol = symbol = cp.get() & 0xFF;
        do {
            alphabet[symbol] = 1;
            if (rle != 0) {
                --rle;
            } else {
                symbol = cp.get() & 0xFF;
                if (symbol == lastSymbol + 1) {
                    rle = cp.get() & 0xFF;
                }
            }
            lastSymbol = ++symbol;
        } while (symbol != 0);
        return alphabet;
    }

    private ByteBuffer decodeRLEMeta(ByteBuffer inBuffer, int uncompressedRLEMetaDataLength, int[] rleSymbols, RANSNx16Params ransNx16Params) {
        ByteBuffer uncompressedRLEMetaData;
        if ((uncompressedRLEMetaDataLength & 1) != 0) {
            byte[] uncompressedRLEMetaDataArray = new byte[(uncompressedRLEMetaDataLength - 1) / 2];
            inBuffer.get(uncompressedRLEMetaDataArray, 0, (uncompressedRLEMetaDataLength - 1) / 2);
            uncompressedRLEMetaData = CompressionUtils.wrap(uncompressedRLEMetaDataArray);
        } else {
            int compressedRLEMetaDataLength = CompressionUtils.readUint7(inBuffer);
            byte[] compressedRLEMetaDataArray = new byte[compressedRLEMetaDataLength];
            inBuffer.get(compressedRLEMetaDataArray, 0, compressedRLEMetaDataLength);
            ByteBuffer compressedRLEMetaData = CompressionUtils.wrap(compressedRLEMetaDataArray);
            uncompressedRLEMetaData = CompressionUtils.allocateByteBuffer(uncompressedRLEMetaDataLength / 2);
            this.uncompressOrder0WayN(compressedRLEMetaData, uncompressedRLEMetaData, uncompressedRLEMetaDataLength / 2, new RANSNx16Params(0 | ransNx16Params.getFormatFlags() & 4));
        }
        int numRLESymbols = uncompressedRLEMetaData.get() & 0xFF;
        if (numRLESymbols == 0) {
            numRLESymbols = 256;
        }
        for (int i = 0; i < numRLESymbols; ++i) {
            rleSymbols[uncompressedRLEMetaData.get() & 0xFF] = 1;
        }
        return uncompressedRLEMetaData;
    }

    private ByteBuffer decodeRLE(ByteBuffer inBuffer, int[] rleSymbols, ByteBuffer uncompressedRLEMetaData, int uncompressedRLEOutputLength) {
        ByteBuffer rleOutBuffer = CompressionUtils.allocateByteBuffer(uncompressedRLEOutputLength);
        int j = 0;
        int i = 0;
        while (j < uncompressedRLEOutputLength) {
            byte sym = inBuffer.get(i);
            if (rleSymbols[sym & 0xFF] != 0) {
                int run = CompressionUtils.readUint7(uncompressedRLEMetaData);
                for (int r = 0; r <= run; ++r) {
                    rleOutBuffer.put(j++, sym);
                }
            } else {
                rleOutBuffer.put(j++, sym);
            }
            ++i;
        }
        return rleOutBuffer;
    }

    private ByteBuffer decodeStripe(ByteBuffer inBuffer, int outSize) {
        int numInterleaveStreams = inBuffer.get() & 0xFF;
        for (int j = 0; j < numInterleaveStreams; ++j) {
            CompressionUtils.readUint7(inBuffer);
        }
        int[] uncompressedLengths = new int[numInterleaveStreams];
        ByteBuffer[] transposedData = new ByteBuffer[numInterleaveStreams];
        for (int j = 0; j < numInterleaveStreams; ++j) {
            uncompressedLengths[j] = (int)Math.floor((double)outSize / (double)numInterleaveStreams);
            if (outSize % numInterleaveStreams > j) {
                int n = j;
                uncompressedLengths[n] = uncompressedLengths[n] + 1;
            }
            transposedData[j] = this.uncompress(inBuffer, uncompressedLengths[j]);
        }
        ByteBuffer outBuffer = CompressionUtils.allocateByteBuffer(outSize);
        for (int j = 0; j < numInterleaveStreams; ++j) {
            for (int i = 0; i < uncompressedLengths[j]; ++i) {
                outBuffer.put(i * numInterleaveStreams + j, transposedData[j].get(i));
            }
        }
        return outBuffer;
    }
}

