package hex.word2vec;

import hex.Model;
import hex.ModelMetrics;
import hex.word2vec.Word2Vec;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Random;
import water.DKV;
import water.Futures;
import water.H2O;
import water.Iced;
import water.Key;
import water.fvec.AppendableVec;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.nbhm.NonBlockingHashMap;
import water.parser.ValueString;
import water.util.ArrayUtils;
import water.util.Log;
import water.util.RandomUtils;

/* loaded from: input_file:hex/word2vec/Word2VecModel.class */
public class Word2VecModel extends Model<Word2VecModel, Word2VecParameters, Word2VecOutput> {
    private volatile Word2VecModelInfo _modelInfo;
    private Key _w2vKey;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:hex/word2vec/Word2VecModel$Word2VecModelInfo.class */
    public static class Word2VecModelInfo extends Iced {
        static final int UNIGRAM_TABLE_SIZE = 10000000;
        static final float UNIGRAM_POWER = 0.75f;
        static final int MAX_CODE_LENGTH = 40;
        long _trainFrameSize;
        int _vocabSize;
        float _curLearningRate;
        float[] _syn0;
        float[] _syn1;
        int[] _uniTable = null;
        int[][] _HBWTCode = (int[][]) null;
        int[][] _HBWTPoint = (int[][]) null;
        private Word2VecParameters _parameters;
        private static int _localWordCnt;
        private static int _globalWordCnt;
        static final /* synthetic */ boolean $assertionsDisabled;

        public final Word2VecParameters getParams() {
            return this._parameters;
        }

        public Word2VecModelInfo() {
        }

        public Word2VecModelInfo(Word2VecParameters word2VecParameters) {
            this._parameters = word2VecParameters;
            if (this._parameters._vocabKey == null) {
                this._parameters._vocabKey = ((WordCountTask) new WordCountTask(this._parameters._minWordFreq).doAll(this._parameters.train()))._wordCountKey;
            }
            this._vocabSize = (int) this._parameters._vocabKey.get().numRows();
            this._trainFrameSize = getTrainFrameSize(this._parameters.train());
            Random rng = RandomUtils.getRNG(new long[]{912559, 55930});
            this._syn1 = new float[this._parameters._vecSize * this._vocabSize];
            this._syn0 = new float[this._parameters._vecSize * this._vocabSize];
            for (int i = 0; i < this._parameters._vecSize * this._vocabSize; i++) {
                this._syn0[i] = (rng.nextFloat() - 0.5f) / this._parameters._vecSize;
            }
            if (this._parameters._normModel == Word2Vec.NormModel.HSM) {
                buildHuffmanBinaryWordTree();
            } else {
                buildUnigramTable();
            }
        }

        public synchronized void addLocallyProcessed(long j) {
            _localWordCnt = (int) (_localWordCnt + j);
        }

        public synchronized long getLocallyProcessed() {
            return _localWordCnt;
        }

        public synchronized void setLocallyProcessed(int i) {
            _localWordCnt = i;
        }

        public synchronized void addGloballyProcessed(long j) {
            _globalWordCnt = (int) (_globalWordCnt + j);
        }

        public synchronized long getGloballyProcessed() {
            return _globalWordCnt;
        }

