package com.github.tjake.jlama.math;

/* loaded from: input_file:com/github/tjake/jlama/math/FloatConversions.class */
public class FloatConversions {
    static short bFloat16NaN = 32641;
    private static final short SIGN_MASK = Short.MIN_VALUE;
    private static final short EXP_MASK = 31744;
    private static final short NAN_VALUE = Short.MAX_VALUE;

    public static float bFloat16ToFloat32(short s) {
        return Float.intBitsToFloat(s << 16);
    }

    public static short float32ToBFloat16(float f) {
        int floatToRawIntBits = Float.floatToRawIntBits(f);
        int i = (floatToRawIntBits >>> 16) & 32768;
        int i2 = (floatToRawIntBits >>> 16) & 32640;
        int i3 = floatToRawIntBits & 8388607;
        return i2 != 32640 ? (short) (i | (i2 + round(i3, 16))) : i3 != 0 ? bFloat16NaN : (short) (floatToRawIntBits >>> 16);
    }

    static int round(int i, int i2) {
        int i3 = i >> i2;
        int i4 = (i & ((1 << i2) - 1)) - (1 << (i2 - 1));
        if (i4 > 0) {
            return i3 + 1;
        }
        if (i4 < 0) {
            return i3;
        }
        return (i3 & 1) != 0 ? i3 + 1 : i3;
    }

    public static float float16ToFloat32Alt(short s) {
        long unsignedLong = Integer.toUnsignedLong(s << 16);
        long j = unsignedLong & 2147483648L;
        long j2 = unsignedLong & 2147483647L;
        int numberOfLeadingZeros = Long.numberOfLeadingZeros(j2);
        return Float.intBitsToFloat((int) (j | ((((j2 << (numberOfLeadingZeros > 37 ? numberOfLeadingZeros - 37 : 0)) >> 3) + ((112 - r16) << 23)) & (((j2 - 1) >> 63) ^ (-1)))));
    }

    private static boolean IS_ZERO(short s) {
        return (s & NAN_VALUE) == 0;
    }

    private static boolean IS_INVALID(short s) {
        return (s & EXP_MASK) == EXP_MASK;
    }

    private static boolean IS_NAN(short s) {
        return (s & NAN_VALUE) > EXP_MASK;
    }

    private static boolean IS_INF(short s) {
        return (s & NAN_VALUE) == EXP_MASK;
    }

    private static short MANTISSA(short s) {
        return (short) ((s & 1023) | ((s & EXP_MASK) == 0 ? 0 : 1024));
    }

    private static short EXPONENT(short s) {
        return (short) ((s & EXP_MASK) >> 10);
    }

    private static short SIGNED_INF_VALUE(short s) {
        return (short) ((s & SIGN_MASK) | EXP_MASK);
    }

    public static short subIeeeFloat16(short s, short s2) {
        short s3;
        if (((s ^ s2) & 32768) != 0) {
            return addIeeeFloat16(s, (short) (s2 ^ 32768));
        }
        short s4 = (short) (s & 32768);
        short s5 = (short) (s << 1);
        short s6 = (short) (s2 << 1);
        if (s5 < s6) {
            s5 = s6;
            s6 = s5;
            s4 = (short) (s4 ^ 32768);
        }
        short s7 = (short) (s5 & 63488);
        short s8 = (short) (s6 & 63488);
        if (s5 >= 63488 || s6 >= 63488) {
            if (s5 > 63488 || s6 > 63488 || s5 == s6) {
                return Short.MAX_VALUE;
            }
            short s9 = (short) (s4 | EXP_MASK);
            return s5 == 63488 ? s9 : (short) (s9 ^ 32768);
        }
        int i = s7 - s8;
        short s10 = s7;
        if (i != 0) {
            int i2 = i >> 11;
            s3 = s8 != 0 ? (short) (((s6 & 2047) | 2048) >> i2) : (short) (s6 >> (i2 - 1));
        } else {
            if (s8 == 0) {
                short s11 = (short) ((s5 - s6) >> 1);
                return s11 == 0 ? s11 : (short) (s11 | s4);
            }
            s3 = (short) ((s6 & 2047) | 2048);
        }
        short s12 = (short) (s5 - s3);
        if ((s12 & 63488) == s10) {
            return (short) ((s12 >> 1) | s4);
        }
        short s13 = (short) (((short) ((s5 & 2047) | 2048)) - s3);
        if (s13 == 0) {
            return (short) 0;
        }
        while (s10 != 0 && (s13 & 2048) == 0) {
            s10 = (short) (s10 - 2048);
            if (s10 != 0) {
                s13 = (short) (s13 << 1);
            }
        }
        return (short) ((((s13 & 2047) | s10) >> 1) | s4);
    }

