package ai.onnxruntime.reactnative;

import ai.onnxruntime.OnnxJavaType;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OnnxValue;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtSession;
import ai.onnxruntime.OrtUtil;
import ai.onnxruntime.TensorInfo;
import com.facebook.react.bridge.Arguments;
import com.facebook.react.bridge.ReadableArray;
import com.facebook.react.bridge.ReadableMap;
import com.facebook.react.bridge.WritableArray;
import com.facebook.react.bridge.WritableMap;
import com.facebook.react.modules.blob.BlobModule;
import com.tencent.android.tpush.common.Constants;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Iterator;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/* loaded from: classes.dex */
public class TensorHelper {
    public static final String JsTensorTypeBool = "bool";
    public static final String JsTensorTypeByte = "int8";
    public static final String JsTensorTypeDouble = "float64";
    public static final String JsTensorTypeFloat = "float32";
    public static final String JsTensorTypeInt = "int32";
    public static final String JsTensorTypeLong = "int64";
    public static final String JsTensorTypeShort = "int16";
    public static final String JsTensorTypeString = "string";
    private static final Map<String, TensorInfo.OnnxTensorType> JsTensorTypeToOnnxTensorTypeMap;
    public static final String JsTensorTypeUnsignedByte = "uint8";
    private static final Map<TensorInfo.OnnxTensorType, String> OnnxTensorTypeToJsTensorTypeMap;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: ai.onnxruntime.reactnative.TensorHelper$1, reason: invalid class name */
    /* loaded from: classes.dex */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$ai$onnxruntime$TensorInfo$OnnxTensorType;

