Deploying PyTorch to an Android Environment

I have some inchoate thoughts about approaches to deployable AI products. Too much of the conversation around AI frames it from the angle of AI-as-a-service, something that a customer makes requests of. This is evidenced by reactions to things like DALL-E 2 being concerned about the possible future of professional artists. A more useful frame of reference is AI-as-an-instrument (AIaaI), i.e. a tool that a user would wield like a pencil. Ideally, it would be something that could amplify arbitrary ideas that a user had. Of course, this would require much more generality than currently trained into models and likely the ability to interface with a number of other APIs.

Exploring a product from this angle will require levering hardware accessible to a normal user, meaning a heavy focus on mobile and edge computing systems. To that end, I’m making a harder pivot to mobile. Fortunately, there is a large and growing body of work being done in the area of mobile deep learning. Before we get started on the really interesting stuff, we have some housekeeping to attend to.

Having done a little work with both, I actually think it will be easier to develop native Android apps instead of using something like React Native. I understand that PyTorch supports both, but the ecosystem for the former seems tighter to me.

Fair warning: because this is covering so much new material, this is going to get long.

GitHub Repo

There is a PyTorch GitHub repo focusing on demos for Android that is worth a look. Today I’ll be going through some of the relevant code in the HelloWorldApp section. Honestly, most of that code. Additionally, I somewhat followed the Quickstart with a HelloWorld Example.

The following is basically notes on things that were confusing or somehow tripped me up in the process of getting HelloWorldApp to run on my phone. If you’re following along and aren’t familiar with Android or Java, I heartily recommend going through the code yourself instead of just downloading and running the repo, especially if you’re most experienced with something like a desktop Python environment. Basically, this picks up from the point of opening Android Studio and starting an empty project.

Gradle Stuff

Gradle is the build tool being used here, and it warrants some attention for anyone coming from the Python landscape or used something like CMake.

Looking at the project structure in Android Studio with the Android view, it’s easy to see that there are two build.gradle files, one for the project and one for the module. For future reference, the Project version is placed in the root folder and wil define behavior for the entire project. A module, being isolated from the broader project, has its own file to define settings that are only relevant to that module.

setttings.gradle

We need to configure this file at the top level. Possibly due to the specifics of my machine, I encountered build errors when starting with the boilerplate file, as follows.

pluginManagement {
    repositories {
        gradlePluginPortal()
        google()
        mavenCentral()
    }
}

dependencyResolutionManagement {
    repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS)  
    repositories {
        google()
        mavenCentral()
    }
}

rootProject.name = "HelloWorld"
include ':app'

I initially had a look at how to configure a build focusing on settings.gradle and thought that dependencyResolutionManagement had to be removed, but this felt like poor form and raised questions about what would be done in later projects that actually needed this. Fortunately, this is known to happen in the latest version of Android Studio.

The culprit is down to the single setting. FAIL_ON_PROJECT_REPOS. The build fails here because we are setting repositories at the project level. If we comment this line out, we’re golden.

build.gradle (:app)

The project-level build.gradle file is a little more involved, and I think it’s worth going over it in sections. Worth noting that this is written in Groovy, which is a whole thing on its own and beyond the scope of this article.

plugins {
    id 'com.android.application'
}

Here is where we specify plugins. This project in particular just uses the one. com.android.application is used here because the project is an app. For examples of what else might be listed here, library would be com.android.application, or using Kotlin would be org.jetbrains.kotlin.android. There are rich gradle docs going into more detail about the use of plugins.

android {
    compileSdk 32

    defaultConfig {
        applicationId "com.example.helloworld"
        minSdk 28
        targetSdk 31
        versionCode 1
        versionName "1.0"

        testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
    }

    buildTypes {
        release {
            minifyEnabled false
            proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
        }
    }

    compileOptions {
        sourceCompatibility JavaVersion.VERSION_1_8
        targetCompatibility JavaVersion.VERSION_1_8
    }  
}

Here is where the specifics of the Android build happen. Looking at defaultConfig, we can see that the name of our app is com.example.helloworld, the earliest Android version that this app can run on is 28, it’s really made for version 31 and compiled against 32, and that it’s version 1.0.

minifyEnabled might seem a little strange. I ran a search and found from this article on shrinking, obfuscating, and optimizing apps. These are things that improve security and shrink the build size. It is disabled by default in Android Studio apps; a developer would need to specify ProGuard rules to enable it, and this is a lightweight example app.

