/*
 * Decompiled with CFR 0.152.
 */
package com.mindspore;

import com.mindspore.Graph;
import com.mindspore.MSTensor;
import com.mindspore.config.MSContext;
import com.mindspore.config.MindsporeLite;
import com.mindspore.config.TrainCfg;
import java.nio.MappedByteBuffer;
import java.util.ArrayList;
import java.util.List;

public class Model {
    private long modelPtr = 0L;

    public boolean build(Graph graph, MSContext context, TrainCfg cfg) {
        if (graph == null || context == null) {
            return false;
        }
        long cfgPtr = cfg != null ? cfg.getTrainCfgPtr() : 0L;
        this.modelPtr = this.buildByGraph(graph.getGraphPtr(), context.getMSContextPtr(), cfgPtr);
        return this.modelPtr != 0L;
    }

    public boolean build(MappedByteBuffer buffer, int modelType, MSContext context, char[] dec_key, String dec_mode, String cropto_lib_path) {
        if (context == null || buffer == null || dec_key == null || dec_mode == null) {
            return false;
        }
        this.modelPtr = this.buildByBuffer(buffer, modelType, context.getMSContextPtr(), dec_key, dec_mode, cropto_lib_path);
        return this.modelPtr != 0L;
    }

    public boolean build(MappedByteBuffer buffer, int modelType, MSContext context) {
        if (context == null || buffer == null) {
            return false;
        }
        this.modelPtr = this.buildByBuffer(buffer, modelType, context.getMSContextPtr(), null, "", "");
        return this.modelPtr != 0L;
    }

    public boolean build(String modelPath, int modelType, MSContext context, char[] dec_key, String dec_mode, String cropto_lib_path) {
        if (context == null || modelPath == null || dec_key == null || dec_mode == null) {
            return false;
        }
        this.modelPtr = this.buildByPath(modelPath, modelType, context.getMSContextPtr(), dec_key, dec_mode, cropto_lib_path);
        return this.modelPtr != 0L;
    }

    public boolean build(String modelPath, int modelType, MSContext context) {
        if (context == null || modelPath == null) {
            return false;
        }
        this.modelPtr = this.buildByPath(modelPath, modelType, context.getMSContextPtr(), null, "", "");
        return this.modelPtr != 0L;
    }

    public boolean predict() {
        return this.runStep(this.modelPtr);
    }

    public boolean runStep() {
        return this.runStep(this.modelPtr);
    }

    public boolean resize(List<MSTensor> inputs, int[][] dims) {
        if (inputs == null || dims == null) {
            return false;
        }
        long[] inputsArray = new long[inputs.size()];
        for (int i = 0; i < inputs.size(); ++i) {
            inputsArray[i] = inputs.get(i).getMSTensorPtr();
        }
        return this.resize(this.modelPtr, inputsArray, dims);
    }

    public List<MSTensor> getInputs() {
        List<Long> ret = this.getInputs(this.modelPtr);
        ArrayList<MSTensor> tensors = new ArrayList<MSTensor>();
        for (Long msTensorAddr : ret) {
            MSTensor msTensor = new MSTensor(msTensorAddr);
            tensors.add(msTensor);
        }
        return tensors;
    }

    public List<MSTensor> getOutputs() {
        List<Long> ret = this.getOutputs(this.modelPtr);
        ArrayList<MSTensor> tensors = new ArrayList<MSTensor>();
        for (Long msTensorAddr : ret) {
            MSTensor msTensor = new MSTensor(msTensorAddr);
            tensors.add(msTensor);
        }
        return tensors;
    }

    public MSTensor getInputByTensorName(String tensorName) {
        if (tensorName == null) {
            return null;
        }
        long tensorAddr = this.getInputByTensorName(this.modelPtr, tensorName);
        return new MSTensor(tensorAddr);
    }

    public MSTensor getOutputByTensorName(String tensorName) {
        if (tensorName == null) {
            return null;
        }
        long tensorAddr = this.getOutputByTensorName(this.modelPtr, tensorName);
        return new MSTensor(tensorAddr);
    }

