Skip to main content
Version: Next

Image Classification

In this tutorial, you will build an app that can take pictures and classify objects in each image using an on-device image classification model.

Viewing this Demo

In order to view this demo download the PlayTorch app.

Preview

If you want a sneak peek at what you'll be building, run this Snack by scanning the QR code in the PlayTorch app!

Overview

We'll go through the following steps:

  1. Create a new project with Snack by Expo
  2. Run the project in the PlayTorch app
  3. Add PlayTorch dependencies
  4. Add a camera view
  5. Process an image
  6. Run a model
  7. Display results

Starting a New Project

We will be using a tool called Snack by Expo to write our code in the browser and then run it on our device. To learn more about Snack, visit this link.

Open a new tab in your web browser and navigate to snack.expo.dev.

You will see a code editor with the App.js file open. On the right side of the window, you will see several options for running your code. It defaults to "Web", but let's select "My Device" so we can use the PlayTorch app to enable ML in our app.

Run the New Project

Open the PlayTorch app and from the home screen, tap "Scan QR Code".

If you have never done this before, it will ask for camera permissions. Grant the app camera permissions and scan the QR code from the right side of the Snack window.

If you haven't made any changes to the snack, you should see a screen that looks like this:

Try changing the backgroundColor to #800080 on line 29 and you will see that your app screen changes in real time to match it.

Add PlayTorch Dependencies

In order to add ML to this simple demo, we first need to add the PlayTorch dependencies.

In the left side of the Snack window, you will see a list of the files being used in your Snack. Open the one called package.json and replace the contents with the following:

{
"dependencies": {
"react-native-pytorch-core": "0.2.0",
"react-native-safe-area-context": "3.3.2"
}
}

This is a list of external libraries that we will be using to build our ML powered demo.

Add a Camera View

Now that we have the extra dependencies loaded, we can use them to prepare our user interface for collecting images to classify.

Go ahead and replace the contents of App.js with the following. Let's walk through what changes are included:

  1. Import dependencies. It's worth noting we import the Camera component from the react-native-pytorch-core package. That is the core PlayTorch SDK
  2. Update the App function to render our new UI
    1. Get the "safe area insets" which let us know how much of the screen we can actually use to render avoiding camera notches and bottom bars.
    2. Make the Camera view fill the whole screen except for the unsafe areas on the bottom so the capture button doesn't get obscured
    3. Create a label container for when we begin classifying images that floats near the top
  3. Create a styles object that is used to set the styles for or label container
App.js
// 1. Import dependencies
import * as React from 'react';
import {StyleSheet, Text, View} from 'react-native';
import {Camera} from 'react-native-pytorch-core';
import {useSafeAreaInsets} from 'react-native-safe-area-context';

// 2. App function to render a camera and a text
export default function App() {
// 2.i. Safe area insets to compensate for notches and bottom bars
const insets = useSafeAreaInsets();
return (
<View style={StyleSheet.absoluteFill}>
{/* 2.ii. Render camara and make it parent filling */}
<Camera style={[StyleSheet.absoluteFill, {bottom: insets.bottom}]} />
{/* 2.iii. Label container with custom render style and a text */}
<View style={styles.labelContainer}>
<Text>Label will go here</Text>
</View>
</View>
);
}

// 3. Custom render style for label container
const styles = StyleSheet.create({
labelContainer: {
padding: 20,
margin: 20,
marginTop: 40,
borderRadius: 10,
backgroundColor: 'white',
},
});

Once you make these changes, open the Snack back up in the PlayTorch app and you will see the camera view filling the screen and our label container with a placeholder label.

Notice clicking the capture button doesn't do anything yet. Let's fix that.

The added lines below do the following:

  1. Create an async (runs in the background) function called handleImage that simply:
    1. Logs the image object passed to it
    2. Releases the image from memory. Not calling image.release() will result in the camera not providing an image on consecutive presses of the capture button. The only way to fix this is to force close the PlayTorch app and reopen it.
  2. Set the handleImage function to be called everytime an image is captured by the Camera component
App.js
// Import dependencies
import * as React from 'react';
import {StyleSheet, Text, View} from 'react-native';
import {Camera} from 'react-native-pytorch-core';
import {useSafeAreaInsets} from 'react-native-safe-area-context';