dependencies {
    implementation 'androidx.appcompat:appcompat:1.4.1'
    implementation 'com.google.android.material:material:1.4.0'
    implementation 'androidx.constraintlayout:constraintlayout:2.1.4'
    implementation 'org.pytorch:pytorch_android_lite:1.9.0'
    implementation 'org.pytorch:pytorch_android_torchvision:1.9.0'
    testImplementation 'junit:junit:4.13.2'
    androidTestImplementation 'androidx.test.ext:junit:1.1.3'
    androidTestImplementation 'androidx.test.espresso:espresso-core:3.4.0'
}

The dependencies keyword itself is pretty straightforward, so let’s have a look at what we’re actually using.

Side note: the implementation keyword exists as a way to manage dependencies in chains of libraries and speed up the build time. See Implementation Vs API in Android Gradle plugin 3.0 for more information.

  • AndroidX is the namespace for the Android Jetpack libraries. It’s the replacement for the Android Support Library and used for deployment.
  • AppCompat is a library to make Android apps work across different versions of Android itself.
  • Material Design is the library for Android’s design language.
  • The two PyTorch libraries there provide the actual deep learning functionality that we’ll get to later in this article.
  • testImplementation extends the implementation configuration such that we can use JUnit for tests.

build.gradle (HelloWorld)

With that cleared, we can get into the module-level build.gradle file.

buildscript {
    repositories {
        google()
        jcenter()
    }

    dependencies {
        classpath 'com.android.tools.build:gradle:7.2.0'
    }
}

The buildscript block provides the foundational information for the rest of the build script. Here you can see that we have added

  • google(), which is a shortcut to the Maven repository, and
  • jcenter(), which enables access to a huge number of Java and Android OSS libraries.

Both of these will be downloaded if they are not already installed.

allprojects {
    repositories {
         google()
        jcenter()
    }
}

This might look redundant. allprojects is distinct from buildscript in that the latter is for gradle itself, and the former is for the modules being built by gradle.

task clean(type: Delete) {
    delete rootProject.buildDir
}

Here we just remove the build directory when the app runs. If you want to get a better intuition around tasks, Build Script Basics has a lot of coverage there.

Design Stuff

activity_main.xml

Okay, I’m lying a little bit. Before going over the Java code proper, there is some preliminary work that needs to be done in the design. We need to take a look at activity_main.xml located in the res/values folder, as it is the file that specifies the layout. Android Studio lets you look at the code, the GUI layout, and both side-by-side. Because we are dealing with a relatively simple layout, we’ll just be going over the code in this one.

<?xml version="1.0" encoding="utf-8"?>
<androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
 xmlns:app="http://schemas.android.com/apk/res-auto"
 xmlns:tools="http://schemas.android.com/tools"
 android:layout_width="match_parent"
 android:layout_height="match_parent"
 tools:context=".MainActivity">
  ...
</androidx.constraintlayout.widget.ConstraintLayout>

This can be regarded as a bit of boilerplate. For the curious, ConstraintLayout is the parent class of all the design elements that we will be using. It enables the use of the GUI layout editor in Android Studio and avoids nesting of controls.

Being basically a static demonstration, we only need to put two controls inside of this, an ImageView and a TextView. We will explore things like buttons, check boxes, and so on in a later entry.

<ImageView
 android:id="@+id/image"
 android:layout_width="match_parent"
 android:layout_height="match_parent"
 android:scaleType="fitCenter"/>

The image view is where the static image gets displayed. For the sake of simplicity, we will be focusing on the original code and not adding constraints. Those will be part of an upcoming tutorial.

Here, android:id gives the ImageView a name that can be referenced by the activity. layout_width and layout_height are both set to match_parent, indicating that they will be no wider than ConstraintLayout minus its padding. scaleType being set to fitCenter centers the image and matches it to the dimensions of the ImageView. Now, all of this can result in some visual distortions that should be explored when we get to the project that we are ultimately aiming for, but it is being done here to give a consistent presentation for the sake of the example.

<TextView
 android:id="@+id/text"
 android:layout_width="match_parent"
 android:layout_height="wrap_content"
 android:layout_gravity="top"
 android:text="TextView"
 android:textSize="18sp"
 tools:layout_editor_absoluteX="169dp"
 tools:layout_editor_absoluteY="540dp" />

Next is the TextView, which needs a little more explaining. We see here that layout_height is set to wrap_content. This means that, instead of taking up as much space as possible, the TextView only expands enough vertically to contain its values. Setting layout_gravity to top pushes it to the top of the screen without changing the size.

Funnily enough, if you compare this code to the GitHub repo, this is where it starts to get really apparent that I made it a point of going through this myself instead of directly copying the repo. If you don’t change these settings, you’ll run into a warning about android:text being hardcoded instead of using a @string resource, which… Yeah, fair enough, but I don’t want to do a separate section on strings.xml. At any rate, that’s something that Android Studio adds.

Something else specific to Android Studio is the tools namespace, which is used by the layout editor. Seriously, if you remove these two and run this on the phone, it’ll run the same. What’s going on here is that tools is trying to ensure consistency between the XML code and the layout GUI.

Given that the last three attributes deal with layout dimensions, you might be wondering what’s going on with sp and dp. sp stands for scale-independent pixels, and dp stands for density-independent pixels. The functional difference between these is that sp is used for text, which is scaled by user preferences, and dp is used for everything else. Setting textSize to 18sp just makes the text conventionally readable.

Additional Files

Okay, there is some housekeeping that needs to be done. From the repo, you’ll need to copy the files in app/src/main/assets. These are just the test image and the pretrained model. Additionally, you’ll need to copy over ImageNetClasses.java from app/src/main/java/org/pytorch/helloworld, which gives you the labels for the classes in question. This goes right next to MainActivity.java. I would go over it if it were anything more complex than a class containing a string, and it’ll be clear enough on the other side of the main Java material here.

Java Stuff

With that, we are finally ready to get to the meat of the project. The entire focus of the rest of this writeup will be in MainActivity.java.

Imports

First, we’ll have a look at our libraries.

package com.example.helloworld;

This creates the package. These are generally used to avoid name conflicts. It isn’t critical here, but it is a good practice to keep an eye out for.

import androidx.appcompat.app.AppCompatActivity;

import android.os.Bundle;

import android.content.Context;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.util.Log;
import android.widget.ImageView;
import android.widget.TextView;

As we’ll see in a moment, AppCompatActivity here is the base class for our MainActivity. This allows us to use newer features on older devices.

Broadly speaking, the android package contains the set of tools and resources that will be used by the app. android.content provides basic data handling on the device itself. android.graphics deals more in things that are drawn to the screen, and will be needed to handle the images. I think it’s clear that android.util‘s use here to handle logging, but it also handles a lot of time, string, and numerical data types. Finally, android.widget handles the UI elements such as our ImageView and TextView.

import org.pytorch.IValue;
import org.pytorch.LiteModuleLoader;
import org.pytorch.Module;
import org.pytorch.Tensor;
import org.pytorch.helloworld.ImageNetClasses;
import org.pytorch.torchvision.TensorImageUtils;
import org.pytorch.MemoryFormat;

org.pytorch is the package containing the deep learning components of the app. Notice on line 5 that we have added the ImageNetClasses file copied from the GitHub repo.

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;

Finally, we add the java package. java.io just handles the input and output. This might be strange to someone who is only familiar with interpreted languages like Python.

MainActivity

public class MainActivity extends AppCompatActivity {
...
}

This is wrapped around the entirety of our code here. Notice that MainActivity is a class that extends AppCompatActivity. This class is declares MainActivity public, meaning that it is visible to all other classes. It contains only two methods. The first is onCreate(), which initializes the activity. The second is assetFilePath(), which is a helper method that lets us open our test image.

onCreate()

First, we’ll have a look at onCreate() in a number of sections:

@Override
protected void onCreate(Bundle savedInstanceState) {
    super.onCreate(savedInstanceState);
    setContentView(R.layout.activity_main);

    Bitmap bitmap = null;
    Module module = null;

@Override here indicates that we are overriding the method from the parent class, in this case AppCompatActivity.

Bundle savedInstanceState is the state of the application in a bundle, typically dynamic and non-persistent. This plays an important part in the application life cycle, for instance allowing us to navigate away from the app and come back to it with its data intact.

setContentView(R.layout.activity_main is reference to activity_main.xml that we referenced earlier. R is a class that contains the definitions for all the resources used by the package, so we’ll be seeing it peppered throughout the rest of this writeup.

It is worth noting for anyone coming from Python that Java doesn’t do duck typing; every variable needs a type. Personal note: I’ve done enough with Python to know this is a part of the language I’m going to love. bitmap will ultimately be our image, and module will be the pretrained model that we downloaded from the repo. These are being set up this way to provide them access outside of the following try catch statement:

try {
    bitmap = BitmapFactory.decodeStream(getAssets().open("poodle.jpeg"));
    module = LiteModuleLoader.load(assetFilePath(this, "model.pt"));
} catch (IOException e) {
    Log.e("PytorchHelloWorld", "Error reading assets", e);
    finish();
}

ImageView imageView = findViewById(R.id.image);
imageView.setImageBitmap(bitmap);

This statement will fail if either of those files are missing. getAssets() opens the file from the context of the environment, and that stream is decoded by BitmapFactory. LiteModuleLoader() has to call the helper method to open the pretrained model, but we’ll take a look at that in just a moment. I made it a point of looking at the docs, and this doesn’t actually include a counterpart that would be implied by Lite.

Successful completion of the try catch statement then allows the code to call the ImageView and fill it with the bitmap we just opened.

final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,  
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB,  
 MemoryFormat.CHANNELS_LAST  
);

Here is where the bitmap gets converted into a tensor that can be used by PyTorch. The keyword final here indicates that inputTensor is a constant. TensorImageUtils ultimately wants to turn this into a tensor comprised of 32-bit floating point numbers. The two arguments in the middle are hardcoded lists of floating point values specifying mean and standard deviations for R, G, and B components of a given pixel. This is a convention in machine learning that helps eliminate extreme values when training.

Finally, MemoryFormat.CHANNELS_LAST rearranges the tensor to make the channels last in contrast to a contiguous tensor. A contiguous tensor might have a shape of (3, 224, 244), the 3 indicating one channel for each R, G, and B value, and 224s indicating height and width. Channels last simply moves the 3 to the end to fit with PyTorch’s image representation conventions.

final Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();

There are some moving parts here, but this is where the model tries to predict the image’s class. inputTensor is passed to IValue, short for Interpreter Value. IValue is a Java representation of a TorchScript value. It can be any of a number of supported types, as determined by from(). As we can see here, this one is a Tensor. When the prediction is made, it is also converted to a tensor.

final float[] scores = outputTensor.getDataAsFloatArray();

float maxScore = -Float.MAX_VALUE;
int maxScoreIdx = -1;
for (int i = 0; i < scores.length; i++) {
    if (scores[i] > maxScore) {
        maxScore = scores[i];
    maxScoreIdx = i;
    }
}

scores is an array representing each of the possible classes. The highest score is the model’s prediction for the image’s class. To get that prediction, we loop over that array and set maxScoreIdx to the index where the largest number can be found.

String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx];

TextView textView = findViewById(R.id.text);
textView.setText(className);

className is just the human-readable representation of the prediction, which can be found in the ImageNetClasses file. The TextView is found from its name in the R class, and the name of the prediction is assigned to it.

assetFilePath()

This is a helper method that is used by LiteModuleLoader. Its purpose is to make sure that the image file exists and get its path, or throw an exception if something is wrong with the process.

The first argument here is a Context, the global information about the environment provided by Android. As used by onCreate(), the Context in question is just this, a reference to the current object. For our purposes, assetName is simply the filename of the pretrained model, model.pt.

Unlike the onCreate() method, this uses the throws keyword. This portion indicates that the method might return an IOException if something goes wrong in our file handling process.

public static String assetFilePath(Context context, String assetName) throws IOException {

    ...
}

Let’s move on to the implementation of this method.

    File file = new File(context.getFilesDir(), assetName);
    if (file.exists() && file.length() > 0) {
        return file.getAbsolutePath();
    }

It’s worth outlining the order of operations here. The File class here represents the pathname that we will try to open. context.getFilesDir() gets the absolute path to the directory, and as mentioned before, assetName is the filename of the model we are opening, model.pt.

The if block here just checks that the file exists and is a nonzero length, and will close the method and return the path to be opened by LiteModuleLoader in onCreate() if it’s successful. Simple enough, but what happens if either of these conditions are not met?

        try (InputStream is = context.getAssets().open(assetName)) {
            try (OutputStream os = new FileOutputStream(file)) {
                byte[] buffer = new byte[4 * 1024];
                int read;
                while ((read = is.read(buffer)) != -1) {
                    os.write(buffer, 0, read);
                }
                os.flush();
            }
            return file.getAbsolutePath();
        }

These are nested try-with-resoruces expressions. These statements create their respective InputStream and OutputStream and ensure that they are closed when the statements are completed.

What’s going on here? This creates the model file in the event that one does not already exist or the current one has 0 length. An InputStream using the “model.pt” filename is created, indicating that we are ready to input a stream of bytes to is. Then a FileOutputStream is created so that we will be able to write data to the file “model.pt”. A 4KB byte buffer is created, with each value set to 0. In the while block, the buffer is written to os, and the path to the file is returned.

New Beginnings

I now have a framework for experimenting with PyTorch in a mobile environment. Immediate next steps include actually letting a user make selections of images, using an intent to take new images with the phone’s camera, and doing real-time classification with text displayed in the ImageView itself. Early Days!