YOLO Browser

ONNX

Untitled

ONNX (Open Neural Network Exchange) is an open-source framework that plays a pivotal role in the field of Machine Learning Operations (MLOps) by providing a standardized format for representing and sharing deep learning models across different frameworks, libraries, and platforms. It acts as a bridge between various deep learning frameworks, making it easier to build, train, and deploy models seamlessly. ONNX is designed to enhance interoperability, speed up development, and promote innovation in the AI and machine learning domains.

Untitled

Key Features of ONNX:

  1. Interoperability: ONNX allows models to be easily transferred between popular deep learning frameworks like TensorFlow, PyTorch, and more. This interoperability is crucial for MLOps as it enables teams to choose the best tools for specific tasks.
  2. Optimized Inference: ONNX Runtime, a runtime engine that supports the execution of ONNX models, is highly optimized for inference tasks. It ensures that models can be executed efficiently in production environments.
  3. Hardware Acceleration: ONNX can take advantage of hardware accelerators like GPUs and TPUs, making it suitable for deploying models in high-performance settings.
  4. Cross-Platform Compatibility: ONNX models can be deployed on various platforms, including edge devices, cloud servers, and even web browsers, making it versatile for different use cases.

ONNX in the Browser with WebGL

Now, let's focus on the specific topic for this week's session: Model Runtime on the Browser with WebGL using ONNX. This is an exciting development in the world of AI and MLOps because it allows machine learning models to be executed directly within web browsers, offering several advantages:

1. Low Latency Inference: ONNX models can be run in the browser, reducing the need for constant communication with remote servers. This results in lower latency and faster model execution, crucial for real-time applications like gaming, interactive websites, and more.

2. Privacy and Data Security: By running models locally in the browser, sensitive data can be kept on the client-side, enhancing privacy and data security. This is particularly important for applications involving personal or confidential information.

3. Offline Availability: ONNX models deployed in the browser remain accessible even without an internet connection. This is beneficial for applications that need to work in offline or intermittent connectivity scenarios.

4. Cross-Platform Compatibility: WebGL, a JavaScript API for rendering interactive 2D and 3D graphics, allows for the execution of ONNX models on a wide range of devices and platforms, including desktops, mobile phones, and VR headsets.

5. Web-Based AI: The combination of ONNX and WebGL enables the development of web-based AI applications, such as interactive demos, educational tools, and games that incorporate machine learning capabilities.

PyTorch to ONNX

pip install onnx onnxruntime
import torch
import torch.onnx
import timm
model = timm.create_model('resnetv2_50', pretrained=True)
model = model.eval()
model_script = torch.jit.script(model)
torch.onnx.export(model_script, torch.randn(1, 3, 224, 224), "resnetv2_50.onnx", verbose=True, input_names=['input'], output_names=['output'], dynamic_axes={'input': {0: 'batch'}})

Visualize ONNX Graph

https://netron.app/

https://onnxruntime.ai/docs/get-started/with-python.html

Tracing vs Scripting

