IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 移动开发 -> PyTorch Mobile Demo on Android -> 正文阅读

[移动开发]PyTorch Mobile Demo on Android

HelloWorld

HelloWorld is a simple image classification application that demonstrates how to use PyTorch Android API.
This application runs TorchScript serialized TorchVision pretrained MobileNet v3 model on static image which is packaged inside the app as android asset.

1. Model Preparation

Let’s start with model preparation. If you are familiar with PyTorch, you probably should already know how to train and save your model. In case you don’t, we are going to use a pre-trained image classification model(MobileNet v3), which is packaged in TorchVision.
To install it, run the command below:

pip install torch torchvision

To serialize and optimize the model for Android, you can use the Python script in the root folder of HelloWorld app:

import torch
import torchvision
from torch.utils.mobile_optimizer import optimize_for_mobile

model = torchvision.models.mobilenet_v3_small(pretrained=True)
model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
optimized_traced_model = optimize_for_mobile(traced_script_module)
optimized_traced_model._save_for_lite_interpreter("app/src/main/assets/model.ptl")

If everything works well, we should have our scripted and optimized model - model.pt generated in the assets folder of android application.
That will be packaged inside android application as asset and can be used on the device.

By using the new MobileNet v3 model instead of the old Resnet18 model, and by calling the optimize_for_mobile method on the traced model, the model inference time on a Pixel 3 gets decreased from over 230ms to about 40ms.

More details about TorchScript you can find in tutorials on pytorch.org

2. Cloning from github

git clone https://github.com/pytorch/android-demo-app.git
cd HelloWorldApp

If Android SDK and Android NDK are already installed you can install this application to the connected android device or emulator with:

./gradlew installDebug

We recommend you to open this project in Android Studio 3.5.1+ (At the moment PyTorch Android and demo applications use android gradle plugin of version 3.5.0, which is supported only by Android Studio version 3.5.1 and higher),
in that case you will be able to install Android NDK and Android SDK using Android Studio UI.

3. Gradle dependencies

Pytorch android is added to the HelloWorld as gradle dependencies in build.gradle:

repositories {
    jcenter()
}

dependencies {
    implementation 'org.pytorch:pytorch_android_lite:1.9.0'
    implementation 'org.pytorch:pytorch_android_torchvision:1.9.0'
}

Where org.pytorch:pytorch_android is the main dependency with PyTorch Android API, including libtorch native library for all 4 android abis (armeabi-v7a, arm64-v8a, x86, x86_64).
Further in this doc you can find how to rebuild it only for specific list of android abis.

org.pytorch:pytorch_android_torchvision - additional library with utility functions for converting android.media.Image and android.graphics.Bitmap to tensors.

4. Reading image from Android Asset

All the logic happens in org.pytorch.helloworld.MainActivity.
As a first step we read image.jpg to android.graphics.Bitmap using the standard Android API.

Bitmap bitmap = BitmapFactory.decodeStream(getAssets().open("image.jpg"));

5. Loading TorchScript Module

Module module = LiteModuleLoader.load(assetFilePath(this, "model.pt"));

org.pytorch.Module represents torch::jit::script::Module that can be loaded with load method specifying file path to the serialized to file model.

6. Preparing Input

Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
    TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB);

org.pytorch.torchvision.TensorImageUtils is part of org.pytorch:pytorch_android_torchvision library.
The TensorImageUtils#bitmapToFloat32Tensor method creates tensors in the torchvision format using android.graphics.Bitmap as a source.

All pre-trained models expect input images normalized in the same way, i.e. mini-batches of 3-channel RGB images of shape (3 x H x W), where H and W are expected to be at least 224.
The images have to be loaded in to a range of [0, 1] and then normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225]

inputTensor's shape is 1x3xHxW, where H and W are bitmap height and width appropriately.

7. Run Inference

Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
float[] scores = outputTensor.getDataAsFloatArray();

org.pytorch.Module.forward method runs loaded module’s forward method and gets result as org.pytorch.Tensor outputTensor with shape 1x1000.

8. Processing results

Its content is retrieved using org.pytorch.Tensor.getDataAsFloatArray() method that returns java array of floats with scores for every image net class.

After that we just find index with maximum score and retrieve predicted class name from ImageNetClasses.IMAGENET_CLASSES array that contains all ImageNet classes.

float maxScore = -Float.MAX_VALUE;
int maxScoreIdx = -1;
for (int i = 0; i < scores.length; i++) {
  if (scores[i] > maxScore) {
    maxScore = scores[i];
    maxScoreIdx = i;
  }
}
String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx];

In the following sections you can find detailed explanations of PyTorch Android API, code walk through for a bigger demo application,
implementation details of the API, how to customize and build it from source.

PyTorchDemoApp

Image Classification

This demo application does image classification from camera output and text classification in the same github repo.

To get device camera output it uses Android CameraX API.
All the logic that works with CameraX is separated to org.pytorch.demo.vision.AbstractCameraXActivity class.

