/*
 * Decompiled with CFR 0.152.
 */
package org.bouncycastle.pqc.crypto.sphincsplus;

import java.util.LinkedList;
import org.bouncycastle.pqc.crypto.sphincsplus.ADRS;
import org.bouncycastle.pqc.crypto.sphincsplus.NodeEntry;
import org.bouncycastle.pqc.crypto.sphincsplus.SIG_XMSS;
import org.bouncycastle.pqc.crypto.sphincsplus.SPHINCSPlusEngine;
import org.bouncycastle.pqc.crypto.sphincsplus.WotsPlus;
import org.bouncycastle.util.Arrays;

class HT {
    private final byte[] skSeed;
    private final byte[] pkSeed;
    SPHINCSPlusEngine engine;
    WotsPlus wots;
    final byte[] htPubKey;

    public HT(SPHINCSPlusEngine engine, byte[] skSeed, byte[] pkSeed) {
        this.skSeed = skSeed;
        this.pkSeed = pkSeed;
        this.engine = engine;
        this.wots = new WotsPlus(engine);
        ADRS adrs = new ADRS();
        adrs.setLayerAddress(engine.D - 1);
        adrs.setTreeAddress(0L);
        this.htPubKey = (byte[])(skSeed != null ? this.xmss_PKgen(skSeed, pkSeed, adrs) : null);
    }

    byte[] sign(byte[] M, long idx_tree, int idx_leaf) {
        ADRS adrs = new ADRS();
        adrs.setLayerAddress(0);
        adrs.setTreeAddress(idx_tree);
        SIG_XMSS SIG_tmp = this.xmss_sign(M, this.skSeed, idx_leaf, this.pkSeed, adrs);
        SIG_XMSS[] SIG_HT = new SIG_XMSS[this.engine.D];
        SIG_HT[0] = SIG_tmp;
        adrs.setLayerAddress(0);
        adrs.setTreeAddress(idx_tree);
        byte[] root = this.xmss_pkFromSig(idx_leaf, SIG_tmp, M, this.pkSeed, adrs);
        int j = 1;
        while (j < this.engine.D) {
            idx_leaf = (int)(idx_tree & (long)((1 << this.engine.H_PRIME) - 1));
            adrs.setLayerAddress(j);
            adrs.setTreeAddress(idx_tree >>>= this.engine.H_PRIME);
            SIG_HT[j] = SIG_tmp = this.xmss_sign(root, this.skSeed, idx_leaf, this.pkSeed, adrs);
            if (j < this.engine.D - 1) {
                root = this.xmss_pkFromSig(idx_leaf, SIG_tmp, root, this.pkSeed, adrs);
            }
            ++j;
        }
        byte[][] totSigs = new byte[SIG_HT.length][];
        int i = 0;
        while (i != totSigs.length) {
            totSigs[i] = Arrays.concatenate(SIG_HT[i].sig, Arrays.concatenate(SIG_HT[i].auth));
            ++i;
        }
        return Arrays.concatenate(totSigs);
    }

    byte[] xmss_PKgen(byte[] skSeed, byte[] pkSeed, ADRS adrs) {
        return this.treehash(skSeed, 0, this.engine.H_PRIME, pkSeed, adrs);
    }

    byte[] xmss_pkFromSig(int idx, SIG_XMSS sig_xmss, byte[] M, byte[] pkSeed, ADRS paramAdrs) {
        ADRS adrs = new ADRS(paramAdrs);
        adrs.setTypeAndClear(0);
        adrs.setKeyPairAddress(idx);
        byte[] sig = sig_xmss.getWOTSSig();
        byte[][] AUTH = sig_xmss.getXMSSAUTH();
        byte[] node0 = this.wots.pkFromSig(sig, M, pkSeed, adrs);
        byte[] node1 = null;
        adrs.setTypeAndClear(2);
        adrs.setTreeIndex(idx);
        int k = 0;
        while (k < this.engine.H_PRIME) {
            adrs.setTreeHeight(k + 1);
            if (idx / (1 << k) % 2 == 0) {
                adrs.setTreeIndex(adrs.getTreeIndex() / 2);
                node1 = this.engine.H(pkSeed, adrs, node0, AUTH[k]);
            } else {
                adrs.setTreeIndex((adrs.getTreeIndex() - 1) / 2);
                node1 = this.engine.H(pkSeed, adrs, AUTH[k], node0);
            }
            node0 = node1;
            ++k;
        }
        return node0;
    }