Internally, [torch.onnx.export()](https://pytorch.org/docs/stable/onnx.html#torch.onnx.export) requires a [torch.jit.ScriptModule](https://pytorch.org/docs/stable/generated/torch.jit.ScriptModule.html#torch.jit.ScriptModule) rather than a [torch.nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module). If the passed-in model is not already a ScriptModuleexport() will use tracing to convert it to one:

  • Tracing: If torch.onnx.export() is called with a Module that is not already a ScriptModule, it first does the equivalent of [torch.jit.trace()](https://pytorch.org/docs/stable/generated/torch.jit.trace.html#torch.jit.trace), which executes the model once with the given args and records all operations that happen during that execution. This means that if your model is dynamic, e.g., changes behavior depending on input data, the exported model will not capture this dynamic behavior. We recommend examining the exported model and making sure the operators look reasonable. Tracing will unroll loops and if statements, exporting a static graph that is exactly the same as the traced run. If you want to export your model with dynamic control flow, you will need to use scripting.
  • Scripting: Compiling a model via scripting preserves dynamic control flow and is valid for inputs of different sizes. To use scripting:Use [torch.jit.script()](https://pytorch.org/docs/stable/generated/torch.jit.script.html#torch.jit.script) to produce a ScriptModule.Call torch.onnx.export() with the ScriptModule as the model. The args are still required, but they will be used internally only to produce example outputs, so that the types and shapes of the outputs can be captured. No tracing will be performed.

ONNX Opset

https://github.com/onnx/onnx/blob/main/docs/Operators.md

An OpSet is essentially a collection or set of operators that are supported and defined by a specific version of the ONNX specification. It defines which operators are available and how they behave. In other words, an OpSet is a versioned set of operator definitions and rules.

Verifying ONNX

import onnxruntime as ort
import numpy as np

Test with Random Input

ort_session = ort.InferenceSession("resnetv2_50.onnx")
ort_session.run(['output'], {'input': np.random.randn(1, 3, 224, 224).astype(np.float32)})

Image Input

import onnxruntime as ort
import numpy as np
from PIL import Image
 
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
 
with open("imagenet_classes.txt", "r") as f:
    classes_response = f.read()
classes_list = [line.strip() for line in classes_response.split('\n')]
 
ort_session = ort.InferenceSession("resnetv2_50.onnx")
output = ort_session.run(
    ['output'], {'input': np.random.randn(1, 3, 224, 224).astype(np.float32)})
 
print(f"random_output = {output}")
 
img = Image.open("test_image.jpeg")
img = img.convert("RGB")
img = img.resize((224, 224))
img_np = np.array(img)
 
print(f"image shape = {img_np.shape}")
 
img_np = img_np / 255.0
img_np = (img_np - mean) / std
img_np = img_np.transpose(2, 0, 1)
 
ort_outputs = ort_session.run(
    ['output'], {'input': img_np[None, ...].astype(np.float32)})
 
pred_class_idx = np.argmax(ort_outputs[0])
 
predicted_class = classes_list[pred_class_idx]
 
print(f"{predicted_class=}")

ONNX on Browser

npx create-next-app@latest
cp resnetv2_50.onnx public/

page.tsx

"use client";
 
import { useEffect, useState } from "react";
import * as ort from "onnxruntime-web";
import ndarray from "ndarray";
import ops from "ndarray-ops";
import { softmax } from "@/utils/math/softmax-2";
import { imagenetClassesTopK } from "@/utils/imagenet";
 
export default function Home() {
  const [selectedImage, setSelectedImage] = useState<File | null>(null);
  const [resizedImage, setResizedImage] = useState("");
  const [inferenceSession, setInferenceSession] =
    useState<ort.InferenceSession | null>(null);
 
  const [modelOutput, setModelOutput] = useState<
    { id: string; index: number; name: string; probability: number }[]
  >([]);
 
  useEffect(() => {
    // load model
    ort.InferenceSession.create("/resnetv2_50.onnx", {
      executionProviders: ["webgl"],
      graphOptimizationLevel: "all",
    }).then((session) => setInferenceSession(session));
  }, []);
 
  useEffect(() => {
    const mycode = async () => {
      try {
        if (!inferenceSession) return;
 
        const image = document.createElement("img");
        image.onload = async () => {
          const canvas = document.createElement("canvas");
          canvas.width = 224;
          canvas.height = 224;
 
          if (!canvas) return;
 
          const canvas2DCtx = canvas.getContext("2d");
 
          if (!canvas2DCtx) return;
 
          canvas2DCtx.drawImage(image, 0, 0, 224, 224);
 
          const resizedImage = canvas.toDataURL();
 
          setResizedImage(resizedImage);
 
          const imageData = canvas2DCtx.getImageData(
            0,
            0,
            canvas2DCtx.canvas.width,
            canvas2DCtx.canvas.height
          );
          const { data, width, height } = imageData;
 
          // data processing
          const dataTensor = ndarray(new Float32Array(data), [
            width,
            height,
            4,
          ]);
 
          const dataProcessedTensor = ndarray(
            new Float32Array(width * height * 3),
            [1, 3, width, height]
          );
 
          // permute [H, W, C] -> [B, C, H, W]
          ops.assign(
            dataProcessedTensor.pick(0, 0, null, null),
            dataTensor.pick(null, null, 0)
          );
          ops.assign(
            dataProcessedTensor.pick(0, 1, null, null),
            dataTensor.pick(null, null, 1)
          );
          ops.assign(
            dataProcessedTensor.pick(0, 2, null, null),
            dataTensor.pick(null, null, 2)
          );
 
          // image normalization with mean and std
          ops.divseq(dataProcessedTensor, 255);
          ops.subseq(dataProcessedTensor.pick(0, 0, null, null), 0.485);
          ops.subseq(dataProcessedTensor.pick(0, 1, null, null), 0.456);
          ops.subseq(dataProcessedTensor.pick(0, 2, null, null), 0.406);
 
          ops.divseq(dataProcessedTensor.pick(0, 0, null, null), 0.229);
          ops.divseq(dataProcessedTensor.pick(0, 1, null, null), 0.224);
          ops.divseq(dataProcessedTensor.pick(0, 2, null, null), 0.225);
 
          const tensor = new ort.Tensor(
            "float32",
            new Float32Array(width * height * 3),
            [1, 3, width, height]
          );
          (tensor.data as Float32Array).set(dataProcessedTensor.data);
 
          // const randomA = Float32Array.from(result);
 
          // const tensorA = new ort.Tensor("float32", randomA, [1, 3, 224, 224]);
 
          const results = await inferenceSession.run({
            input: tensor,
          });
 
          if (results.output) {
            const res = results.output;
 
            const output = softmax(Array.prototype.slice.call(res.data));
            const topK = imagenetClassesTopK(output, 5);
 
            setModelOutput(topK);
          }
        };
 
        if (selectedImage) {
          image.setAttribute("src", URL.createObjectURL(selectedImage));
        }
      } catch (e: any) {
        console.error(e, e.toString());
      }
    };
 
    if (selectedImage) {
      mycode();
    }
  }, [inferenceSession, selectedImage]);
 
  return (
    <main className="flex min-h-screen flex-col items-center p-24 gap-y-12">
      {!!!inferenceSession && "Loading Model..."}
      {!!inferenceSession && (
        <input
          type="file"
          name="myImage"
          onChange={(event) => {
            if (event.target.files && event.target.files.length > 0) {
              const file = event.target.files[0];
              console.log(event.target.files[0]);
              setSelectedImage(event.target.files[0]);
            }
          }}
        />
      )}
      {resizedImage && (
        // eslint-disable-next-line @next/next/no-img-element
        <img src={resizedImage} alt="Resized Image" className="rounded-md" />
      )}
      {modelOutput.length > 0 && (
        <table className="table-auto max-w-2xl w-full">
          <thead>
            <tr>
              <th className="py-3.5 px-4 text-sm font-normal text-left rtl:text-right text-gray-500 dark:text-gray-400">
                Index
              </th>
              <th className="py-3.5 px-4 text-sm font-normal text-left rtl:text-right text-gray-500 dark:text-gray-400">
                Name
              </th>
              <th className="py-3.5 px-4 text-sm font-normal text-left rtl:text-right text-gray-500 dark:text-gray-400">
                Probability
              </th>
            </tr>
          </thead>
          <tbody>
            {modelOutput.map((m, i) => (
              <tr key={i}>
                <td className="px-4 py-4 text-sm text-gray-500 dark:text-gray-300 whitespace-nowrap">
                  {m.index}
                </td>
                <td className="px-4 py-4 text-sm text-gray-500 dark:text-gray-300 whitespace-nowrap">
                  {m.name}
                </td>
                <td className="px-4 py-4 text-sm text-gray-500 dark:text-gray-300 whitespace-nowrap">
                  {m.probability.toFixed(2)}
                </td>
              </tr>
            ))}
          </tbody>
        </table>
      )}
    </main>
  );
}

Let’s break it down to understand what is happening

Loading the ONNX Model

useEffect(() => {
    // load model
    ort.InferenceSession.create("/resnetv2_50.onnx", {
      executionProviders: ["webgl"],
      graphOptimizationLevel: "all",
    }).then((session) => setInferenceSession(session));
  }, []);

executionProviders

WebAssembly backend

ONNX Runtime Web currently support all operators in ai.onnx and ai.onnx.ml.

WebGL backend

ONNX Runtime Web currently supports a subset of operators in ai.onnx operator set. See operators.md for a complete, detailed list of which ONNX operators are supported by WebGL backend.

WebGPU

https://developer.chrome.com/blog/webgpu-io2023/

Resizing the Image to 224, 224

const image = document.createElement("img");
image.onload = async () => {
  const canvas = document.createElement("canvas");
  canvas.width = 224;
  canvas.height = 224;
 
  if (!canvas) return;
 
  const canvas2DCtx = canvas.getContext("2d");
 
  if (!canvas2DCtx) return;
 
  canvas2DCtx.drawImage(image, 0, 0, 224, 224);
 
  const resizedImage = canvas.toDataURL();
 
  setResizedImage(resizedImage);
 
  const imageData = canvas2DCtx.getImageData(
    0,
    0,
    canvas2DCtx.canvas.width,
    canvas2DCtx.canvas.height
  );
  const { data, width, height } = imageData;

Creating an empty placeholder for Preprocessed Image

// data processing
const dataTensor = ndarray(new Float32Array(data), [
  width,
  height,
  4,
]);
 
const dataProcessedTensor = ndarray(
  new Float32Array(width * height * 3),
  [1, 3, width, height]
);

HWC to BCHW

// permute [H, W, C] -> [B, C, H, W]
ops.assign(
  dataProcessedTensor.pick(0, 0, null, null),
  dataTensor.pick(null, null, 0)
);
ops.assign(
  dataProcessedTensor.pick(0, 1, null, null),
  dataTensor.pick(null, null, 1)
);
ops.assign(
  dataProcessedTensor.pick(0, 2, null, null),
  dataTensor.pick(null, null, 2)
);

Image Standardization with Mean and Std of ImageNet

// image normalization with mean and std
ops.divseq(dataProcessedTensor, 255);
ops.subseq(dataProcessedTensor.pick(0, 0, null, null), 0.485);
ops.subseq(dataProcessedTensor.pick(0, 1, null, null), 0.456);
ops.subseq(dataProcessedTensor.pick(0, 2, null, null), 0.406);
 
ops.divseq(dataProcessedTensor.pick(0, 0, null, null), 0.229);
ops.divseq(dataProcessedTensor.pick(0, 1, null, null), 0.224);
ops.divseq(dataProcessedTensor.pick(0, 2, null, null), 0.225);

Creating an ONNX Runtime Tensor

const tensor = new ort.Tensor(
        "float32",
        new Float32Array(width * height * 3),
        [1, 3, width, height]
      );
      (tensor.data as Float32Array).set(dataProcessedTensor.data);

Run Inference and Get Classname

const results = await inferenceSession.run({
  input: tensor,
});
 
if (results.output) {
  const res = results.output;
 
  const output = softmax(Array.prototype.slice.call(res.data));
  const topK = imagenetClassesTopK(output, 5);
 
  setModelOutput(topK);
}

There’s actually a lot more code involved to do simple operations that we do in python

imagenet.ts

import { imagenetClasses } from '@/config/imagenet-classes';
import _ from 'lodash';
 
/**
 * Find top k imagenet classes
 */
export function imagenetClassesTopK(classProbabilities: any, k = 5) {
  const probs =
      _.isTypedArray(classProbabilities) ? Array.prototype.slice.call(classProbabilities) : classProbabilities;
 
  const sorted = _.reverse(_.sortBy(probs.map((prob: any, index: number) => [prob, index]), probIndex => probIndex[0]));
 
  const topK = _.take(sorted, k).map((probIndex: any) => {
    const iClass = imagenetClasses[probIndex[1]];
    return {
      id: iClass[0],
      index: parseInt(probIndex[1], 10),
      name: iClass[1].replace(/_/g, ' '),
      probability: probIndex[0]
    };
  });
  return topK;
}

https://github.com/satyajitghana/web-onnx-classifier

ONNX Graph Optimization: https://onnxruntime.ai/docs/performance/model-optimizations/graph-optimizations.html

Transformer Optimization Tool: https://onnxruntime.ai/docs/performance/transformers-optimization.html

package.json

{
  "name": "web-onnx",
  "version": "0.1.0",
  "private": true,
  "scripts": {
    "dev": "next dev",
    "build": "next build",
    "start": "next start",
    "lint": "next lint"
  },
  "dependencies": {
    "@types/node": "20.4.2",
    "@types/react": "18.2.15",
    "@types/react-dom": "18.2.7",
    "autoprefixer": "10.4.14",
    "eslint": "8.45.0",
    "eslint-config-next": "13.4.10",
    "lodash": "^4.17.21",
    "ndarray": "^1.0.19",
    "ndarray-ops": "^1.2.2",
    "next": "13.4.10",
    "onnxruntime-web": "^1.15.1",
    "postcss": "8.4.26",
    "react": "18.2.0",
    "react-dom": "18.2.0",
    "tailwindcss": "3.3.3",
    "typescript": "5.1.6"
  },
  "devDependencies": {
    "@types/lodash": "^4.14.195",
    "@types/ndarray": "^1.0.11",
    "@types/ndarray-ops": "^1.2.4"
  }
}
npm run dev

Untitled

YOLOV8

Untitled

https://github.com/ultralytics/ultralytics

pip install ultralytics
sudo apt update && sudo apt install libgl1
from ultralytics import YOLO
 
# Create a new YOLO model from scratch
model = YOLO('yolov8n.yaml')
 
# Load a pretrained YOLO model (recommended for training)
model = YOLO('yolov8n.pt')
 
# Train the model using the 'coco128.yaml' dataset for 3 epochs
results = model.train(data='coco128.yaml', epochs=3)
 
# Evaluate the model's performance on the validation set
results = model.val()
 
# Perform object detection on an image using the model
results = model('https://ultralytics.com/images/bus.jpg')
 
# Export the model to ONNX format
success = model.export(format='onnx')

Convert YOLOV8 to ONNX

from ultralytics import YOLO
 
# Load a model
model = YOLO('yolov8n.pt')  # load an official model
model = YOLO('path/to/best.pt')  # load a custom trained
 
# Export the model
model.export(format='onnx')

Using CLI

yolo export model=yolov8n.pt format=onnx  # export official model
yolo export model=path/to/best.pt format=onnx  # export custom trained model

https://docs.ultralytics.com/tasks/detect/#export

ONNX NMS: https://github.com/Hyuto/fun/blob/master/test-onnx-graph-surgeon/nms-onnx-v8.py

Download Models from: https://github.com/satyajitghana/web-yolo-onnx/tree/master/public/model

page.tsx

"use client";
import { useEffect, useRef, useState } from "react";
import cv from "@techstark/opencv-js";
import { Tensor, InferenceSession } from "onnxruntime-web";
import { detectImage } from "@/lib/utils";
 
const modelConfig = {
  name: "yolov8n.onnx",
  nmsModel: "nms-yolov8.onnx",
  inputShape: [1, 3, 640, 640],
  topK: 100,
  iouThreshold: 0.45,
  scoreThreshold: 0.25,
};
 
export default function Home() {
  const [session, setSession] = useState<{
    net: InferenceSession;
    nms: InferenceSession;
  } | null>(null);
 
  const [loading, setLoading] = useState(true);
  const [image, setImage] = useState<string | null>(null);
  const inputRef = useRef<HTMLInputElement>(null);
  const inputImageRef = useRef<HTMLImageElement>(null);
  const canvasOutputRef = useRef<HTMLCanvasElement>(null);
  const canvasInputRef = useRef<HTMLCanvasElement>(null);
 
  cv["onRuntimeInitialized"] = async () => {
    // create the YOLOv8 Model
    const yolov8 = await InferenceSession.create(`/model/${modelConfig.name}`, {
      executionProviders: ["wasm"],
    });
 
    // create the NMS Model
    const nms = await InferenceSession.create(
      `/model/${modelConfig.nmsModel}`,
      {
        executionProviders: ["wasm"],
      }
    );
 
    const tensor = new Tensor(
      "float32",
      new Float32Array(modelConfig.inputShape.reduce((a, b) => a * b)),
      modelConfig.inputShape
    );
    const res = await yolov8.run({ images: tensor });
    console.log("model warm up", res);
 
    setSession({
      net: yolov8,
      nms: nms,
    });
 
    setLoading(false);
  };
  return (
    <main className="flex min-h-screen flex-col items-center justify-between p-24">
      <h1 className="text-3xl">YOLOV8 - ONNX - WASM</h1>
      {loading && <>Loading Model...</>}
      <img
        ref={inputImageRef}
        src="#"
        alt=""
        // style={{ display: image ? "block" : "none" }}
        className="hidden absolute"
        onLoad={() => {
          if (!inputImageRef.current || !canvasOutputRef.current) return;
          if (!session) return;
          detectImage(
            inputImageRef.current,
            canvasOutputRef.current,
            session,
            modelConfig.topK,
            modelConfig.iouThreshold,
            modelConfig.scoreThreshold,
            modelConfig.inputShape
          );
        }}
      />
      <div className="relative min-h-[640px] min-w-[640px]">
        <div className="absolute flex flex-col items-center w-full justify-center z-20">
          <canvas
            width={modelConfig.inputShape[2]}
            height={modelConfig.inputShape[3]}
            ref={canvasInputRef}
            className="absolute left-0 top-0 rounded-md"
          />
          <canvas
            width={modelConfig.inputShape[2]}
            height={modelConfig.inputShape[3]}
            ref={canvasOutputRef}
            className="absolute left-0 top-0"
          />
        </div>
      </div>
      <input
        type="file"
        ref={inputRef}
        accept="image/*"
        onChange={(e) => {
          if (!inputImageRef.current) return;
          if (e.target.files?.length) {
            // handle next image to detect
            if (image) {
              URL.revokeObjectURL(image);
              setImage(null);
            }
 
            const url = URL.createObjectURL(e.target.files[0]); // create image url
            inputImageRef.current.src = url; // set image source
 
            const canvas2DCtx = canvasInputRef.current?.getContext("2d");
 
            inputImageRef.current.onload = async () => {
              if (!inputImageRef.current) return;
              if (canvas2DCtx) {
                canvas2DCtx.drawImage(
                  inputImageRef.current,
                  0,
                  0,
                  modelConfig.inputShape[2],
                  modelConfig.inputShape[3]
                );
              }
            };
 
            setImage(url);
          }
        }}
      />
    </main>
  );
}

Reference: https://github.com/microsoft/onnxruntime-nextjs-template/blob/main/next.config.js

npm run dev

Untitled

https://github.com/satyajitghana/web-yolo-onnx

NOTES: