/*
 * Decompiled with CFR 0.152.
 */
package org.bouncycastle.pqc.crypto.mlkem;

import org.bouncycastle.pqc.crypto.mlkem.CBD;
import org.bouncycastle.pqc.crypto.mlkem.MLKEMEngine;
import org.bouncycastle.pqc.crypto.mlkem.Ntt;
import org.bouncycastle.pqc.crypto.mlkem.Reduce;
import org.bouncycastle.pqc.crypto.mlkem.Symmetric;

class Poly {
    private short[] coeffs = new short[256];
    private MLKEMEngine engine;
    private int polyCompressedBytes;
    private int eta1;
    private int eta2;
    private Symmetric symmetric;

    public Poly(MLKEMEngine engine) {
        this.engine = engine;
        this.polyCompressedBytes = engine.getKyberPolyCompressedBytes();
        this.eta1 = engine.getKyberEta1();
        this.eta2 = MLKEMEngine.getKyberEta2();
        this.symmetric = engine.getSymmetric();
    }

    public short getCoeffIndex(int i) {
        return this.coeffs[i];
    }

    public short[] getCoeffs() {
        return this.coeffs;
    }

    public void setCoeffIndex(int i, short val) {
        this.coeffs[i] = val;
    }

    public void setCoeffs(short[] coeffs) {
        this.coeffs = coeffs;
    }

    public void polyNtt() {
        this.setCoeffs(Ntt.ntt(this.getCoeffs()));
        this.reduce();
    }

    public void polyInverseNttToMont() {
        this.setCoeffs(Ntt.invNtt(this.getCoeffs()));
    }

    public void reduce() {
        int i = 0;
        while (i < 256) {
            this.setCoeffIndex(i, Reduce.barretReduce(this.getCoeffIndex(i)));
            ++i;
        }
    }

    public static void baseMultMontgomery(Poly r, Poly a, Poly b) {
        int i = 0;
        while (i < 64) {
            Ntt.baseMult(r, 4 * i, a.getCoeffIndex(4 * i), a.getCoeffIndex(4 * i + 1), b.getCoeffIndex(4 * i), b.getCoeffIndex(4 * i + 1), Ntt.nttZetas[64 + i]);
            Ntt.baseMult(r, 4 * i + 2, a.getCoeffIndex(4 * i + 2), a.getCoeffIndex(4 * i + 3), b.getCoeffIndex(4 * i + 2), b.getCoeffIndex(4 * i + 3), (short)(-1 * Ntt.nttZetas[64 + i]));
            ++i;
        }
    }

    public void addCoeffs(Poly b) {
        int i = 0;
        while (i < 256) {
            this.setCoeffIndex(i, (short)(this.getCoeffIndex(i) + b.getCoeffIndex(i)));
            ++i;
        }
    }

    public void convertToMont() {
        int f = 1353;
        int i = 0;
        while (i < 256) {
            this.setCoeffIndex(i, Reduce.montgomeryReduce(this.getCoeffIndex(i) * 1353));
            ++i;
        }
    }

    public byte[] compressPoly() {
        byte[] t = new byte[8];
        byte[] r = new byte[this.polyCompressedBytes];
        int count = 0;
        this.conditionalSubQ();
        if (this.polyCompressedBytes == 128) {
            int i = 0;
            while (i < 32) {
                int j = 0;
                while (j < 8) {
                    int t_j = this.getCoeffIndex(8 * i + j);
                    t_j <<= 4;
                    t_j += 1665;
                    t_j *= 80635;
                    t_j >>= 28;
                    t[j] = (byte)(t_j &= 0xF);
                    ++j;
                }
                r[count + 0] = (byte)(t[0] | t[1] << 4);
                r[count + 1] = (byte)(t[2] | t[3] << 4);
                r[count + 2] = (byte)(t[4] | t[5] << 4);
                r[count + 3] = (byte)(t[6] | t[7] << 4);
                count += 4;
                ++i;
            }
        } else if (this.polyCompressedBytes == 160) {
            int i = 0;
            while (i < 32) {
                int j = 0;
                while (j < 8) {
                    int t_j = this.getCoeffIndex(8 * i + j);
                    t_j <<= 5;
                    t_j += 1664;
                    t_j *= 40318;
                    t_j >>= 27;
                    t[j] = (byte)(t_j &= 0x1F);
                    ++j;
                }
                r[count + 0] = (byte)(t[0] >> 0 | t[1] << 5);
                r[count + 1] = (byte)(t[1] >> 3 | t[2] << 2 | t[3] << 7);
                r[count + 2] = (byte)(t[3] >> 1 | t[4] << 4);
                r[count + 3] = (byte)(t[4] >> 4 | t[5] << 1 | t[6] << 6);
                r[count + 4] = (byte)(t[6] >> 2 | t[7] << 3);
                count += 5;
                ++i;
            }
        } else {
            throw new RuntimeException("PolyCompressedBytes is neither 128 or 160!");
        }
        return r;
    }

