AndroidPytorch

1. 模型转化

import torch
import torchvision
from torch.utils.mobile_optimizer import optimize_for_mobile

model = torchvision.models.mobilenet_v3_small(pretrained=True)
#model.load_state_dict(torch.load(model_pth)) # 加载参数
model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)  # 模型转化
optimized_traced_model = optimize_for_mobile(traced_script_module)
optimized_traced_model.save("model.pt") # 保存文件
optimized_traced_model._save_for_lite_interpreter("app/src/main/assets/model.pt")

2. Gradle 依赖

implementation 'org.pytorch:pytorch_android_lite:1.9.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.9.0'

3. 封装函数

public class MainActivity extends AppCompatActivity {

  @Override
  protected void onCreate(Bundle savedInstanceState) {
    super.onCreate(savedInstanceState);
    setContentView(R.layout.activity_main);

    Bitmap bitmap = null;
    Module module = null;
    try {
      // creating bitmap from packaged into app android asset 'image.jpg',
      // app/src/main/assets/image.jpg
      bitmap = BitmapFactory.decodeStream(getAssets().open("image.jpg"));
      // loading serialized torchscript module from packaged into app android asset model.pt,
      // app/src/model/assets/model.pt
      module = LiteModuleLoader.load(assetFilePath(this, "model.pt"));
    } catch (IOException e) {
      Log.e("PytorchHelloWorld", "Error reading assets", e);
      finish();
    }

    // showing image on UI
    ImageView imageView = findViewById(R.id.image);
    imageView.setImageBitmap(bitmap);

    // preparing input tensor
    final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
        TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB, MemoryFormat.CHANNELS_LAST);

    // running the model
    final Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();

    // getting tensor content as java array of floats
    final float[] scores = outputTensor.getDataAsFloatArray();

    // searching for the index with maximum score
    float maxScore = -Float.MAX_VALUE;
    int maxScoreIdx = -1;
    for (int i = 0; i < scores.length; i++) {
      if (scores[i] > maxScore) {
        maxScore = scores[i];
        maxScoreIdx = i;
      }
    }

    String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx];

    // showing className on UI
    TextView textView = findViewById(R.id.text);
    textView.setText(className);
  }

  /**
   * Copies specified asset to the file in /files app directory and returns this file absolute path.
   *
   * @return absolute file path
   */
  public static String assetFilePath(Context context, String assetName) throws IOException {
    File file = new File(context.getFilesDir(), assetName);
    if (file.exists() && file.length() > 0) {
      return file.getAbsolutePath();
    }

    try (InputStream is = context.getAssets().open(assetName)) {
      try (OutputStream os = new FileOutputStream(file)) {
        byte[] buffer = new byte[4 * 1024];
        int read;
        while ((read = is.read(buffer)) != -1) {
          os.write(buffer, 0, read);
        }
        os.flush();
      }
      return file.getAbsolutePath();
    }
  }
}
public class RecognizeTorch {
    private Module module=null;

    private RecognizeTorch(){
    }
    private static class Inner {
        private static final RecognizeTorch instance = new RecognizeTorch();
    }
    public static RecognizeTorch getSingleton(){
        return RecognizeTorch.Inner.instance;
    }
    public Boolean initializeModel(Context context) throws IOException {
        module = LiteModuleLoader.load(assetFilePath(context, Constant.MODEL_PATH));
        if(module!=null){
            return true;
        }
        return false;
    }
    public String getRecognizeResult(FlexWindow flexWindow){
        float[] data=new float[5*5];
        ArrayList<Float> inputList=new ArrayList<>();
        ArrayList<Double> arrayList=new ArrayList<Double>();
        for(Double value :flexWindow.getSingleFlexData((int)(flexWindow.getSize()/6*4)).getFlexData()){
            arrayList.add(value);
        }
        for(int i=0;i<arrayList.size();i++){
            inputList.add(arrayList.get(i).floatValue());
            for(int j=0;j<arrayList.size();j++){
                if(i!=j){
                    inputList.add(arrayList.get(i).floatValue()-arrayList.get(j).floatValue());
                }
            }
        }
        long[] shape={1,5,5};
        for(int i=0;i<25;i++){
            data[i]=inputList.get(i).floatValue();
        }
        Tensor input_tensor= Tensor.fromBlob(data,shape);
        System.out.println(input_tensor.toString());
        return getRecognizeReuslt(input_tensor);
    }
    public String getRecognizeReuslt(Tensor inputTensor){
        final Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
        // getting tensor content as java array of floats
        final float[] scores = outputTensor.getDataAsFloatArray();

        // searching for the index with maximum score
       /* float maxScore = -Float.MAX_VALUE;
        int maxScoreIdx = -1;
        for (int i = 0; i < scores.length; i++) {
            if (scores[i] > maxScore) {
                maxScore = scores[i];
                maxScoreIdx = i;
            }
        }*/
        int[] Index = new int[scores.length];
        Index = ArrayHelper.Arraysort(scores);

        for (int i = 0; i < 10; i++) {
            System.out.println(Index[i] + ":" + scores[i]);
        }
        String classname = Constant.Gesture_CHAR_CLASSES[Index[0]];
        return classname;
    }


    /**
     * Copies specified asset to the file in /files app directory and returns this file absolute path.
     * @return absolute file path
     * */
    public static String assetFilePath(Context context, String assetName) throws IOException {
        File file = new File(context.getFilesDir(), assetName);
        if (file.exists() && file.length() > 0) {
            return file.getAbsolutePath();
        }

        try (InputStream is = context.getAssets().open(assetName)) {
            try (OutputStream os = new FileOutputStream(file)) {
                byte[] buffer = new byte[4 * 1024];
                int read;
                while ((read = is.read(buffer)) != -1) {
                    os.write(buffer, 0, read);
                }
                os.flush();
            }
            return file.getAbsolutePath();
        }
    }
}

Resource

0%