// App function to render a camera and a text
export default function App() {
// Safe area insets to compensate for notches and bottom bars
const insets = useSafeAreaInsets();

// 1. Function to handle images whenever the user presses the capture button
async function handleImage(image) {
// 1.i. Log the image object to the console
console.log(image);
// 1.ii. Release the image from memory
image.release();
}

return (
<View style={StyleSheet.absoluteFill}>
{/* Render camara and make it parent filling */}
<Camera
style={[StyleSheet.absoluteFill, {bottom: insets.bottom}]}
// 2. Add handle image callback on the camera component
onCapture={handleImage}
/>
{/* Label container with custom render style and a text */}
<View style={styles.labelContainer}>
<Text>Label will go here</Text>
</View>
</View>
);
}

// Custom render style for label container
const styles = StyleSheet.create({
labelContainer: {
padding: 20,
margin: 20,
marginTop: 40,
borderRadius: 10,
backgroundColor: 'white',
},
});

Open the logs in the Snack window by clicking the settings gear icon at the bottom of the window, enabling the Panel, and clicking the logs tab of the newly opened panel.

After taking a picture, you should see a logged object with an ID field.

Now that we can capture images, let's write some code to prepare them for machine learning!

Process an Image

In order for us to run machine learning on our image to classify it, we first need to translate it to a format that the ML model understands.

ML models don't work with images, but with tensors (multi dimensional matrices) with their own specific data format.

The MobileNet model that we will be using only needs its image data in its tensors to be exactly 224 by 224 with normalized data and a specific shape.

Let's create a new file by clicking the new file button in the left pane of the Snack window. We'll call it ImageClassifier.js

Let's walk through the code below to see how we get our image converted to a proper tensor:

  1. Import torch, torchvision, and media from the react-native-pytorch-core package (the PlayTorch SDK)
  2. Create an alias called T for the transforms object from the torchvision to make it shorter to access the transform functions
  3. Create an async function called classifyImage that takes in an image and does the following:
    1. Grab the width and the height of the image
    2. Create a blob of the image (a blob is just a raw data representation of something). In this case, the blob holds a byte representation of the image in the format height, width, and channels, or HWC for short.
    3. Create a tensor from the blob with the shape height by width by channels (RGB). It is important that the order of HWC is aligned with the byte representation of the image.
    4. Rearrange the tensor shape to be channels (RGB) by height by width. This is important because the image classifiation model that is used in this tutorial requires the tensor shape to be in this order.
    5. Divide all of the values in the tensor by 255. This is important because the image classification model requires the tensor values to be between [0, 1].
    6. Center crop the image data within the tensor. The center crop will result in a squared image tensor with the shortest side defining the size.
    7. Resize the tensor to 3 by 224 by 224 (or tensor shape [3, 224, 224]) to match the size the model expects as tensor input format.
    8. Normalize the tensor image with mean and standard deviation.
    9. Add one more leading dimension to the tensor to be in the shape 1 by 3 by 224 by 224 (or tensor shape [1, 3, 224, 224]). The image classification model can classify multiple images in parallel. The leading 1 represents the batch size, which is 1 because in this tutorial it only needs to process one image at a time.
    10. Return the shape of the tensor, which is [1, 3, 224, 224].
ImageClassifier.js
// 1. Import torch, torchvision, and media from PlayTorch SDK
import {torch, torchvision, media} from 'react-native-pytorch-core';

// 2. Alias for torchvision transforms
const T = torchvision.transforms;

// 3. The classifyImage function that will process an image and return the top
// class label
export default async function classifyImage(image) {
// 3.i. Get image width and height
const width = image.getWidth();
const height = image.getHeight();

// 3.ii. Convert image to blob, which is a byte representation of the image
// in the format height (H), width (W), and channels (C), or HWC for short
const blob = media.toBlob(image);

// 3.iii. Get a tensor from image the blob and also define in what format
// the image blob is.
let tensor = torch.fromBlob(blob, [height, width, 3]);

// 3.iv. Rearrange the tensor shape to be [CHW]
tensor = tensor.permute([2, 0, 1]);

// 3.v. Divide the tensor values by 255 to get values between [0, 1]
tensor = tensor.div(255);

// 3.vi. Crop the image in the center to be a squared image
const centerCrop = T.centerCrop(Math.min(width, height));
tensor = centerCrop(tensor);

// 3.vii. Resize the image tensor to 3 x 224 x 224
const resize = T.resize(224);
tensor = resize(tensor);

// 3.viii. Normalize the tensor image with mean and standard deviation
const normalize = T.normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]);
tensor = normalize(tensor);

// 3.ix. Unsqueeze adds 1 leading dimension to the tensor
tensor = tensor.unsqueeze(0);

// 3.x. Return the tensor shape [1, 3, 224, 224]
return tensor.shape;
}

Let's double check the output of this function to make sure we are on the right track.

Go back to App.js and instead of just logging the image object, let's run the classifyImage function on the image object first and log the result instead.

  1. Import the classifyImage function from the ImageClassifier.js file.
  2. Call classifyImage function with the image from the camera.
  3. Log the result to the console.
App.js
// Import dependencies
import * as React from 'react';
import {StyleSheet, Text, View} from 'react-native';
import {Camera} from 'react-native-pytorch-core';
import {useSafeAreaInsets} from 'react-native-safe-area-context';
// 1. Import classify image function
import classifyImage from './ImageClassifier';

// App function to render a camera and a text
export default function App() {
// Safe area insets to compensate for notches and bottom bars
const insets = useSafeAreaInsets();

// Function to handle images whenever the user presses the capture button
async function handleImage(image) {
// 2. Call the classify image function with the camera image
const result = await classifyImage(image);
// 3. Log the result from classify image to the console
console.log(result);
// Release the image from memory
image.release();
}

return (
<View style={StyleSheet.absoluteFill}>
{/* Render camara and make it parent filling */}
<Camera
style={[StyleSheet.absoluteFill, {bottom: insets.bottom}]}
// Add handle image callback on the camera component
onCapture={handleImage}
/>
{/* Label container with custom render style and a text */}
<View style={styles.labelContainer}>
<Text>Label will go here</Text>
</View>
</View>
);
}

// Custom render style for label container
const styles = StyleSheet.create({
labelContainer: {
padding: 20,
margin: 20,
marginTop: 40,
borderRadius: 10,
backgroundColor: 'white',
},
});

When you check the log ouput after capturing an image now, you should see [1,3,224,224], which is the tensor shape we need.

Now that the image has been converted to a properly formatted tensor, we are ready to run the machine learning model!

Run the Model

Let's head back to our ImageClassifier.js file and make some updates to the classifyImage function to actually classify the image.

For the changes we make to ImageClassifer.js you'll need to upload a file containing the labels for the different things the model knows how to classify.

Click here to do download the JSON file and then drag and drop it into the Snack window to upload it.

Here's a quick summary of the changes we are making to run the model:

  1. Import the MobileModel to help us load our machine learning model
  2. Import the class labels from the ImageNetClasses.json file. The JSON file is a mapping between image class indices to class labels.
  3. Store the url for the model we'll be using in a variable for later access
  4. Create a variable for storing our model and set it to null initially
  5. After we have the tensor all ready, check to see if our model is still null. If it is, initialize it by downloading it and loading it into memory.
  6. Run the model on our image converted into a tensor by calling model.forward(tensor). The return value will be a Tensor of shape [1, 1000] where 1 the batch size (remember in this tutorial only 1 image is processed at a time) and 1000 are 1000 probability values (one probability value for each class in the ImageNetClasses.json).
  7. Find the index with the highest probability, which represents the most likely class detected in the image.
  8. Resolve the most likely image class by mapping the index to the class label and return it.
ImageClassifer.js
// 1. Add import for MobileModel from PlayTorch SDK
import {
MobileModel,
torch,
torchvision,
media,
} from 'react-native-pytorch-core';
// 2. Import the ImageNetClasses JSON file, which is used below to map the
// processed model result to a class label
import * as ImageNetClasses from './ImageNetClasses.json';

// Alias for torchvision transforms
const T = torchvision.transforms;

// 3. URL to the image classification model that is used int his example
const MODEL_URL =
'https://github.com/facebookresearch/playtorch/releases/download/v0.1.0/mobilenet_v3_small.ptl';