    public static short addIeeeFloat16(short s, short s2) {
        short s3;
        if (((s ^ s2) & 32768) != 0) {
            return subIeeeFloat16(s, (short) (s2 ^ 32768));
        }
        short s4 = (short) (s & 32768);
        short s5 = (short) (s & NAN_VALUE);
        short s6 = (short) (s2 & NAN_VALUE);
        if (s5 < s6) {
            s5 = s6;
            s6 = s5;
        }
        if (s5 >= EXP_MASK || s6 >= EXP_MASK) {
            if (s5 > EXP_MASK || s6 > EXP_MASK) {
                return Short.MAX_VALUE;
            }
            return (short) (EXP_MASK | s4);
        }
        short s7 = (short) (s5 & EXP_MASK);
        short s8 = (short) (s6 & EXP_MASK);
        short s9 = (short) (s7 - s8);
        if (s9 != 0) {
            short s10 = (short) (s9 >> 10);
            s3 = s8 != 0 ? (short) (((s6 & 1023) | 1024) >> s10) : (short) (s6 >> (s10 - 1));
        } else {
            if (s8 == 0) {
                return (short) ((s5 + s6) | s4);
            }
            s3 = (short) ((s6 & 1023) | 1024);
        }
        short s11 = (short) (s5 + s3);
        if ((s11 & EXP_MASK) != s7) {
            s11 = (short) ((s7 + 1024) | (1023 & ((short) ((((short) ((s5 & 1023) | 1024)) + s3) >> 1))));
        }
        return (s11 & 65535) >= EXP_MASK ? (short) (s4 | EXP_MASK) : (short) (s11 | s4);
    }

    public static short mulIeeeFloat16(short s, short s2) {
        int i = (s ^ s2) & SIGN_MASK;
        if (IS_INVALID(s) || IS_INVALID(s2)) {
            if (IS_NAN(s) || IS_NAN(s2) || IS_ZERO(s) || IS_ZERO(s2)) {
                return Short.MAX_VALUE;
            }
            return (short) (i | EXP_MASK);
        }
        if (IS_ZERO(s) || IS_ZERO(s2)) {
            return (short) 0;
        }
        long MANTISSA = MANTISSA(s) * MANTISSA(s2);
        short EXPONENT = EXPONENT(s);
        short EXPONENT2 = EXPONENT(s2);
        int i2 = ((EXPONENT + (EXPONENT == 0 ? (short) 1 : (short) 0)) + (EXPONENT2 + (EXPONENT2 == 0 ? (short) 1 : (short) 0))) - 15;
        if ((MANTISSA & 2097152) != 0) {
            MANTISSA >>= 11;
            i2++;
        } else if ((MANTISSA & 1048576) != 0) {
            MANTISSA >>= 10;
        } else {
            i2 -= 10;
            while (MANTISSA >= 2048) {
                MANTISSA >>= 1;
                i2++;
            }
        }
        if (i2 <= 0) {
            MANTISSA >>= (-i2) + 1;
            i2 = 0;
        } else if (i2 >= 31) {
            return SIGNED_INF_VALUE((short) i);
        }
        return (short) (i | (i2 << 10) | (MANTISSA & 1023));
    }
}
