/*
 * Decompiled with CFR 0.152.
 */
package org.bouncycastle.pqc.crypto.slhdsa;

import org.bouncycastle.pqc.crypto.slhdsa.ADRS;
import org.bouncycastle.pqc.crypto.slhdsa.SLHDSAEngine;
import org.bouncycastle.util.Arrays;
import org.bouncycastle.util.Pack;

class WotsPlus {
    private final SLHDSAEngine engine;
    private final int w;

    WotsPlus(SLHDSAEngine engine) {
        this.engine = engine;
        this.w = this.engine.WOTS_W;
    }

    byte[] pkGen(byte[] skSeed, byte[] pkSeed, ADRS paramAdrs) {
        ADRS wotspkADRS = new ADRS(paramAdrs);
        byte[][] tmp = new byte[this.engine.WOTS_LEN][];
        int i = 0;
        while (i < this.engine.WOTS_LEN) {
            ADRS adrs = new ADRS(paramAdrs);
            adrs.setTypeAndClear(5);
            adrs.setKeyPairAddress(paramAdrs.getKeyPairAddress());
            adrs.setChainAddress(i);
            adrs.setHashAddress(0);
            byte[] sk = this.engine.PRF(pkSeed, skSeed, adrs);
            adrs.setTypeAndClear(0);
            adrs.setKeyPairAddress(paramAdrs.getKeyPairAddress());
            adrs.setChainAddress(i);
            adrs.setHashAddress(0);
            tmp[i] = this.chain(sk, 0, this.w - 1, pkSeed, adrs);
            ++i;
        }
        wotspkADRS.setTypeAndClear(1);
        wotspkADRS.setKeyPairAddress(paramAdrs.getKeyPairAddress());
        return this.engine.T_l(pkSeed, wotspkADRS, Arrays.concatenate(tmp));
    }

    byte[] chain(byte[] X, int i, int s, byte[] pkSeed, ADRS adrs) {
        if (s == 0) {
            return Arrays.clone(X);
        }
        if (i + s > this.w - 1) {
            return null;
        }
        byte[] result = X;
        int j = 0;
        while (j < s) {
            adrs.setHashAddress(i + j);
            result = this.engine.F(pkSeed, adrs, result);
            ++j;
        }
        return result;
    }

    public byte[] sign(byte[] M, byte[] skSeed, byte[] pkSeed, ADRS paramAdrs) {
        ADRS adrs = new ADRS(paramAdrs);
        int[] msg = new int[this.engine.WOTS_LEN];
        this.base_w(M, 0, this.w, msg, 0, this.engine.WOTS_LEN1);
        int csum = 0;
        int i = 0;
        while (i < this.engine.WOTS_LEN1) {
            csum += this.w - 1 - msg[i];
            ++i;
        }
        if (this.engine.WOTS_LOGW % 8 != 0) {
            csum <<= 8 - this.engine.WOTS_LEN2 * this.engine.WOTS_LOGW % 8;
        }
        int len_2_bytes = (this.engine.WOTS_LEN2 * this.engine.WOTS_LOGW + 7) / 8;
        byte[] csum_bytes = Pack.intToBigEndian(csum);
        this.base_w(csum_bytes, 4 - len_2_bytes, this.w, msg, this.engine.WOTS_LEN1, this.engine.WOTS_LEN2);
        byte[][] sig = new byte[this.engine.WOTS_LEN][];
        int i2 = 0;
        while (i2 < this.engine.WOTS_LEN) {
            adrs.setTypeAndClear(5);
            adrs.setKeyPairAddress(paramAdrs.getKeyPairAddress());
            adrs.setChainAddress(i2);
            adrs.setHashAddress(0);
            byte[] sk = this.engine.PRF(pkSeed, skSeed, adrs);
            adrs.setTypeAndClear(0);
            adrs.setKeyPairAddress(paramAdrs.getKeyPairAddress());
            adrs.setChainAddress(i2);
            adrs.setHashAddress(0);
            sig[i2] = this.chain(sk, 0, msg[i2], pkSeed, adrs);
            ++i2;
        }
        return Arrays.concatenate(sig);
    }

    void base_w(byte[] X, int XOff, int w, int[] output, int outOff, int outLen) {
        byte total = 0;
        int bits = 0;
        int consumed = 0;
        while (consumed < outLen) {
            if (bits == 0) {
                total = X[XOff++];
                bits += 8;
            }
            output[outOff++] = total >>> (bits -= this.engine.WOTS_LOGW) & w - 1;
            ++consumed;
        }
    }

    public byte[] pkFromSig(byte[] sig, byte[] M, byte[] pkSeed, ADRS adrs) {
        ADRS wotspkADRS = new ADRS(adrs);
        int[] msg = new int[this.engine.WOTS_LEN];
        this.base_w(M, 0, this.w, msg, 0, this.engine.WOTS_LEN1);
        int csum = 0;
        int i = 0;
        while (i < this.engine.WOTS_LEN1) {
            csum += this.w - 1 - msg[i];
            ++i;
        }
        int len_2_bytes = (this.engine.WOTS_LEN2 * this.engine.WOTS_LOGW + 7) / 8;
        byte[] csum_bytes = Pack.intToBigEndian(csum <<= 8 - this.engine.WOTS_LEN2 * this.engine.WOTS_LOGW % 8);
        this.base_w(csum_bytes, 4 - len_2_bytes, this.w, msg, this.engine.WOTS_LEN1, this.engine.WOTS_LEN2);
        byte[] sigI = new byte[this.engine.N];
        byte[][] tmp = new byte[this.engine.WOTS_LEN][];
        int i2 = 0;
        while (i2 < this.engine.WOTS_LEN) {
            adrs.setChainAddress(i2);
            System.arraycopy(sig, i2 * this.engine.N, sigI, 0, this.engine.N);
            tmp[i2] = this.chain(sigI, msg[i2], this.w - 1 - msg[i2], pkSeed, adrs);
            ++i2;
        }
        wotspkADRS.setTypeAndClear(1);
        wotspkADRS.setKeyPairAddress(adrs.getKeyPairAddress());
        return this.engine.T_l(pkSeed, wotspkADRS, Arrays.concatenate(tmp));
    }
}