// 4. Variable to hold a reference to the loaded ML model
let model = null;

// The classifyImage function that will process an image and return the top
// class label
export default async function classifyImage(image) {
// Get image width and height
const width = image.getWidth();
const height = image.getHeight();

// Convert image to blob, which is a byte representation of the image
// in the format height (H), width (W), and channels (C), or HWC for short
const blob = media.toBlob(image);

// Get a tensor from image the blob and also define in what format
// the image blob is.
let tensor = torch.fromBlob(blob, [height, width, 3]);

// Rearrange the tensor shape to be [CHW]
tensor = tensor.permute([2, 0, 1]);

// Divide the tensor values by 255 to get values between [0, 1]
tensor = tensor.div(255);

// Crop the image in the center to be a squared image
const centerCrop = T.centerCrop(Math.min(width, height));
tensor = centerCrop(tensor);

// Resize the image tensor to 3 x 224 x 224
const resize = T.resize(224);
tensor = resize(tensor);

// Normalize the tensor image with mean and standard deviation
const normalize = T.normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]);
tensor = normalize(tensor);

// Unsqueeze adds 1 leading dimension to the tensor
tensor = tensor.unsqueeze(0);

// 5. If the model has not been loaded already, it will be downloaded from
// the URL and then loaded into memory.
if (model == null) {
const filePath = await MobileModel.download(MODEL_URL);
model = await torch.jit._loadForMobile(filePath);
}

// 6. Run the ML inference with the pre-processed image tensor
const output = await model.forward(tensor);

// 7. Get the index of the value with the highest probability
const maxIdx = output.argmax().item();

// 8. Resolve the most likely class label and return it
return ImageNetClasses[maxIdx];
}
note

Since we are initializing the model the first time we run the classifyImage function, it will be slower. Subsequent runs will go much faster since they don't have to download the model or load it into memory.

If you do not wish to upload your model to a publicly accessible server, you may instead place the file in a directory of your choice and replace the line const filePath = await MobileModel.download(MODEL_URL); with const filePath = await MobileModel.download(require('./path/to/model.ptl'));.

Now that we are actually running the model, let's try it out in the PlayTorch app again and see what it logs. You should see a class label in the logs which is a word or list of words.

Excellent! It's logging the classification of each picture!

Display the Result

Lastly, let's update our UI to display the result of our model!

Go back to App.js and make the following changes:

  1. Create a state variable to store the topClass we get from the model
  2. In the handleImage function, set the topClass state variable to the result of the classifyImage function
  3. Change the text in the UI to display the topClass state variable instead of the placeholder text
App.js
// Import dependencies
import * as React from 'react';
import {StyleSheet, Text, View} from 'react-native';
import {Camera} from 'react-native-pytorch-core';
import {useSafeAreaInsets} from 'react-native-safe-area-context';
// Import classify image function
import classifyImage from './ImageClassifier';

// App function to render a camera and a text
export default function App() {
// Safe area insets to compensate for notches and bottom bars
const insets = useSafeAreaInsets();
// 1. Create a React state to store the top class returned from the
// classifyImage function
const [topClass, setTopClass] = React.useState(
"Press capture button to classify what's in the camera view!",
);

// Function to handle images whenever the user presses the capture button
async function handleImage(image) {
// Call the classify image function with the camera image
const result = await classifyImage(image);
// 2. Set result as top class label state
setTopClass(result);
// Release the image from memory
image.release();
}

return (
<View style={StyleSheet.absoluteFill}>
{/* Render camara and make it parent filling */}
<Camera
style={[StyleSheet.absoluteFill, {bottom: insets.bottom}]}
// Add handle image callback on the camera component
onCapture={handleImage}
/>
{/* Label container with custom render style and a text */}
<View style={styles.labelContainer}>
{/* 3. Change the text to render the top class label */}
<Text>{topClass}</Text>
</View>
</View>
);
}

// Custom render style for label container
const styles = StyleSheet.create({
labelContainer: {
padding: 20,
margin: 20,
marginTop: 40,
borderRadius: 10,
backgroundColor: 'white',
},
});

And with that you have a working image classifer!