    SIG_XMSS xmss_sign(byte[] M, byte[] skSeed, int idx, byte[] pkSeed, ADRS paramAdrs) {
        byte[][] AUTH = new byte[this.engine.H_PRIME][];
        ADRS adrs = new ADRS(paramAdrs);
        adrs.setTypeAndClear(2);
        adrs.setLayerAddress(paramAdrs.getLayerAddress());
        adrs.setTreeAddress(paramAdrs.getTreeAddress());
        int j = 0;
        while (j < this.engine.H_PRIME) {
            int k = idx >>> j ^ 1;
            AUTH[j] = this.treehash(skSeed, k << j, j, pkSeed, adrs);
            ++j;
        }
        adrs = new ADRS(paramAdrs);
        adrs.setTypeAndClear(0);
        adrs.setKeyPairAddress(idx);
        byte[] sig = this.wots.sign(M, skSeed, pkSeed, adrs);
        return new SIG_XMSS(sig, AUTH);
    }

    byte[] treehash(byte[] skSeed, int s, int z, byte[] pkSeed, ADRS adrsParam) {
        if (s >>> z << z != s) {
            return null;
        }
        LinkedList<NodeEntry> stack = new LinkedList<NodeEntry>();
        ADRS adrs = new ADRS(adrsParam);
        int idx = 0;
        while (idx < 1 << z) {
            adrs.setTypeAndClear(0);
            adrs.setKeyPairAddress(s + idx);
            byte[] node = this.wots.pkGen(skSeed, pkSeed, adrs);
            adrs.setTypeAndClear(2);
            adrs.setTreeHeight(1);
            adrs.setTreeIndex(s + idx);
            int adrsTreeHeight = 1;
            int adrsTreeIndex = s + idx;
            while (!stack.isEmpty() && ((NodeEntry)stack.get((int)0)).nodeHeight == adrsTreeHeight) {
                adrsTreeIndex = (adrsTreeIndex - 1) / 2;
                adrs.setTreeIndex(adrsTreeIndex);
                NodeEntry current = (NodeEntry)stack.remove(0);
                node = this.engine.H(pkSeed, adrs, current.nodeValue, node);
                adrs.setTreeHeight(++adrsTreeHeight);
            }
            stack.add(0, new NodeEntry(node, adrsTreeHeight));
            ++idx;
        }
        return ((NodeEntry)stack.get((int)0)).nodeValue;
    }

    public boolean verify(byte[] M, SIG_XMSS[] sig_ht, byte[] pkSeed, long idx_tree, int idx_leaf, byte[] PK_HT) {
        ADRS adrs = new ADRS();
        SIG_XMSS SIG_tmp = sig_ht[0];
        adrs.setLayerAddress(0);
        adrs.setTreeAddress(idx_tree);
        byte[] node = this.xmss_pkFromSig(idx_leaf, SIG_tmp, M, pkSeed, adrs);
        int j = 1;
        while (j < this.engine.D) {
            idx_leaf = (int)(idx_tree & (long)((1 << this.engine.H_PRIME) - 1));
            SIG_tmp = sig_ht[j];
            adrs.setLayerAddress(j);
            adrs.setTreeAddress(idx_tree >>>= this.engine.H_PRIME);
            node = this.xmss_pkFromSig(idx_leaf, SIG_tmp, node, pkSeed, adrs);
            ++j;
        }
        return Arrays.areEqual(PK_HT, node);
    }
}

