/*
 * Decompiled with CFR 0.152.
 */
package org.bouncycastle.pqc.legacy.crypto.ntru;

import org.bouncycastle.crypto.Digest;
import org.bouncycastle.pqc.legacy.crypto.ntru.NTRUEncryptionParameters;
import org.bouncycastle.util.Arrays;

public class IndexGenerator {
    private byte[] seed;
    private int N;
    private int c;
    private int minCallsR;
    private int totLen;
    private int remLen;
    private BitString buf;
    private int counter;
    private boolean initialized;
    private Digest hashAlg;
    private int hLen;

    IndexGenerator(byte[] seed, NTRUEncryptionParameters params) {
        this.seed = seed;
        this.N = params.N;
        this.c = params.c;
        this.minCallsR = params.minCallsR;
        this.totLen = 0;
        this.remLen = 0;
        this.counter = 0;
        this.hashAlg = params.hashAlg;
        this.hLen = this.hashAlg.getDigestSize();
        this.initialized = false;
    }

    int nextIndex() {
        BitString M;
        int i;
        if (!this.initialized) {
            this.buf = new BitString();
            byte[] hash = new byte[this.hashAlg.getDigestSize()];
            while (this.counter < this.minCallsR) {
                this.appendHash(this.buf, hash);
                ++this.counter;
            }
            this.remLen = this.totLen = this.minCallsR * 8 * this.hLen;
            this.initialized = true;
        }
        do {
            this.totLen += this.c;
            M = this.buf.getTrailing(this.remLen);
            if (this.remLen < this.c) {
                int tmpLen = this.c - this.remLen;
                int cThreshold = this.counter + (tmpLen + this.hLen - 1) / this.hLen;
                byte[] hash = new byte[this.hashAlg.getDigestSize()];
                while (this.counter < cThreshold) {
                    this.appendHash(M, hash);
                    ++this.counter;
                    if (tmpLen <= 8 * this.hLen) continue;
                    tmpLen -= 8 * this.hLen;
                }
                this.remLen = 8 * this.hLen - tmpLen;
                this.buf = new BitString();
                this.buf.appendBits(hash);
                continue;
            }
            this.remLen -= this.c;
        } while ((i = M.getLeadingAsInt(this.c)) >= (1 << this.c) - (1 << this.c) % this.N);
        return i % this.N;
    }

    private void appendHash(BitString m, byte[] hash) {
        this.hashAlg.update(this.seed, 0, this.seed.length);
        this.putInt(this.hashAlg, this.counter);
        this.hashAlg.doFinal(hash, 0);
        m.appendBits(hash);
    }

    private void putInt(Digest hashAlg, int counter) {
        hashAlg.update((byte)(counter >> 24));
        hashAlg.update((byte)(counter >> 16));
        hashAlg.update((byte)(counter >> 8));
        hashAlg.update((byte)counter);
    }

    private static byte[] copyOf(byte[] src, int len) {
        byte[] tmp = new byte[len];
        System.arraycopy(src, 0, tmp, 0, len < src.length ? len : src.length);
        return tmp;
    }

    public static class BitString {
        byte[] bytes = new byte[4];
        int numBytes;
        int lastByteBits;

        void appendBits(byte[] bytes) {
            int i = 0;
            while (i != bytes.length) {
                this.appendBits(bytes[i]);
                ++i;
            }
        }

        public void appendBits(byte b) {
            if (this.numBytes == this.bytes.length) {
                this.bytes = IndexGenerator.copyOf(this.bytes, 2 * this.bytes.length);
            }
            if (this.numBytes == 0) {
                this.numBytes = 1;
                this.bytes[0] = b;
                this.lastByteBits = 8;
            } else if (this.lastByteBits == 8) {
                this.bytes[this.numBytes++] = b;
            } else {
                int s = 8 - this.lastByteBits;
                int n = this.numBytes - 1;
                this.bytes[n] = (byte)(this.bytes[n] | (b & 0xFF) << this.lastByteBits);
                this.bytes[this.numBytes++] = (byte)((b & 0xFF) >> s);
            }
        }

        public BitString getTrailing(int numBits) {
            BitString newStr = new BitString();
            newStr.numBytes = (numBits + 7) / 8;
            newStr.bytes = new byte[newStr.numBytes];
            int i = 0;
            while (i < newStr.numBytes) {
                newStr.bytes[i] = this.bytes[i];
                ++i;
            }
            newStr.lastByteBits = numBits % 8;
            if (newStr.lastByteBits == 0) {
                newStr.lastByteBits = 8;
            } else {
                int s = 32 - newStr.lastByteBits;
                newStr.bytes[newStr.numBytes - 1] = (byte)(newStr.bytes[newStr.numBytes - 1] << s >>> s);
            }
            return newStr;
        }

        public int getLeadingAsInt(int numBits) {
            int startBit = (this.numBytes - 1) * 8 + this.lastByteBits - numBits;
            int startByte = startBit / 8;
            int startBitInStartByte = startBit % 8;
            int sum = (this.bytes[startByte] & 0xFF) >>> startBitInStartByte;
            int shift = 8 - startBitInStartByte;
            int i = startByte + 1;
            while (i < this.numBytes) {
                sum |= (this.bytes[i] & 0xFF) << shift;
                shift += 8;
                ++i;
            }
            return sum;
        }

        public byte[] getBytes() {
            return Arrays.clone(this.bytes);
        }
    }
}