    public List<MSTensor> getOutputsByNodeName(String nodeName) {
        if (nodeName == null) {
            return null;
        }
        List<Long> ret = this.getOutputsByNodeName(this.modelPtr, nodeName);
        ArrayList<MSTensor> tensors = new ArrayList<MSTensor>();
        for (Long msTensorAddr : ret) {
            MSTensor msTensor = new MSTensor(msTensorAddr);
            tensors.add(msTensor);
        }
        return tensors;
    }

    public List<String> getOutputTensorNames() {
        return this.getOutputTensorNames(this.modelPtr);
    }

    public boolean export(String fileName, int quantizationType, boolean isOnlyExportInfer, List<String> outputTensorNames) {
        if (fileName == null) {
            return false;
        }
        if (outputTensorNames != null) {
            String[] outputTensorArray = new String[outputTensorNames.size()];
            for (int i = 0; i < outputTensorNames.size(); ++i) {
                outputTensorArray[i] = outputTensorNames.get(i);
            }
            return this.export(this.modelPtr, fileName, quantizationType, isOnlyExportInfer, outputTensorArray);
        }
        return this.export(this.modelPtr, fileName, quantizationType, isOnlyExportInfer, null);
    }

    public List<MSTensor> getFeatureMaps() {
        List<Long> ret = this.getFeatureMaps(this.modelPtr);
        ArrayList<MSTensor> tensors = new ArrayList<MSTensor>();
        for (Long msTensorAddr : ret) {
            MSTensor msTensor = new MSTensor(msTensorAddr);
            tensors.add(msTensor);
        }
        return tensors;
    }

    public boolean updateFeatureMaps(List<MSTensor> features) {
        if (features == null) {
            return false;
        }
        long[] inputsArray = new long[features.size()];
        for (int i = 0; i < features.size(); ++i) {
            inputsArray[i] = features.get(i).getMSTensorPtr();
        }
        return this.updateFeatureMaps(this.modelPtr, inputsArray);
    }

    public boolean setTrainMode(boolean isTrain) {
        return this.setTrainMode(this.modelPtr, isTrain);
    }

    public boolean getTrainMode() {
        return this.getTrainMode(this.modelPtr);
    }

    public boolean setLearningRate(float learning_rate) {
        return this.setLearningRate(this.modelPtr, learning_rate);
    }

    public boolean setupVirtualBatch(int virtualBatchMultiplier, float learningRate, float momentum) {
        return this.setupVirtualBatch(this.modelPtr, virtualBatchMultiplier, learningRate, momentum);
    }

    public void free() {
        this.free(this.modelPtr);
    }

    private native void free(long var1);

    private native long buildByGraph(long var1, long var3, long var5);

    private native long buildByPath(String var1, int var2, long var3, char[] var5, String var6, String var7);

    private native long buildByBuffer(MappedByteBuffer var1, int var2, long var3, char[] var5, String var6, String var7);

    private native List<Long> getInputs(long var1);

    private native long getInputByTensorName(long var1, String var3);

    private native boolean runStep(long var1);

    private native List<Long> getOutputs(long var1);

    private native long getOutputByTensorName(long var1, String var3);

    private native List<String> getOutputTensorNames(long var1);

    private native List<Long> getOutputsByNodeName(long var1, String var3);

    private native boolean setTrainMode(long var1, boolean var3);

    private native boolean getTrainMode(long var1);

    private native boolean resize(long var1, long[] var3, int[][] var4);

    private native boolean export(long var1, String var3, int var4, boolean var5, String[] var6);

    private native List<Long> getFeatureMaps(long var1);

    private native boolean updateFeatureMaps(long var1, long[] var3);

    private native boolean setLearningRate(long var1, float var3);

    private native boolean setupVirtualBatch(long var1, int var3, float var4, float var5);

    static {
        MindsporeLite.init();
    }
}

