/*
 * Decompiled with CFR 0.152.
 */
package org.apache.lucene.benchmark.jmh;

import java.io.Closeable;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.file.Files;
import java.nio.file.attribute.FileAttribute;
import java.util.Random;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues;
import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.store.MMapDirectory;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Level;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.TearDown;
import org.openjdk.jmh.annotations.Warmup;

@BenchmarkMode(value={Mode.Throughput})
@OutputTimeUnit(value=TimeUnit.MICROSECONDS)
@State(value=Scope.Benchmark)
@Warmup(iterations=4, time=1)
@Measurement(iterations=5, time=1)
@Fork(value=3, jvmArgsAppend={"-Xmx2g", "-Xms2g", "-XX:+AlwaysPreTouch"})
public class VectorScorerBenchmark {
    private static final float EPSILON = 1.0E-4f;
    @Param(value={"1", "128", "207", "256", "300", "512", "702", "1024"})
    public int size;
    @Param(value={"0", "1", "4", "64"})
    public int padBytes;
    Directory dir;
    IndexInput bytesIn;
    IndexInput floatsIn;
    KnnVectorValues byteVectorValues;
    KnnVectorValues floatVectorValues;
    byte[] vec1;
    byte[] vec2;
    float[] floatsA;
    float[] floatsB;
    float expectedBytes;
    float expectedFloats;
    UpdateableRandomVectorScorer byteScorer;
    UpdateableRandomVectorScorer floatScorer;

    @Setup(value=Level.Iteration)
    public void init() throws IOException {
        ThreadLocalRandom random = ThreadLocalRandom.current();
        this.vec1 = new byte[this.size];
        this.vec2 = new byte[this.size];
        random.nextBytes(this.vec1);
        random.nextBytes(this.vec2);
        this.expectedBytes = VectorSimilarityFunction.DOT_PRODUCT.compare(this.vec1, this.vec2);
        this.floatsA = new float[this.size];
        this.floatsB = new float[this.size];
        for (int i = 0; i < this.size; ++i) {
            this.floatsA[i] = ((Random)random).nextFloat();
            this.floatsB[i] = ((Random)random).nextFloat();
        }
        this.expectedFloats = VectorSimilarityFunction.DOT_PRODUCT.compare(this.floatsA, this.floatsB);
        this.dir = new MMapDirectory(Files.createTempDirectory("VectorScorerBenchmark", new FileAttribute[0]));
        try (IndexOutput out = this.dir.createOutput("byteVector.data", IOContext.DEFAULT);){
            out.writeBytes(new byte[this.padBytes], 0, this.padBytes);
            out.writeBytes(this.vec1, 0, this.vec1.length);
            out.writeBytes(this.vec2, 0, this.vec2.length);
        }
        out = this.dir.createOutput("floatVector.data", IOContext.DEFAULT);
        try {
            out.writeBytes(new byte[this.padBytes], 0, this.padBytes);
            byte[] buffer = new byte[this.size * 4];
            ByteBuffer.wrap(buffer).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer().put(this.floatsA);
            out.writeBytes(buffer, 0, buffer.length);
            ByteBuffer.wrap(buffer).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer().put(this.floatsB);
            out.writeBytes(buffer, 0, buffer.length);
        }
        finally {
            if (out != null) {
                out.close();
            }
        }
        this.bytesIn = this.dir.openInput("byteVector.data", IOContext.DEFAULT);
        this.byteVectorValues = this.byteVectorValues(VectorSimilarityFunction.DOT_PRODUCT);
        this.byteScorer = FlatVectorScorerUtil.getLucene99FlatVectorsScorer().getRandomVectorScorerSupplier(VectorSimilarityFunction.DOT_PRODUCT, this.byteVectorValues).scorer();
        this.byteScorer.setScoringOrdinal(0);
        this.floatsIn = this.dir.openInput("floatVector.data", IOContext.DEFAULT);
        this.floatVectorValues = this.floatVectorValues(VectorSimilarityFunction.DOT_PRODUCT);
        this.floatScorer = FlatVectorScorerUtil.getLucene99FlatVectorsScorer().getRandomVectorScorerSupplier(VectorSimilarityFunction.DOT_PRODUCT, this.floatVectorValues).scorer();
        this.floatScorer.setScoringOrdinal(0);
    }

    @TearDown
    public void teardown() throws IOException {
        IOUtils.close((Closeable[])new Closeable[]{this.dir, this.bytesIn});
    }

    @Benchmark
    public float binaryDotProductDefault() throws IOException {
        float result = this.byteScorer.score(1);
        if (Math.abs(result - this.expectedBytes) > 1.0E-4f) {
            throw new RuntimeException("Expected " + result + " but got " + this.expectedBytes);
        }
        return result;
    }

    @Benchmark
    @Fork(jvmArgsPrepend={"--add-modules=jdk.incubator.vector"})
    public float binaryDotProductMemSeg() throws IOException {
        float result = this.byteScorer.score(1);
        if (Math.abs(result - this.expectedBytes) > 1.0E-4f) {
            throw new RuntimeException("Expected " + result + " but got " + this.expectedBytes);
        }
        return result;
    }

    @Benchmark
    public float floatDotProductDefault() throws IOException {
        float result = this.floatScorer.score(1);
        if (Math.abs(result - this.expectedFloats) > 1.0E-4f) {
            throw new RuntimeException("Expected " + result + " but got " + this.expectedFloats);
        }
        return result;
    }

    @Benchmark
    @Fork(jvmArgsPrepend={"--add-modules=jdk.incubator.vector"})
    public float floatDotProductMemSeg() throws IOException {
        float result = this.floatScorer.score(1);
        if (Math.abs(result - this.expectedFloats) > 1.0E-4f) {
            throw new RuntimeException("Expected " + result + " but got " + this.expectedFloats);
        }
        return result;
    }

    KnnVectorValues byteVectorValues(VectorSimilarityFunction sim) throws IOException {
        return new OffHeapByteVectorValues.DenseOffHeapVectorValues(this.size, 2, this.bytesIn.slice("test", (long)this.padBytes, (long)this.size * 2L), this.size, (FlatVectorsScorer)new ThrowingFlatVectorScorer(), sim);
    }

    KnnVectorValues floatVectorValues(VectorSimilarityFunction sim) throws IOException {
        int byteSize = this.size * 4;
        return new OffHeapFloatVectorValues.DenseOffHeapVectorValues(this.size, 2, this.floatsIn.slice("test", (long)this.padBytes, (long)byteSize * 2L), byteSize, (FlatVectorsScorer)new ThrowingFlatVectorScorer(), sim);
    }

    static final class ThrowingFlatVectorScorer
    implements FlatVectorsScorer {
        ThrowingFlatVectorScorer() {
        }

        public RandomVectorScorerSupplier getRandomVectorScorerSupplier(VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues) {
            throw new UnsupportedOperationException();
        }

        public RandomVectorScorer getRandomVectorScorer(VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, float[] target) {
            throw new UnsupportedOperationException();
        }

        public RandomVectorScorer getRandomVectorScorer(VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target) {
            throw new UnsupportedOperationException();
        }
    }
}