    public void decompressPoly(byte[] compressedPolyCipherText) {
        int count = 0;
        if (this.engine.getKyberPolyCompressedBytes() == 128) {
            int i = 0;
            while (i < 128) {
                this.setCoeffIndex(2 * i + 0, (short)((short)(compressedPolyCipherText[count] & 0xFF & 0xF) * 3329 + 8 >> 4));
                this.setCoeffIndex(2 * i + 1, (short)((short)((compressedPolyCipherText[count] & 0xFF) >> 4) * 3329 + 8 >> 4));
                ++count;
                ++i;
            }
        } else if (this.engine.getKyberPolyCompressedBytes() == 160) {
            byte[] t = new byte[8];
            int i = 0;
            while (i < 32) {
                t[0] = (byte)((compressedPolyCipherText[count + 0] & 0xFF) >> 0);
                t[1] = (byte)((compressedPolyCipherText[count + 0] & 0xFF) >> 5 | (compressedPolyCipherText[count + 1] & 0xFF) << 3);
                t[2] = (byte)((compressedPolyCipherText[count + 1] & 0xFF) >> 2);
                t[3] = (byte)((compressedPolyCipherText[count + 1] & 0xFF) >> 7 | (compressedPolyCipherText[count + 2] & 0xFF) << 1);
                t[4] = (byte)((compressedPolyCipherText[count + 2] & 0xFF) >> 4 | (compressedPolyCipherText[count + 3] & 0xFF) << 4);
                t[5] = (byte)((compressedPolyCipherText[count + 3] & 0xFF) >> 1);
                t[6] = (byte)((compressedPolyCipherText[count + 3] & 0xFF) >> 6 | (compressedPolyCipherText[count + 4] & 0xFF) << 2);
                t[7] = (byte)((compressedPolyCipherText[count + 4] & 0xFF) >> 3);
                count += 5;
                int j = 0;
                while (j < 8) {
                    this.setCoeffIndex(8 * i + j, (short)((t[j] & 0x1F) * 3329 + 16 >> 5));
                    ++j;
                }
                ++i;
            }
        } else {
            throw new RuntimeException("PolyCompressedBytes is neither 128 or 160!");
        }
    }

    public byte[] toBytes() {
        byte[] r = new byte[384];
        this.conditionalSubQ();
        int i = 0;
        while (i < 128) {
            short t0 = this.getCoeffIndex(2 * i);
            short t1 = this.getCoeffIndex(2 * i + 1);
            r[3 * i] = (byte)(t0 >> 0);
            r[3 * i + 1] = (byte)(t0 >> 8 | t1 << 4);
            r[3 * i + 2] = (byte)(t1 >> 4);
            ++i;
        }
        return r;
    }

    public void fromBytes(byte[] inpBytes) {
        int i = 0;
        while (i < 128) {
            this.setCoeffIndex(2 * i, (short)(((inpBytes[3 * i + 0] & 0xFF) >> 0 | (inpBytes[3 * i + 1] & 0xFF) << 8) & 0xFFF));
            this.setCoeffIndex(2 * i + 1, (short)(((long)((inpBytes[3 * i + 1] & 0xFF) >> 4) | (long)((inpBytes[3 * i + 2] & 0xFF) << 4)) & 0xFFFL));
            ++i;
        }
    }

    public byte[] toMsg() {
        int LOWER = 832;
        int UPPER = 3329 - LOWER;
        byte[] outMsg = new byte[MLKEMEngine.getKyberIndCpaMsgBytes()];
        this.conditionalSubQ();
        int i = 0;
        while (i < 32) {
            outMsg[i] = 0;
            int j = 0;
            while (j < 8) {
                short c_j = this.getCoeffIndex(8 * i + j);
                int t = (LOWER - c_j & c_j - UPPER) >>> 31;
                int n = i;
                outMsg[n] = (byte)(outMsg[n] | (byte)(t << j));
                ++j;
            }
            ++i;
        }
        return outMsg;
    }

    public void fromMsg(byte[] msg) {
        if (msg.length != 32) {
            throw new RuntimeException("KYBER_INDCPA_MSGBYTES must be equal to KYBER_N/8 bytes!");
        }
        int i = 0;
        while (i < 32) {
            int j = 0;
            while (j < 8) {
                short mask = (short)(-1 * (short)((msg[i] & 0xFF) >> j & 1));
                this.setCoeffIndex(8 * i + j, (short)(mask & 0x681));
                ++j;
            }
            ++i;
        }
    }

    public void conditionalSubQ() {
        int i = 0;
        while (i < 256) {
            this.setCoeffIndex(i, Reduce.conditionalSubQ(this.getCoeffIndex(i)));
            ++i;
        }
    }

    public void getEta1Noise(byte[] seed, byte nonce) {
        byte[] buf = new byte[256 * this.eta1 / 4];
        this.symmetric.prf(buf, seed, nonce);
        CBD.mlkemCBD(this, buf, this.eta1);
    }

    public void getEta2Noise(byte[] seed, byte nonce) {
        byte[] buf = new byte[256 * this.eta2 / 4];
        this.symmetric.prf(buf, seed, nonce);
        CBD.mlkemCBD(this, buf, this.eta2);
    }

    public void polySubtract(Poly b) {
        int i = 0;
        while (i < 256) {
            this.setCoeffIndex(i, (short)(b.getCoeffIndex(i) - this.getCoeffIndex(i)));
            ++i;
        }
    }

    public String toString() {
        StringBuffer out = new StringBuffer();
        out.append("[");
        int i = 0;
        while (i < this.coeffs.length) {
            out.append(this.coeffs[i]);
            if (i != this.coeffs.length - 1) {
                out.append(", ");
            }
            ++i;
        }
        out.append("]");
        return out.toString();
    }
}