        static {
            int[] iArr = new int[TensorInfo.OnnxTensorType.values().length];
            $SwitchMap$ai$onnxruntime$TensorInfo$OnnxTensorType = iArr;
            try {
                iArr[TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT.ordinal()] = 1;
            } catch (NoSuchFieldError unused) {
            }
            try {
                $SwitchMap$ai$onnxruntime$TensorInfo$OnnxTensorType[TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8.ordinal()] = 2;
            } catch (NoSuchFieldError unused2) {
            }
            try {
                $SwitchMap$ai$onnxruntime$TensorInfo$OnnxTensorType[TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16.ordinal()] = 3;
            } catch (NoSuchFieldError unused3) {
            }
            try {
                $SwitchMap$ai$onnxruntime$TensorInfo$OnnxTensorType[TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32.ordinal()] = 4;
            } catch (NoSuchFieldError unused4) {
            }
            try {
                $SwitchMap$ai$onnxruntime$TensorInfo$OnnxTensorType[TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64.ordinal()] = 5;
            } catch (NoSuchFieldError unused5) {
            }
            try {
                $SwitchMap$ai$onnxruntime$TensorInfo$OnnxTensorType[TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE.ordinal()] = 6;
            } catch (NoSuchFieldError unused6) {
            }
            try {
                $SwitchMap$ai$onnxruntime$TensorInfo$OnnxTensorType[TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8.ordinal()] = 7;
            } catch (NoSuchFieldError unused7) {
            }
            try {
                $SwitchMap$ai$onnxruntime$TensorInfo$OnnxTensorType[TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL.ordinal()] = 8;
            } catch (NoSuchFieldError unused8) {
            }
            try {
                $SwitchMap$ai$onnxruntime$TensorInfo$OnnxTensorType[TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16.ordinal()] = 9;
            } catch (NoSuchFieldError unused9) {
            }
            try {
                $SwitchMap$ai$onnxruntime$TensorInfo$OnnxTensorType[TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16.ordinal()] = 10;
            } catch (NoSuchFieldError unused10) {
            }
            try {
                $SwitchMap$ai$onnxruntime$TensorInfo$OnnxTensorType[TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32.ordinal()] = 11;
            } catch (NoSuchFieldError unused11) {
            }
            try {
                $SwitchMap$ai$onnxruntime$TensorInfo$OnnxTensorType[TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64.ordinal()] = 12;
            } catch (NoSuchFieldError unused12) {
            }
            try {
                $SwitchMap$ai$onnxruntime$TensorInfo$OnnxTensorType[TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING.ordinal()] = 13;
            } catch (NoSuchFieldError unused13) {
            }
        }
    }

    static {
        TensorInfo.OnnxTensorType onnxTensorType = TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
        TensorInfo.OnnxTensorType onnxTensorType2 = TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
        TensorInfo.OnnxTensorType onnxTensorType3 = TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
        TensorInfo.OnnxTensorType onnxTensorType4 = TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16;
        TensorInfo.OnnxTensorType onnxTensorType5 = TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
        TensorInfo.OnnxTensorType onnxTensorType6 = TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
        TensorInfo.OnnxTensorType onnxTensorType7 = TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
        TensorInfo.OnnxTensorType onnxTensorType8 = TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
        TensorInfo.OnnxTensorType onnxTensorType9 = TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
        JsTensorTypeToOnnxTensorTypeMap = (Map) Stream.of((Object[]) new Object[][]{new Object[]{JsTensorTypeFloat, onnxTensorType}, new Object[]{JsTensorTypeByte, onnxTensorType2}, new Object[]{JsTensorTypeUnsignedByte, onnxTensorType3}, new Object[]{JsTensorTypeShort, onnxTensorType4}, new Object[]{JsTensorTypeInt, onnxTensorType5}, new Object[]{JsTensorTypeLong, onnxTensorType6}, new Object[]{"string", onnxTensorType7}, new Object[]{JsTensorTypeBool, onnxTensorType8}, new Object[]{JsTensorTypeDouble, onnxTensorType9}}).collect(Collectors.toMap(new Function() { // from class: ai.onnxruntime.reactnative.VWDM爩鱅E鷙
            @Override // java.util.function.Function
            public final Object apply(Object obj) {
                String lambda$static$0;
                lambda$static$0 = TensorHelper.lambda$static$0((Object[]) obj);
                return lambda$static$0;
            }
        }, new Function() { // from class: ai.onnxruntime.reactnative.簾WX齇鷙M鱅XT
            @Override // java.util.function.Function
            public final Object apply(Object obj) {
                TensorInfo.OnnxTensorType lambda$static$1;
                lambda$static$1 = TensorHelper.lambda$static$1((Object[]) obj);
                return lambda$static$1;
            }
        }));
        OnnxTensorTypeToJsTensorTypeMap = (Map) Stream.of((Object[]) new Object[][]{new Object[]{onnxTensorType, JsTensorTypeFloat}, new Object[]{onnxTensorType2, JsTensorTypeByte}, new Object[]{onnxTensorType3, JsTensorTypeUnsignedByte}, new Object[]{onnxTensorType4, JsTensorTypeShort}, new Object[]{onnxTensorType5, JsTensorTypeInt}, new Object[]{onnxTensorType6, JsTensorTypeLong}, new Object[]{onnxTensorType7, "string"}, new Object[]{onnxTensorType8, JsTensorTypeBool}, new Object[]{onnxTensorType9, JsTensorTypeDouble}}).collect(Collectors.toMap(new Function() { // from class: ai.onnxruntime.reactnative.鱅齇矡竈GV齇鱅V矡HQ齇貜
            @Override // java.util.function.Function
            public final Object apply(Object obj) {
                TensorInfo.OnnxTensorType lambda$static$2;
                lambda$static$2 = TensorHelper.lambda$static$2((Object[]) obj);
                return lambda$static$2;
            }
        }, new Function() { // from class: ai.onnxruntime.reactnative.糴爩貜鱅FRR鬚JV
            @Override // java.util.function.Function
            public final Object apply(Object obj) {
                String lambda$static$3;
                lambda$static$3 = TensorHelper.lambda$static$3((Object[]) obj);
                return lambda$static$3;
            }
        }));
    }

    private static OnnxTensor createInputTensor(TensorInfo.OnnxTensorType onnxTensorType, long[] jArr, ByteBuffer byteBuffer, OrtEnvironment ortEnvironment) throws Exception {
        switch (AnonymousClass1.$SwitchMap$ai$onnxruntime$TensorInfo$OnnxTensorType[onnxTensorType.ordinal()]) {
            case 1:
                return OnnxTensor.createTensor(ortEnvironment, byteBuffer.asFloatBuffer(), jArr);
            case 2:
                return OnnxTensor.createTensor(ortEnvironment, byteBuffer, jArr);
            case 3:
                return OnnxTensor.createTensor(ortEnvironment, byteBuffer.asShortBuffer(), jArr);
            case 4:
                return OnnxTensor.createTensor(ortEnvironment, byteBuffer.asIntBuffer(), jArr);
            case 5:
                return OnnxTensor.createTensor(ortEnvironment, byteBuffer.asLongBuffer(), jArr);
            case 6:
                return OnnxTensor.createTensor(ortEnvironment, byteBuffer.asDoubleBuffer(), jArr);
            case 7:
                return OnnxTensor.createTensor(ortEnvironment, byteBuffer, jArr, OnnxJavaType.UINT8);
            case 8:
                return OnnxTensor.createTensor(ortEnvironment, byteBuffer, jArr, OnnxJavaType.BOOL);
            default:
                throw new IllegalStateException("Unexpected value: " + onnxTensorType.toString());
        }
    }

    public static OnnxTensor createInputTensor(BlobModule blobModule, ReadableMap readableMap, OrtEnvironment ortEnvironment) throws Exception {
        ReadableArray array = readableMap.getArray("dims");
        long[] jArr = new long[array.size()];
        for (int i = 0; i < array.size(); i++) {
            jArr[i] = array.getInt(i);
        }
        TensorInfo.OnnxTensorType onnxTensorType = getOnnxTensorType(readableMap.getString("type"));
        if (onnxTensorType != TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) {
            ReadableMap map = readableMap.getMap("data");
            String string = map.getString("blobId");
            byte[] resolve = blobModule.resolve(string, map.getInt(Constants.FLAG_TAG_OFFSET), map.getInt("size"));
            blobModule.remove(string);
            return createInputTensor(onnxTensorType, jArr, ByteBuffer.wrap(resolve).order(ByteOrder.nativeOrder()), ortEnvironment);
        }
        ReadableArray array2 = readableMap.getArray("data");
        String[] strArr = new String[array2.size()];
        for (int i2 = 0; i2 < array2.size(); i2++) {
            strArr[i2] = array2.getString(i2);
        }
        return OnnxTensor.createTensor(ortEnvironment, strArr, jArr);
    }

    public static WritableMap createOutputTensor(BlobModule blobModule, OrtSession.Result result) throws Exception {
        WritableMap createMap = Arguments.createMap();
        Iterator<Map.Entry<String, OnnxValue>> it2 = result.iterator();
        while (it2.hasNext()) {
            Map.Entry<String, OnnxValue> next = it2.next();
            String key = next.getKey();
            OnnxValue value = next.getValue();
            if (value.getType() != OnnxValue.OnnxValueType.ONNX_TYPE_TENSOR) {
                throw new Exception("Not supported type: " + value.getType().toString());
            }
            OnnxTensor onnxTensor = (OnnxTensor) value;
            WritableMap createMap2 = Arguments.createMap();
            WritableArray createArray = Arguments.createArray();
            for (long j : onnxTensor.getInfo().getShape()) {
                createArray.pushInt((int) j);
            }
            createMap2.putArray("dims", createArray);
            createMap2.putString("type", getJsTensorType(onnxTensor.getInfo().onnxType));
            if (onnxTensor.getInfo().onnxType == TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) {
                String[] strArr = (String[]) onnxTensor.getValue();
                WritableArray createArray2 = Arguments.createArray();
                for (String str : strArr) {
                    createArray2.pushString(str);
                }
                createMap2.putArray("data", createArray2);
            } else {
                byte[] createOutputTensor = createOutputTensor(onnxTensor);
                WritableMap createMap3 = Arguments.createMap();
                createMap3.putString("blobId", blobModule.store(createOutputTensor));
                createMap3.putInt(Constants.FLAG_TAG_OFFSET, 0);
                createMap3.putInt("size", createOutputTensor.length);
                createMap2.putMap("data", createMap3);
            }
            createMap.putMap(key, createMap2);
        }
        return createMap;
    }

    private static byte[] createOutputTensor(OnnxTensor onnxTensor) throws Exception {
        ByteBuffer order;
        TensorInfo info = onnxTensor.getInfo();
        int elementCount = (int) OrtUtil.elementCount(onnxTensor.getInfo().getShape());
        switch (AnonymousClass1.$SwitchMap$ai$onnxruntime$TensorInfo$OnnxTensorType[info.onnxType.ordinal()]) {
            case 1:
                order = ByteBuffer.allocate(elementCount * 4).order(ByteOrder.nativeOrder());
                order.asFloatBuffer().put(onnxTensor.getFloatBuffer());
                break;
            case 2:
            case 8:
                order = ByteBuffer.allocate(elementCount).order(ByteOrder.nativeOrder());
                order.put(onnxTensor.getByteBuffer());
                break;
            case 3:
                order = ByteBuffer.allocate(elementCount * 2).order(ByteOrder.nativeOrder());
                order.asShortBuffer().put(onnxTensor.getShortBuffer());
                break;
            case 4:
                order = ByteBuffer.allocate(elementCount * 4).order(ByteOrder.nativeOrder());
                order.asIntBuffer().put(onnxTensor.getIntBuffer());
                break;
            case 5:
                order = ByteBuffer.allocate(elementCount * 8).order(ByteOrder.nativeOrder());
                order.asLongBuffer().put(onnxTensor.getLongBuffer());
                break;
            case 6:
                order = ByteBuffer.allocate(elementCount * 8).order(ByteOrder.nativeOrder());
                order.asDoubleBuffer().put(onnxTensor.getDoubleBuffer());
                break;
            case 7:
                order = ByteBuffer.allocate(elementCount).order(ByteOrder.nativeOrder());
                order.put(onnxTensor.getByteBuffer());
                break;
            default:
                throw new IllegalStateException("Unexpected type: " + info.onnxType.toString());
        }
        return order.array();
    }

    private static String getJsTensorType(TensorInfo.OnnxTensorType onnxTensorType) {
        Map<TensorInfo.OnnxTensorType, String> map = OnnxTensorTypeToJsTensorTypeMap;
        return map.containsKey(onnxTensorType) ? map.get(onnxTensorType) : "undefined";
    }

    private static TensorInfo.OnnxTensorType getOnnxTensorType(String str) {
        Map<String, TensorInfo.OnnxTensorType> map = JsTensorTypeToOnnxTensorTypeMap;
        return map.containsKey(str) ? map.get(str) : TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static /* synthetic */ String lambda$static$0(Object[] objArr) {
        return (String) objArr[0];
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static /* synthetic */ TensorInfo.OnnxTensorType lambda$static$1(Object[] objArr) {
        return (TensorInfo.OnnxTensorType) objArr[1];
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static /* synthetic */ TensorInfo.OnnxTensorType lambda$static$2(Object[] objArr) {
        return (TensorInfo.OnnxTensorType) objArr[0];
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static /* synthetic */ String lambda$static$3(Object[] objArr) {
        return (String) objArr[1];
    }
}