        public synchronized long getTotalProcessed() {
            return _globalWordCnt + _localWordCnt;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public void add(Word2VecModelInfo word2VecModelInfo) {
            ArrayUtils.add(this._syn0, word2VecModelInfo._syn0);
            ArrayUtils.add(this._syn1, word2VecModelInfo._syn1);
            addLocallyProcessed(word2VecModelInfo.getLocallyProcessed());
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public void div(float f) {
            if (f > 1.0f) {
                ArrayUtils.div(this._syn0, f);
                ArrayUtils.div(this._syn1, f);
            }
        }

        public void updateLearningRate() {
            this._curLearningRate = this._parameters._initLearningRate * (1.0f - (((float) getTotalProcessed()) / ((float) ((this._parameters._epochs * this._trainFrameSize) + 1))));
            if (this._curLearningRate < this._parameters._initLearningRate * 1.0E-4f) {
                this._curLearningRate = this._parameters._initLearningRate * 1.0E-4f;
            }
        }

        private void buildUnigramTable() {
            float f = 0.0f;
            long j = 0;
            this._uniTable = new int[UNIGRAM_TABLE_SIZE];
            Vec vec = this._parameters._vocabKey.get().vec(1);
            for (int i = 0; i < vec.length(); i++) {
                j = (long) (j + Math.pow(vec.at8(i), 0.75d));
            }
            int i2 = 0;
            for (int i3 = 0; i3 < UNIGRAM_TABLE_SIZE; i3++) {
                this._uniTable[i3] = i2;
                if (i2 >= this._vocabSize - 1) {
                    i2 = 0;
                }
                if (i3 / 1.0E7f > f) {
                    int i4 = i2;
                    i2++;
                    f = (float) (f + (Math.pow(vec.at8(i4), 0.75d) / ((float) j)));
                }
            }
        }

        /* JADX WARN: Type inference failed for: r1v12, types: [int[], int[][]] */
        /* JADX WARN: Type inference failed for: r1v9, types: [int[], int[][]] */
        private void buildHuffmanBinaryWordTree() {
            int i;
            int i2;
            int[] iArr = new int[MAX_CODE_LENGTH];
            int[] iArr2 = new int[MAX_CODE_LENGTH];
            long[] jArr = new long[(this._vocabSize * 2) - 1];
            int[] iArr3 = new int[(this._vocabSize * 2) - 1];
            int[] iArr4 = new int[(this._vocabSize * 2) - 1];
            Vec vec = this._parameters._vocabKey.get().vec(1);
            this._HBWTCode = new int[this._vocabSize];
            this._HBWTPoint = new int[this._vocabSize];
            if (!$assertionsDisabled && this._vocabSize != vec.length()) {
                throw new AssertionError();
            }
            for (int i3 = 0; i3 < this._vocabSize; i3++) {
                jArr[i3] = vec.at8(i3);
            }
            for (int i4 = this._vocabSize; i4 < (this._vocabSize * 2) - 1; i4++) {
                jArr[i4] = 1000000000000000L;
            }
            int i5 = this._vocabSize - 1;
            int i6 = this._vocabSize;
            for (int i7 = 0; i7 < this._vocabSize - 1; i7++) {
                if (i5 < 0) {
                    i = i6;
                    i6++;
                } else if (jArr[i5] < jArr[i6]) {
                    i = i5;
                    i5--;
                } else {
                    i = i6;
                    i6++;
                }
                if (i5 < 0) {
                    i2 = i6;
                    i6++;
                } else if (jArr[i5] < jArr[i6]) {
                    i2 = i5;
                    i5--;
                } else {
                    i2 = i6;
                    i6++;
                }
                jArr[this._vocabSize + i7] = jArr[i] + jArr[i2];
                iArr4[i] = this._vocabSize + i7;
                iArr4[i2] = this._vocabSize + i7;
                iArr3[i2] = 1;
            }
            for (int i8 = 0; i8 < this._vocabSize; i8++) {
                int i9 = i8;
                int i10 = 0;
                do {
                    iArr2[i10] = iArr3[i9];
                    iArr[i10] = i9;
                    i10++;
                    i9 = iArr4[i9];
                } while (i9 != 0);
                this._HBWTCode[i8] = new int[i10];
                this._HBWTPoint[i8] = new int[i10 + 1];
                this._HBWTPoint[i8][0] = this._vocabSize - 2;
                for (int i11 = 0; i11 < i10; i11++) {
                    this._HBWTCode[i8][(i10 - i11) - 1] = iArr2[i11];
                    this._HBWTPoint[i8][i10 - i11] = iArr[i11] - this._vocabSize;
                }
            }
        }

        private long getTrainFrameSize(Frame frame) {
            long j = 0;
            for (Vec vec : frame.vecs()) {
                if (vec.isString()) {
                    j += vec.length();
                }
            }
            return j;
        }

        static {
            $assertionsDisabled = !Word2VecModel.class.desiredAssertionStatus();
            _localWordCnt = 0;
            _globalWordCnt = 0;
        }
    }

    /* loaded from: input_file:hex/word2vec/Word2VecModel$Word2VecOutput.class */
    public static class Word2VecOutput extends Model.Output {
        public Word2Vec.WordModel _wordModel;
        public Word2Vec.NormModel _normModel;
        public int _minWordFreq;
        public int _vecSize;
        public int _windowSize;
        public int _epochs;
        public int _negSampleCnt;
        public float _initLearningRate;
        public float _sentSampleRate;

        public Word2VecOutput(Word2Vec word2Vec) {
            super(word2Vec);
        }

        public Model.ModelCategory getModelCategory() {
            return Model.ModelCategory.Unknown;
        }
    }

    /* loaded from: input_file:hex/word2vec/Word2VecModel$Word2VecParameters.class */
    public static class Word2VecParameters extends Model.Parameters {
        static final int MAX_VEC_SIZE = 10000;
        public Key<Frame> _vocabKey;
        public Word2Vec.WordModel _wordModel = Word2Vec.WordModel.SkipGram;
        public Word2Vec.NormModel _normModel = Word2Vec.NormModel.HSM;
        public int _minWordFreq = 5;
        public int _vecSize = 100;
        public int _windowSize = 5;
        public int _epochs = 5;
        public int _negSampleCnt = 5;
        public float _initLearningRate = 0.05f;
        public float _sentSampleRate = 0.001f;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void setModelInfo(Word2VecModelInfo word2VecModelInfo) {
        this._modelInfo = word2VecModelInfo;
    }

    public final Word2VecModelInfo getModelInfo() {
        return this._modelInfo;
    }

    public Word2VecModel(Key key, Word2VecParameters word2VecParameters, Word2VecOutput word2VecOutput) {
        super(key, word2VecParameters, word2VecOutput);
        this._modelInfo = new Word2VecModelInfo(word2VecParameters);
        if (!$assertionsDisabled && !Arrays.equals(this._key._kb, key._kb)) {
            throw new AssertionError();
        }
    }

    public ModelMetrics.MetricBuilder makeMetricBuilder(String[] strArr) {
        throw H2O.unimpl("No Model Metrics for Word2Vec.");
    }

    public double[] score0(Chunk[] chunkArr, int i, double[] dArr, double[] dArr2) {
        throw H2O.unimpl();
    }

    protected double[] score0(double[] dArr, double[] dArr2) {
        throw H2O.unimpl();
    }

    public float[] transform(String str) {
        return transform(new ValueString(str), buildVocabHashMap(), this._w2vKey.get().vecs());
    }

    private float[] transform(ValueString valueString, NonBlockingHashMap<ValueString, Integer> nonBlockingHashMap, Vec[] vecArr) {
        int length = vecArr.length - 1;
        float[] fArr = new float[length];
        if (!nonBlockingHashMap.containsKey(valueString)) {
            Log.warn(new Object[]{"Target word " + valueString + " isn't in vocabulary."});
            return null;
        }
        int intValue = ((Integer) nonBlockingHashMap.get(valueString)).intValue();
        for (int i = 0; i < length; i++) {
            fArr[i] = (float) vecArr[i + 1].at(intValue);
        }
        return fArr;
    }

    public HashMap<String, Float> findSynonyms(String str, int i) {
        if (i <= 0) {
            Log.err(new Object[]{"Synonym count must be greater than 0."});
            return null;
        }
        NonBlockingHashMap<ValueString, Integer> buildVocabHashMap = buildVocabHashMap();
        Vec[] vecs = this._w2vKey.get().vecs();
        return findSynonyms(transform(new ValueString(str), buildVocabHashMap, vecs), i, vecs);
    }

    public void findSynonyms(float[] fArr, int i) {
        if (i > 0) {
            findSynonyms(fArr, i, this._w2vKey.get().vecs());
        } else {
            Log.err(new Object[]{"Synonym count must be greater than 0."});
        }
    }

    private HashMap<String, Float> findSynonyms(float[] fArr, int i, Vec[] vecArr) {
        int length = vecArr.length - 1;
        int length2 = (int) vecArr[0].length();
        int[] iArr = new int[i];
        float[] fArr2 = new float[i];
        float[] fArr3 = new float[length];
        HashMap<String, Float> hashMap = new HashMap<>();
        if (fArr.length != vecArr.length - 1) {
            Log.warn(new Object[]{"Target vector length differs from the vocab's vector length."});
            return null;
        }
        for (int i2 = 0; i2 < length2; i2++) {
            for (int i3 = 0; i3 < length; i3++) {
                fArr3[i3] = (float) vecArr[i3 + 1].at(i2);
            }
            float cosineSimilarity = cosineSimilarity(fArr, fArr3);
            int i4 = 0;
            while (true) {
                if (i4 >= i) {
                    break;
                }
                if (cosineSimilarity <= fArr2[i4] || cosineSimilarity >= 0.999999d) {
                    i4++;
                } else {
                    for (int i5 = i - 1; i5 > i4; i5--) {
                        fArr2[i5] = fArr2[i5 - 1];
                        iArr[i5] = iArr[i5 - 1];
                    }
                    fArr2[i4] = cosineSimilarity;
                    iArr[i4] = i2;
                }
            }
        }
        for (int i6 = 0; i6 < i; i6++) {
            hashMap.put(vecArr[0].atStr(new ValueString(), iArr[i6]).toString(), Float.valueOf(fArr2[i6]));
        }
        return hashMap;
    }

    public float cosineSimilarity(float[] fArr, float[] fArr2) {
        float f = 0.0f;
        float f2 = 0.0f;
        float f3 = 0.0f;
        for (int i = 0; i < fArr.length; i++) {
            f += fArr[i] * fArr2[i];
            f2 = (float) (f2 + Math.pow(fArr[i], 2.0d));
            f3 = (float) (f3 + Math.pow(fArr2[i], 2.0d));
        }
        return (float) (f / (Math.sqrt(f2) * Math.sqrt(f3)));
    }

    private NonBlockingHashMap<ValueString, Integer> buildVocabHashMap() {
        Vec vec = this._w2vKey.get().vec(0);
        int numRows = (int) this._w2vKey.get().numRows();
        NonBlockingHashMap<ValueString, Integer> nonBlockingHashMap = new NonBlockingHashMap<>(numRows);
        for (int i = 0; i < numRows; i++) {
            nonBlockingHashMap.put(vec.atStr(new ValueString(), i), Integer.valueOf(i));
        }
        return nonBlockingHashMap;
    }

    public void buildModelOutput() {
        int i = ((Word2VecParameters) this._parms)._vecSize;
        Futures futures = new Futures();
        String[] strArr = new String[i];
        Vec[] vecArr = new Vec[i];
        Key[] addVecs = Vec.VectorGroup.VG_LEN1.addVecs(vecArr.length);
        NewChunk[] newChunkArr = new NewChunk[vecArr.length];
        Vec[] vecArr2 = new AppendableVec[vecArr.length];
        for (int i2 = 0; i2 < vecArr.length; i2++) {
            vecArr2[i2] = new AppendableVec(addVecs[i2]);
            newChunkArr[i2] = new NewChunk(vecArr2[i2], 0);
        }
        for (int i3 = 0; i3 < this._modelInfo._vocabSize; i3++) {
            for (int i4 = 0; i4 < i; i4++) {
                newChunkArr[i4].addNum(this._modelInfo._syn0[(i3 * i) + i4]);
            }
        }
        for (int i5 = 0; i5 < vecArr.length; i5++) {
            strArr[i5] = new String("V" + i5);
            newChunkArr[i5].close(0, futures);
            vecArr[i5] = vecArr2[i5].close(futures);
        }
        futures.blockForPending();
        Key make = Key.make("w2v");
        this._w2vKey = make;
        Frame frame = new Frame(make);
        frame.add("Word", ((Word2VecParameters) this._parms)._vocabKey.get().vec(0));
        frame.add(strArr, vecArr);
        DKV.put(this._w2vKey, frame);
    }

    public void delete() {
        ((Word2VecParameters) this._parms)._vocabKey.remove();
        this._w2vKey.remove();
        remove();
        super.delete();
    }

    static {
        $assertionsDisabled = !Word2VecModel.class.desiredAssertionStatus();
    }
}
