import { Fingerprint } from '../../dataService';
import { normalize_values_average } from './utils';

const embeddingV2Magic = Float32Array.from([~0xdead, 0b0001]);

export const checkIsV2Embedding = (vec: Fingerprint): boolean => {
    if (vec.length <= embeddingV2Magic.length) {
        return false;
    }

    const pld_len = vec.length - embeddingV2Magic.length;

    // models use 256, 384, 512, ...
    if (pld_len % 128 !== 0) {
        return false;
    }

    for (let i = 0; i < embeddingV2Magic.length; i++) {
        if (vec[i] !== embeddingV2Magic[i]) {
            return false;
        }
    }

    return true;
};

export const unpackV2Embedding = (vec: Fingerprint): Float32Array => {
    const pld_len = vec.length - embeddingV2Magic.length;

    const pld = new Float32Array(pld_len);

    for (let i = 0; i < pld_len; i++) {
        pld[i] = vec[i + embeddingV2Magic.length];
    }

    return pld;
};

const cosineSimilarityF32 = (a: Float32Array, b: Float32Array): number => {
    const len = a.length;

    let dotProduct = 0;
    let magnitudeA = 0;
    let magnitudeB = 0;

    for (let i = 0; i < len; i += 4) {
        const ai0 = a[i];
        const ai1 = a[i + 1];
        const ai2 = a[i + 2];
        const ai3 = a[i + 3];

        const bi0 = b[i];
        const bi1 = b[i + 1];
        const bi2 = b[i + 2];
        const bi3 = b[i + 3];

        dotProduct += ai0 * bi0 + ai1 * bi1 + ai2 * bi2 + ai3 * bi3;
        magnitudeA += ai0 * ai0 + ai1 * ai1 + ai2 * ai2 + ai3 * ai3;
        magnitudeB += bi0 * bi0 + bi1 * bi1 + bi2 * bi2 + bi3 * bi3;
    }

    const remainingStart = len - (len % 4);

    for (let i = remainingStart; i < len; i++) {
        const ai = a[i];
        const bi = b[i];

        dotProduct += ai * bi;
        magnitudeA += ai * ai;
        magnitudeB += bi * bi;
    }

    const sqrtMagnitudeA = Math.sqrt(magnitudeA);
    const sqrtMagnitudeB = Math.sqrt(magnitudeB);

    if (sqrtMagnitudeA === 0 || sqrtMagnitudeB === 0) {
        return 0;
    }

    return dotProduct / (sqrtMagnitudeA * sqrtMagnitudeB);
};

export function calculateV2EmbeddingSimilarity(
    referenceFingerprint: Float32Array,
    positiveFingerprints: Float32Array[] = [],
    negativeFingerprints: Float32Array[] = [],
): number {
    const positiveSimilarities = positiveFingerprints.map((fp) =>
        cosineSimilarityF32(fp, referenceFingerprint),
    );

    const negativeSimilarities = negativeFingerprints.map((fp) =>
        cosineSimilarityF32(fp, referenceFingerprint),
    );

    const positive_score = normalize_values_average(positiveSimilarities);

    const negative_score = normalize_values_average(negativeSimilarities);

    return (positive_score - negative_score) / 2 + 0.5;
}