void setupCameraX() {
    final PreviewConfig previewConfig = new PreviewConfig.Builder().build();
    final Preview preview = new Preview(previewConfig);
    preview.setOnPreviewOutputUpdateListener(output -> mTextureView.setSurfaceTexture(output.getSurfaceTexture()));

    final ImageAnalysisConfig imageAnalysisConfig =
        new ImageAnalysisConfig.Builder()
            .setTargetResolution(new Size(224, 224))
            .setCallbackHandler(mBackgroundHandler)
            .setImageReaderMode(ImageAnalysis.ImageReaderMode.ACQUIRE_LATEST_IMAGE)
            .build();
    final ImageAnalysis imageAnalysis = new ImageAnalysis(imageAnalysisConfig);
    imageAnalysis.setAnalyzer(
        (image, rotationDegrees) -> {
          analyzeImage(image, rotationDegrees);
        });

    CameraX.bindToLifecycle(this, preview, imageAnalysis);
  }

  void analyzeImage(android.media.Image, int rotationDegrees)

Where the analyzeImage method process the camera output, android.media.Image.

It uses the aforementioned TensorImageUtils.imageYUV420CenterCropToFloat32Tensor method to convert android.media.Image in YUV420 format to input tensor.

After getting predicted scores from the model it finds top K classes with the highest scores and shows on the UI.

Language Processing Example

Another example is natural language processing, based on an LSTM model, trained on a reddit comments dataset.
The logic happens in TextClassificattionActivity.

Result class names are packaged inside the TorchScript model and initialized just after initial module initialization.
The module has a get_classes method that returns List[str], which can be called using method Module.runMethod(methodName):

    mModule = Module.load(moduleFileAbsoluteFilePath);
    IValue getClassesOutput = mModule.runMethod("get_classes");

The returned IValue can be converted to java array of IValue using IValue.toList() and processed to an array of strings using IValue.toStr():

    IValue[] classesListIValue = getClassesOutput.toList();
    String[] moduleClasses = new String[classesListIValue.length];
    int i = 0;
    for (IValue iv : classesListIValue) {
      moduleClasses[i++] = iv.toStr();
    }

Entered text is converted to java array of bytes with UTF-8 encoding. Tensor.fromBlobUnsigned creates tensor of dtype=uint8 from that array of bytes.

    byte[] bytes = text.getBytes(Charset.forName("UTF-8"));
    final long[] shape = new long[]{1, bytes.length};
    final Tensor inputTensor = Tensor.fromBlobUnsigned(bytes, shape);

Running inference of the model is similar to previous examples:

Tensor outputTensor = mModule.forward(IValue.from(inputTensor)).toTensor()

After that, the code processes the output, finding classes with the highest scores.

Semantic Image Segmentation DeepLabV3 with Mobile Interpreter on Android

Introduction

This repo offers a Python script that converts the PyTorch DeepLabV3 model to the Lite Interpreter version of model, also optimized for mobile, and an Android app that uses the model to segment images.

Prerequisites

  • PyTorch 1.9.0 and torchvision 0.10.0 (Optional)
  • Python 3.8 or above (Optional)
  • Android Pytorch library pytorch_android_lite:1.9.0, pytorch_android_torchvision:1.9.0
  • Android Studio 4.0.1 or later

Quick Start

To Test Run the Image Segmentation Android App, follow the steps below:

1. Prepare the Model

If you don’t have the PyTorch 1.9.0 environment set up, you can download the optimized-for-mobile Mobile Interpreter version of model file to the android-demo-app/ImageSegmentation/app/src/main/assets folder using the link here.

Otherwise, open a terminal window, first install PyTorch 1.9.0 and torchvision 0.10.0 using command like pip install torch torchvision, then run the following commands:

git clone https://github.com/pytorch/android-demo-app
cd android-demo-app/ImageSegmentation
python deeplabv3.py

The Python script deeplabv3.py is used to generate the TorchScript-formatted models for mobile apps. For comparison, three versions of the model are generated: a full JIT version of the model, a Mobile Interpreter version of the model which is not optimized for mobile, and a Mobile Interpreter version of the model which is optimized for mobile, named as deeplabv3_scripted_optimized.ptl. The last one is what should be used in mobile apps, as its inference speed is over 60% faster than the non-optimized Mobile Interpreter model, which is about 6% faster than the non-optimized full JIT model.

2. Use Android Studio

Open the ImageSegmentation project using Android Studio. Note the app’s build.gradle file has the following lines:

implementation 'org.pytorch:pytorch_android_lite:1.9.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.9.0'

and in the MainActivity.java, the code below is used to load the model:

mModule = LiteModuleLoader.load(MainActivity.assetFilePath(getApplicationContext(), "deeplabv3_scripted_optimized.ptl"));

3. Run the app

Select an Android emulator or device and build and run the app. The example image and its segmented result are as follows:

results are:

在这里插入图片描述
在这里插入图片描述

Note that the example image used in the repo is pretty large (400x400) so the segmentation process may take about 10 seconds. You may use an image of smaller size but the segmentation result may be less accurate.

Tutorial

Read the tutorial here for detailed step-by-step instructions of how to prepare and run the PyTorch DeepLabV3 model on Android, as well as practical tips on how to successfully use a pre-trained PyTorch model on Android and avoid common pitfalls.

For more information on using Mobile Interpreter in Android, see the tutorial here.

  移动开发 最新文章
Vue3装载axios和element-ui
android adb cmd
【xcode】Xcode常用快捷键与技巧
Android开发中的线程池使用
Java 和 Android 的 Base64
Android 测试文字编码格式
微信小程序支付
安卓权限记录
知乎之自动养号
【Android Jetpack】DataStore
上一篇文章      下一篇文章      查看所有文章
加:2021-07-16 11:25:31  更:2021-07-16 11:26:36 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年5日历 -2024/5/5 13:41:53-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码