diff --git a/.gitignore b/.gitignore index 697658c22..e64048514 100644 --- a/.gitignore +++ b/.gitignore @@ -437,3 +437,5 @@ config.json .ipynb_aml_checkpoints *.ipynb.amltmp +## Android Studio IDE files +.idea/ diff --git a/mobile/examples/phi-3-vision/android/android/.gitignore b/mobile/examples/phi-3-vision/android/android/.gitignore new file mode 100644 index 000000000..5ef3c80cf --- /dev/null +++ b/mobile/examples/phi-3-vision/android/android/.gitignore @@ -0,0 +1,10 @@ +*.iml +.gradle +.idea +.DS_Store +build +/captures +.externalNativeBuild +.cxx +local.properties + diff --git a/mobile/examples/phi-3-vision/android/android/README.md b/mobile/examples/phi-3-vision/android/android/README.md new file mode 100644 index 000000000..63a44172e --- /dev/null +++ b/mobile/examples/phi-3-vision/android/android/README.md @@ -0,0 +1,136 @@ +# Local Chatbot on Android with Phi-3 Vision, ONNX Runtime Mobile and ONNX Runtime Generate() API + +## Overview + +This is a basic [Phi-3 Vision](https://huggingface.co/microsoft/Phi-3-vision-128k-instruct-onnx-cpu) Android example application using [ONNX Runtime mobile](https://onnxruntime.ai/docs/tutorials/mobile/) and [ONNX Runtime Generate() API](https://github.com/microsoft/onnxruntime-genai) with support for efficiently running generative AI models. This tutorial will walk you through how to download and run the Phi-3 Vision App on your own mobile device and help you incorporate Phi-3 Vision into your own mobile developments. + +### Capabilities +[Phi-3 Vision](https://huggingface.co/microsoft/Phi-3-vision-128k-instruct-onnx-cpu) is a multimodal model incorporating imaging into [Phi-3's](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) language input capabilities. This expands Phi-3's usages to include Optical Character Recognition (OCR), image captioning, table parsing, and more. + +## Important Features + +### Java API +This app uses the [generate() Java API's](https://github.com/microsoft/onnxruntime-genai/tree/main/src/java/src/main/java/ai/onnxruntime/genai) GenAIException, Generator, GeneratorParams, Images, Model, MultiModalProcessor, NamedTensors, and TokenizerStream classes ([documentation](https://onnxruntime.ai/docs/genai/api/java.html)). The [generate() C API](https://onnxruntime.ai/docs/genai/api/c.html), [generate() C# API](https://onnxruntime.ai/docs/genai/api/csharp.html), and [generate() Python API](https://onnxruntime.ai/docs/genai/api/python.html) are also available. + +### Model Downloads +This app downloads the [Phi-3 Vision](https://huggingface.co/microsoft/Phi-3-vision-128k-instruct-onnx-cpu) model through Hugging Face. To use a different model, change the path links to refer to your chosen model. +```java +final String baseUrl = "https://huggingface.co/microsoft/Phi-3-vision-128k-instruct-onnx-cpu/resolve/main/cpu-int4-rtn-block-32-acc-level-4/"; +List files = Arrays.asList( + "genai_config.json", + "phi-3-v-128k-instruct-text-embedding.onnx", + "phi-3-v-128k-instruct-text-embedding.onnx.data", + "phi-3-v-128k-instruct-text.onnx", + "phi-3-v-128k-instruct-text.onnx.data", + "phi-3-v-128k-instruct-vision.onnx", + "phi-3-v-128k-instruct-vision.onnx.data", + "processor_config.json", + "special_tokens_map.json", + "tokenizer.json", + "tokenizer_config.json"); +``` +These packages will only need to be downloaded once. While editing your app and running new versions, the downloads will skip since all files already exist. +```java +if (urlFilePairs.isEmpty()) { + // Display a message using Toast + Toast.makeText(this, "All files already exist. Skipping download.", Toast.LENGTH_SHORT).show(); + Log.d(TAG, "All files already exist. Skipping download."); + model = new Model(getFilesDir().getPath()); + multiModalProcessor = new MultiModalProcessor(model); + return; +} +``` +### Crash Prevention +Downloading the packages for the app on your mobile device takes ~15-30 minutes depending on which device you are using. The progress bar indicates what percent of the downloads are completed. +```java +public void onProgress(long lastBytesRead, long bytesRead, long bytesTotal) { + long lastPctDone = 100 * lastBytesRead / bytesTotal; + long pctDone = 100 * bytesRead / bytesTotal; + if (pctDone > lastPctDone) { + Log.d(TAG, "Downloading files: " + pctDone + "%"); + runOnUiThread(() -> { + progressText.setText("Downloading: " + pctDone + "%"); + }); + } +} +``` +Because the app is initialized when downloads start, the 'send' button for prompts is disabled until downloads are complete to prevent crashing. +```java +if (model == null) { + // if the edit text is empty display a toast message. + Toast.makeText(MainActivity.this, "Model not loaded yet, please wait...", Toast.LENGTH_SHORT).show(); + return; +} +``` +### Multimodal Processor +Since we are using Phi-3 Vision, we refer to the [MultiModalProcessor Class]() to include imaging as well as text input. In an application with no imaging, you can use the [Tokenizer Class](https://github.com/microsoft/onnxruntime-genai/blob/main/src/java/src/main/java/ai/onnxruntime/genai/Tokenizer.java) instead. + +### Prompt Template +On its own, this model's answers can be very long. To format the AI assistant's answers, you can adjust the prompt template. +```java +String promptQuestion = "<|user|>\n"; +if (inputImage != null) { + promptQuestion += "<|image_1|>\n"; +} +promptQuestion += userMsgEdt.getText().toString() + "You are a helpful AI assistant. Answer in two paragraphs or less<|end|>\n<|assistant|>\n"; +final String promptQuestion_formatted = promptQuestion; + +Log.i("GenAI: prompt question", promptQuestion_formatted); +``` +You can also include [parameters](https://huggingface.co/microsoft/Phi-3-vision-128k-instruct-onnx-cpu/blob/main/cpu-int4-rtn-block-32-acc-level-4/genai_config.json) such as a max_length or length_penalty to your liking. +```java +generatorParams.setSearchOption("length_penalty", 1000); +generatorParams.setSearchOption("max_length", 500); +``` +NOTE: Including a max_length will cut off the assistant's answer once reaching the maximum number of tokens rather than formatting a complete response. + + +### Requirements +- [Android Studio](https://developer.android.com/studio) Giraffe | 2022.3.1 or later (installed on Mac/Windows/Linux) +- Android SDK 29+ +- Android NDK r22+ +- An Android device or an Android Emulator + +## Build And Run + +### Step 1: Clone the ONNX runtime mobile examples source code + +Clone this repository to get the sample application. + +`git@github.com:microsoft/onnxruntime-inference-examples.git` + +### [Optional] Step 2: Prepare the model + +The current set up supports downloading Phi-3-mini model directly from Huggingface repo to the android device folder. However, it takes time since the model data is >2.5G. + +You can also download [**Phi-3-Vision**](https://huggingface.co/microsoft/Phi-3-vision-128k-instruct-onnx-cpu/tree/main) +and manually copy to the android device file directory following the below instructions: + +#### Steps for manual copying models to android device directory: +From Android Studio: + - create (if necessary) and run your emulator/device + - make sure it has at least 8GB of internal storage + - debug/run the app so it's deployed to the device and creates it's `files` directory + - expected to be `/data/data/ai.onnxruntime.genai.vision.demo/files` + - this is the path returned by `getFilesDir()` + - Open Device Explorer in Android Studio + - Navigate to `/data/data/ai.onnxruntime.genai.vision.demo/files` + - adjust as needed if the value returned by getFilesDir() differs for your emulator or device + - copy the whole [phi-3](https://huggingface.co/microsoft/Phi-3-vision-128k-instruct-onnx-cpu/tree/main) model folder to the `files` directory + +### Step 3: Connect Android Device and Run the app + Connect your Android Device to your computer or select the Android Emulator in Android Studio Device manager. + + Then select `Run -> Run app` and this will prompt the app to be built and installed on your device or emulator. + + Now you can try giving some sample prompt questions and test the chatbot android app by clicking the ">" action button. + +# +Here are some sample example screenshots of the app. + +App Screenshot 1 + +App Screenshot 2 + +App Screenshot 3 + diff --git a/mobile/examples/phi-3-vision/android/android/app/.gitignore b/mobile/examples/phi-3-vision/android/android/app/.gitignore new file mode 100644 index 000000000..42afabfd2 --- /dev/null +++ b/mobile/examples/phi-3-vision/android/android/app/.gitignore @@ -0,0 +1 @@ +/build \ No newline at end of file diff --git a/mobile/examples/phi-3-vision/android/android/app/build.gradle.kts b/mobile/examples/phi-3-vision/android/android/app/build.gradle.kts new file mode 100644 index 000000000..0b2edf583 --- /dev/null +++ b/mobile/examples/phi-3-vision/android/android/app/build.gradle.kts @@ -0,0 +1,58 @@ +plugins { + id("com.android.application") +} + +android { + namespace = "ai.onnxruntime.genai.vision.demo" + compileSdk = 33 + + defaultConfig { + applicationId = "ai.onnxruntime.genai.vision.demo" + minSdk = 27 + targetSdk = 33 + versionCode = 1 + versionName = "1.0" + + testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner" + + ndk { + //noinspection ChromeOsAbiSupport + abiFilters += listOf("arm64-v8a", "x86_64") + } + } + + buildTypes { + release { + isMinifyEnabled = false + proguardFiles( + getDefaultProguardFile("proguard-android-optimize.txt"), + "proguard-rules.pro" + ) + } + } + + compileOptions { + sourceCompatibility = JavaVersion.VERSION_1_8 + targetCompatibility = JavaVersion.VERSION_1_8 + } + + buildFeatures { + viewBinding = true + } +} + +dependencies { + + implementation("androidx.appcompat:appcompat:1.6.1") + implementation("com.google.android.material:material:1.9.0") + implementation("androidx.constraintlayout:constraintlayout:2.1.4") + testImplementation("junit:junit:4.13.2") + androidTestImplementation("androidx.test.ext:junit:1.1.5") + androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1") + + // ONNX Runtime with GenAI + implementation("com.microsoft.onnxruntime:onnxruntime-android:latest.release") + //implementation(files("libs/onnxruntime-genai-android-0.4.1-dev.aar")) + //implementation(files("F:/onnx/onnxruntime/build/Android/Release/java/build/android/outputs/aar/onnxruntime-release.aar")) + implementation(files("C:/Users/t-boskovicf/repos/onnxruntime-genai-fb/build/Android/Release/src/java/build/android/outputs/aar/onnxruntime-genai-release.aar")) +} \ No newline at end of file diff --git a/mobile/examples/phi-3-vision/android/android/app/libs/onnxruntime-genai-android-0.4.0-dev.aar b/mobile/examples/phi-3-vision/android/android/app/libs/onnxruntime-genai-android-0.4.0-dev.aar new file mode 100644 index 000000000..8df8c6dcb Binary files /dev/null and b/mobile/examples/phi-3-vision/android/android/app/libs/onnxruntime-genai-android-0.4.0-dev.aar differ diff --git a/mobile/examples/phi-3-vision/android/android/app/libs/onnxruntime-genai-android-0.4.1-dev-old.aar b/mobile/examples/phi-3-vision/android/android/app/libs/onnxruntime-genai-android-0.4.1-dev-old.aar new file mode 100644 index 000000000..6c80a24e8 Binary files /dev/null and b/mobile/examples/phi-3-vision/android/android/app/libs/onnxruntime-genai-android-0.4.1-dev-old.aar differ diff --git a/mobile/examples/phi-3-vision/android/android/app/libs/onnxruntime-genai-android-0.4.1-dev.aar b/mobile/examples/phi-3-vision/android/android/app/libs/onnxruntime-genai-android-0.4.1-dev.aar new file mode 100644 index 000000000..f28ebfeed Binary files /dev/null and b/mobile/examples/phi-3-vision/android/android/app/libs/onnxruntime-genai-android-0.4.1-dev.aar differ diff --git a/mobile/examples/phi-3-vision/android/android/app/proguard-rules.pro b/mobile/examples/phi-3-vision/android/android/app/proguard-rules.pro new file mode 100644 index 000000000..481bb4348 --- /dev/null +++ b/mobile/examples/phi-3-vision/android/android/app/proguard-rules.pro @@ -0,0 +1,21 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +# Uncomment this to preserve the line number information for +# debugging stack traces. +#-keepattributes SourceFile,LineNumberTable + +# If you keep the line number information, uncomment this to +# hide the original source file name. +#-renamesourcefileattribute SourceFile \ No newline at end of file diff --git a/mobile/examples/phi-3-vision/android/android/app/src/androidTest/java/ai/onnxruntime/genai/vision/demo/ExampleInstrumentedTest.java b/mobile/examples/phi-3-vision/android/android/app/src/androidTest/java/ai/onnxruntime/genai/vision/demo/ExampleInstrumentedTest.java new file mode 100644 index 000000000..0e9868637 --- /dev/null +++ b/mobile/examples/phi-3-vision/android/android/app/src/androidTest/java/ai/onnxruntime/genai/vision/demo/ExampleInstrumentedTest.java @@ -0,0 +1,26 @@ +package ai.onnxruntime.genai.vision.demo; + +import android.content.Context; + +import androidx.test.platform.app.InstrumentationRegistry; +import androidx.test.ext.junit.runners.AndroidJUnit4; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import static org.junit.Assert.*; + +/** + * Instrumented test, which will execute on an Android device. + * + * @see Testing documentation + */ +@RunWith(AndroidJUnit4.class) +public class ExampleInstrumentedTest { + @Test + public void useAppContext() { + // Context of the app under test. + Context appContext = InstrumentationRegistry.getInstrumentation().getTargetContext(); + assertEquals("ai.onnxruntime.genai.vision.demo", appContext.getPackageName()); + } +} \ No newline at end of file diff --git a/mobile/examples/phi-3-vision/android/android/app/src/main/AndroidManifest.xml b/mobile/examples/phi-3-vision/android/android/app/src/main/AndroidManifest.xml new file mode 100644 index 000000000..d9e2be071 --- /dev/null +++ b/mobile/examples/phi-3-vision/android/android/app/src/main/AndroidManifest.xml @@ -0,0 +1,29 @@ + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/mobile/examples/phi-3-vision/android/android/app/src/main/ic_launcher-playstore.png b/mobile/examples/phi-3-vision/android/android/app/src/main/ic_launcher-playstore.png new file mode 100644 index 000000000..06e0e5479 Binary files /dev/null and b/mobile/examples/phi-3-vision/android/android/app/src/main/ic_launcher-playstore.png differ diff --git a/mobile/examples/phi-3-vision/android/android/app/src/main/java/ai/onnxruntime/genai/vision/demo/GenAIImage.java b/mobile/examples/phi-3-vision/android/android/app/src/main/java/ai/onnxruntime/genai/vision/demo/GenAIImage.java new file mode 100644 index 000000000..0a6f5d7c8 --- /dev/null +++ b/mobile/examples/phi-3-vision/android/android/app/src/main/java/ai/onnxruntime/genai/vision/demo/GenAIImage.java @@ -0,0 +1,60 @@ +package ai.onnxruntime.genai.vision.demo; + +import android.content.Context; +import android.graphics.Bitmap; +import android.graphics.BitmapFactory; +import android.net.Uri; + +import java.io.FileNotFoundException; +import java.io.FileOutputStream; +import java.io.IOException; + +import ai.onnxruntime.genai.GenAIException; +import ai.onnxruntime.genai.Images; + +public class GenAIImage { + Images images = null; + Bitmap bitmap = null; + + GenAIImage(Context context, Uri uri, final int maxWidth, final int maxHeight) throws IOException, GenAIException { + Bitmap bmp = decodeUri(context, uri, maxWidth, maxHeight); + String filename = context.getFilesDir() + "/multimodalinput.png"; + FileOutputStream out = new FileOutputStream(filename); + bmp.compress(Bitmap.CompressFormat.PNG, 100, out); // bmp is your Bitmap instance + // PNG is a lossless format, the compression factor (100) is ignored + images = new Images(filename); + images = new Images(filename); + bitmap = BitmapFactory.decodeFile(filename); + } + + GenAIImage(Context context, Uri uri) throws IOException, GenAIException { + this(context, uri, 100000, 100000); + } + + public Images getImages() { + return images; + } + + public Bitmap getBitmap() { return bitmap; } + + private static Bitmap decodeUri(Context c, Uri uri, final int maxWidth, final int maxHeight) + throws FileNotFoundException { + BitmapFactory.Options o = new BitmapFactory.Options(); + o.inJustDecodeBounds = true; + BitmapFactory.decodeStream(c.getContentResolver().openInputStream(uri), null, o); + + int width_tmp = o.outWidth + , height_tmp = o.outHeight; + int scale = 1; + + while(width_tmp / 2 > maxWidth || height_tmp / 2 > maxHeight) { + width_tmp /= 2; + height_tmp /= 2; + scale *= 2; + } + + BitmapFactory.Options o2 = new BitmapFactory.Options(); + o2.inSampleSize = scale; + return BitmapFactory.decodeStream(c.getContentResolver().openInputStream(uri), null, o2); + } +} diff --git a/mobile/examples/phi-3-vision/android/android/app/src/main/java/ai/onnxruntime/genai/vision/demo/MainActivity.java b/mobile/examples/phi-3-vision/android/android/app/src/main/java/ai/onnxruntime/genai/vision/demo/MainActivity.java new file mode 100644 index 000000000..000365d07 --- /dev/null +++ b/mobile/examples/phi-3-vision/android/android/app/src/main/java/ai/onnxruntime/genai/vision/demo/MainActivity.java @@ -0,0 +1,351 @@ +package ai.onnxruntime.genai.vision.demo; + +import androidx.activity.result.ActivityResultLauncher; +import androidx.activity.result.PickVisualMediaRequest; +import androidx.activity.result.contract.ActivityResultContracts; +import androidx.appcompat.app.AppCompatActivity; + +import android.content.Context; +import android.content.Intent; +import android.database.Cursor; +import android.net.Uri; +import android.os.Build; +import android.os.Bundle; +import android.text.method.ScrollingMovementMethod; +import android.util.Log; +import android.util.Pair; +import android.view.View; +import android.view.WindowManager; +import android.webkit.MimeTypeMap; +import android.widget.Button; +import android.widget.EditText; +import android.widget.ImageButton; +import android.widget.TextView; +import android.widget.Toast; + +import java.io.BufferedInputStream; +import java.io.BufferedOutputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileNotFoundException; +import java.io.FileOutputStream; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.function.Consumer; + +import ai.onnxruntime.genai.GenAIException; +import ai.onnxruntime.genai.Generator; +import ai.onnxruntime.genai.GeneratorParams; +import ai.onnxruntime.genai.Images; +import ai.onnxruntime.genai.Model; +import ai.onnxruntime.genai.MultiModalProcessor; +import ai.onnxruntime.genai.NamedTensors; +import ai.onnxruntime.genai.Sequences; +import ai.onnxruntime.genai.SimpleGenAI; +import ai.onnxruntime.genai.Tokenizer; +import ai.onnxruntime.genai.TokenizerStream; +import ai.onnxruntime.genai.vision.demo.databinding.ActivityMainBinding; + +public class MainActivity extends AppCompatActivity implements Consumer { + + private ActivityMainBinding binding; + private EditText userMsgEdt; + private Model model; + //private Tokenizer tokenizer; + private MultiModalProcessor multiModalProcessor; + private ImageButton sendMsgIB; + private ImageButton selectPhotoIB; + private TextView generatedTV; + private TextView promptTV; + private TextView progressText; + private static final String TAG = "genai.demo.MainActivity"; + + private final int PICK_IMAGE_FILE = 2; + private GenAIImage inputImage = null; + + @Override + public void onActivityResult(int requestCode, int resultCode, + Intent resultData) { + if (requestCode == PICK_IMAGE_FILE) { + if (resultCode == RESULT_OK) { + // The result data contains a URI for the document or directory that + // the user selected. + inputImage = null; + if (resultData != null && resultData.getData() != null) { + Uri uri = resultData.getData(); + try { + inputImage = new GenAIImage(this, uri); + if (inputImage.getBitmap() != null) { + runOnUiThread(() -> { + selectPhotoIB.setImageBitmap(inputImage.getBitmap()); + }); + } + } catch (IOException | GenAIException e) { + throw new RuntimeException(e); + } + } + } + } + super.onActivityResult(requestCode, resultCode, resultData); + } + private static boolean fileExists(Context context, String fileName) { + File file = new File(context.getFilesDir(), fileName); + return file.exists(); + } + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + + binding = ActivityMainBinding.inflate(getLayoutInflater()); + setContentView(binding.getRoot()); + + sendMsgIB = findViewById(R.id.idIBSend); + selectPhotoIB = findViewById(R.id.idIBPhoto); + userMsgEdt = findViewById(R.id.idEdtMessage); + generatedTV = findViewById(R.id.sample_text); + promptTV = findViewById(R.id.user_text); + progressText = findViewById(R.id.idProgressStatus); + + // Trigger the download operation when the application is created + try { + downloadModels( + getApplicationContext()); + } catch (GenAIException e) { + throw new RuntimeException(e); + } + + Consumer tokenListener = this; + + //enable scrolling and resizing of text boxes + generatedTV.setMovementMethod(new ScrollingMovementMethod()); + getWindow().setSoftInputMode(WindowManager.LayoutParams.SOFT_INPUT_ADJUST_RESIZE); + + selectPhotoIB.setOnClickListener(new View.OnClickListener() { + @Override + public void onClick(View v) { + + Intent chooseFile = new Intent(Intent.ACTION_GET_CONTENT); + chooseFile.addCategory(Intent.CATEGORY_OPENABLE); + chooseFile.setType("image/*"); + startActivityForResult( + Intent.createChooser(chooseFile, "Choose an image"), + PICK_IMAGE_FILE + ); + + }}); + + // adding on click listener for send message button. + sendMsgIB.setOnClickListener(new View.OnClickListener() { + @Override + public void onClick(View v) { + if (model == null) { + // if the edit text is empty display a toast message. + Toast.makeText(MainActivity.this, "Model not loaded yet, please wait...", Toast.LENGTH_SHORT).show(); + return; + } + + // Checking if the message entered + // by user is empty or not. + if (userMsgEdt.getText().toString().isEmpty()) { + // if the edit text is empty display a toast message. + Toast.makeText(MainActivity.this, "Please enter your message..", Toast.LENGTH_SHORT).show(); + return; + } + + String promptQuestion = "<|user|>\n"; + if (inputImage != null) { + promptQuestion += "<|image_1|>\n"; + } + promptQuestion += userMsgEdt.getText().toString() + "You are a helpful AI assistant. Answer in two paragraphs or less<|end|>\n<|assistant|>\n"; + final String promptQuestion_formatted = promptQuestion; + + Log.i("GenAI: prompt question", promptQuestion_formatted); + setVisibility(); + + // Disable send button while responding to prompt. + sendMsgIB.setEnabled(false); + + promptTV.setText(userMsgEdt.getText().toString()); + // Clear Edit Text or prompt question. + userMsgEdt.setText(""); + if (inputImage != null) { + generatedTV.setText("[analyzing image...]\n"); + } + else { + generatedTV.setText(""); + } + + new Thread(new Runnable() { + @Override + public void run() { + TokenizerStream stream = null; + GeneratorParams generatorParams = null; + Generator generator = null; + Sequences encodedPrompt = null; + Images images = null; + NamedTensors inputTensors = null; + try { + stream = multiModalProcessor.createStream(); + + generatorParams = model.createGeneratorParams(); + //examples for optional parameters to format AI response + //generatorParams.setSearchOption("length_penalty", 1000); + //generatorParams.setSearchOption("max_length", 500); + + if (inputImage != null) { + images = inputImage.getImages(); + } + + + inputTensors = multiModalProcessor.processImages(promptQuestion_formatted, images); + generatorParams.setInput(inputTensors); + + generator = new Generator(model, generatorParams); + + while (!generator.isDone()) { + generator.computeLogits(); + generator.generateNextToken(); + + int token = generator.getLastTokenInSequence(0); + + tokenListener.accept(stream.decode(token)); + } + generator.close(); + encodedPrompt.close(); + stream.close(); + generatorParams.close(); + images.close(); + inputTensors.close(); + } + catch (GenAIException e) { + Log.e(TAG, "Exception occurred during model query: " + e.getMessage()); + if (generator != null) generator.close(); + if (encodedPrompt != null) encodedPrompt.close(); + if (stream != null) stream.close(); + if (generatorParams != null) generatorParams.close(); + if (images != null) images.close(); + if (inputTensors != null) inputTensors.close(); + throw new RuntimeException(e); + } + + runOnUiThread(() -> { + sendMsgIB.setEnabled(true); + }); + } + }).start(); + } + }); + } + + @Override + protected void onDestroy() { + multiModalProcessor.close(); + multiModalProcessor = null; + model.close(); + model = null; + super.onDestroy(); + } + + + private void downloadModels(Context context) throws GenAIException { + + final String baseUrl = "https://huggingface.co/microsoft/Phi-3-vision-128k-instruct-onnx-cpu/resolve/main/cpu-int4-rtn-block-32-acc-level-4/"; + List files = Arrays.asList( + "genai_config.json", + "phi-3-v-128k-instruct-text-embedding.onnx", + "phi-3-v-128k-instruct-text-embedding.onnx.data", + "phi-3-v-128k-instruct-text.onnx", + "phi-3-v-128k-instruct-text.onnx.data", + "phi-3-v-128k-instruct-vision.onnx", + "phi-3-v-128k-instruct-vision.onnx.data", + "processor_config.json", + "special_tokens_map.json", + "tokenizer.json", + "tokenizer_config.json"); + + + List> urlFilePairs = new ArrayList<>(); + for (String file : files) { + if (/*file.endsWith(".data") ||*/ !fileExists(context, file)) { + urlFilePairs.add(new Pair<>( + baseUrl + file,// + "?download=true", + file)); + } + } + if (urlFilePairs.isEmpty()) { + // Display a message using Toast + Toast.makeText(this, "All files already exist. Skipping download.", Toast.LENGTH_SHORT).show(); + Log.d(TAG, "All files already exist. Skipping download."); + model = new Model(getFilesDir().getPath()); + multiModalProcessor = new MultiModalProcessor(model); + return; + } + + progressText.setText("Downloading..."); + progressText.setVisibility(View.VISIBLE); + + Toast.makeText(this, + "Downloading model for the app... Model Size greater than 2GB, please allow a few minutes to download.", + Toast.LENGTH_SHORT).show(); + + ExecutorService executor = Executors.newSingleThreadExecutor(); + executor.execute(() -> { + ModelDownloader.downloadModel(context, urlFilePairs, new ModelDownloader.DownloadCallback() { + @Override + public void onProgress(long lastBytesRead, long bytesRead, long bytesTotal) { + long lastPctDone = 100 * lastBytesRead / bytesTotal; + long pctDone = 100 * bytesRead / bytesTotal; + if (pctDone > lastPctDone) { + Log.d(TAG, "Downloading files: " + pctDone + "%"); + runOnUiThread(() -> { + progressText.setText("Downloading: " + pctDone + "%"); + }); + } + } + @Override + public void onDownloadComplete() { + Log.d(TAG, "All downloads completed."); + + // Last download completed, create SimpleGenAI + try { + model = new Model(getFilesDir().getPath()); + multiModalProcessor = new MultiModalProcessor(model); + runOnUiThread(() -> { + Toast.makeText(context, "All downloads completed", Toast.LENGTH_SHORT).show(); + progressText.setVisibility(View.INVISIBLE); + }); + } catch (GenAIException e) { + e.printStackTrace(); + throw new RuntimeException(e); + } + + } + }); + }); + executor.shutdown(); + } + + @Override + public void accept(String token) { + runOnUiThread(() -> { + // Update and aggregate the generated text and write to text box. + CharSequence generated = generatedTV.getText(); + generatedTV.setText(generated + token); + generatedTV.invalidate(); + final int scrollAmount = generatedTV.getLayout().getLineTop(generatedTV.getLineCount()) - generatedTV.getHeight(); + generatedTV.scrollTo(0, Math.max(scrollAmount, 0)); + }); + } + + public void setVisibility() { + TextView view = (TextView) findViewById(R.id.user_text); + view.setVisibility(View.VISIBLE); + TextView botView = (TextView) findViewById(R.id.sample_text); + botView.setVisibility(View.VISIBLE); + } +} diff --git a/mobile/examples/phi-3-vision/android/android/app/src/main/java/ai/onnxruntime/genai/vision/demo/ModelDownloader.java b/mobile/examples/phi-3-vision/android/android/app/src/main/java/ai/onnxruntime/genai/vision/demo/ModelDownloader.java new file mode 100644 index 000000000..3d544a69d --- /dev/null +++ b/mobile/examples/phi-3-vision/android/android/app/src/main/java/ai/onnxruntime/genai/vision/demo/ModelDownloader.java @@ -0,0 +1,104 @@ +package ai.onnxruntime.genai.vision.demo; + +import static androidx.constraintlayout.helper.widget.MotionEffect.TAG; + +import android.content.Context; +import android.util.Log; +import android.util.Pair; +import android.widget.Toast; + +import java.io.BufferedReader; +import java.io.BufferedWriter; +import java.io.File; +import java.io.FileOutputStream; +import java.io.FileWriter; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.net.HttpURLConnection; +import java.net.URL; +import java.util.ArrayList; +import java.util.List; + +import ai.onnxruntime.genai.GenAIException; + +public class ModelDownloader { + interface DownloadCallback { + void onProgress(long lastBytesRead, long bytesRead, long bytesTotal); + void onDownloadComplete() throws GenAIException; + } + + public static void downloadModel(Context context, List> urlFilePairs, DownloadCallback callback) { + try { + + List connections = new ArrayList<>(); + long totalDownloadBytes = 0; + for (int i = 0; i < urlFilePairs.size(); i++) { + String url = urlFilePairs.get(i).first; + URL modelUrl = new URL(url); + HttpURLConnection connection = (HttpURLConnection) modelUrl.openConnection(); + connections.add(connection); + long totalFileSize = connection.getHeaderFieldLong("Content-Length",-1); + totalDownloadBytes += totalFileSize; + } + + long totalBytesRead = 0; + for (int i = 0; i < urlFilePairs.size(); i++) { + String fileName = urlFilePairs.get(i).second; + HttpURLConnection connection = connections.get(i); + + File file = new File(context.getFilesDir(), fileName); + File tempFile = new File(context.getFilesDir(), fileName + ".tmp"); + Log.d(TAG, "Downloading file: " + fileName); + connection.connect(); + + // Check if response code is OK + if (connection.getResponseCode() == HttpURLConnection.HTTP_OK) { + InputStream inputStream = connection.getInputStream(); + FileOutputStream outputStream = new FileOutputStream(tempFile); + + long begin = System.currentTimeMillis(); + + byte[] buffer = new byte[4096]; + int bytesRead; + + while ((bytesRead = inputStream.read(buffer)) != -1) { + outputStream.write(buffer, 0, bytesRead); + if (callback != null) { + callback.onProgress(totalBytesRead, totalBytesRead + bytesRead, totalDownloadBytes); + } + totalBytesRead += bytesRead; + } + + outputStream.flush(); + outputStream.close(); + inputStream.close(); + connection.disconnect(); + + long duration = System.currentTimeMillis() - begin; + + // File downloaded successfully + if (tempFile.renameTo(file)) { + if (duration > 0) { + Log.d(TAG, "File downloaded successfully: " + fileName + "(" + totalBytesRead + " bytes, " + (totalBytesRead / duration) + "KBps)"); + } else { + Log.d(TAG, "File downloaded successfully: " + fileName + "(" + totalBytesRead + " bytes, " + (duration / 1000.0) + "s)"); + } + } else { + Log.e(TAG, "Failed to rename temp file to original file"); + } + } else { + Log.e(TAG, "Failed to download model. HTTP response code: " + connection.getResponseCode()); + } + } + if (callback != null) { + callback.onDownloadComplete(); + } + } catch (IOException e) { + e.printStackTrace(); + Log.e(TAG, "Exception occurred during model download: " + e.getMessage()); + } catch (GenAIException e) { + throw new RuntimeException(e); + } + } +} \ No newline at end of file diff --git a/mobile/examples/phi-3-vision/android/android/app/src/main/res/drawable/bot_svgrepo_com.xml b/mobile/examples/phi-3-vision/android/android/app/src/main/res/drawable/bot_svgrepo_com.xml new file mode 100644 index 000000000..7b2f450ed --- /dev/null +++ b/mobile/examples/phi-3-vision/android/android/app/src/main/res/drawable/bot_svgrepo_com.xml @@ -0,0 +1,41 @@ + + + + + + + diff --git a/mobile/examples/phi-3-vision/android/android/app/src/main/res/drawable/ic_launcher_background.xml b/mobile/examples/phi-3-vision/android/android/app/src/main/res/drawable/ic_launcher_background.xml new file mode 100644 index 000000000..617786d51 --- /dev/null +++ b/mobile/examples/phi-3-vision/android/android/app/src/main/res/drawable/ic_launcher_background.xml @@ -0,0 +1,10 @@ + + + + diff --git a/mobile/examples/phi-3-vision/android/android/app/src/main/res/drawable/ic_launcher_foreground.xml b/mobile/examples/phi-3-vision/android/android/app/src/main/res/drawable/ic_launcher_foreground.xml new file mode 100644 index 000000000..2f45a10f2 --- /dev/null +++ b/mobile/examples/phi-3-vision/android/android/app/src/main/res/drawable/ic_launcher_foreground.xml @@ -0,0 +1,16 @@ + + + + + + + \ No newline at end of file diff --git a/mobile/examples/phi-3-vision/android/android/app/src/main/res/drawable/ic_send.xml b/mobile/examples/phi-3-vision/android/android/app/src/main/res/drawable/ic_send.xml new file mode 100644 index 000000000..3abc6cb33 --- /dev/null +++ b/mobile/examples/phi-3-vision/android/android/app/src/main/res/drawable/ic_send.xml @@ -0,0 +1,5 @@ + + + diff --git a/mobile/examples/phi-3-vision/android/android/app/src/main/res/drawable/rounded_corner.xml b/mobile/examples/phi-3-vision/android/android/app/src/main/res/drawable/rounded_corner.xml new file mode 100644 index 000000000..64831ea95 --- /dev/null +++ b/mobile/examples/phi-3-vision/android/android/app/src/main/res/drawable/rounded_corner.xml @@ -0,0 +1,13 @@ + + + + + + + + + \ No newline at end of file diff --git a/mobile/examples/phi-3-vision/android/android/app/src/main/res/drawable/rounded_corner2.xml b/mobile/examples/phi-3-vision/android/android/app/src/main/res/drawable/rounded_corner2.xml new file mode 100644 index 000000000..913738d1d --- /dev/null +++ b/mobile/examples/phi-3-vision/android/android/app/src/main/res/drawable/rounded_corner2.xml @@ -0,0 +1,13 @@ + + + + + + + + + \ No newline at end of file diff --git a/mobile/examples/phi-3-vision/android/android/app/src/main/res/drawable/user_svgrepo_com.xml b/mobile/examples/phi-3-vision/android/android/app/src/main/res/drawable/user_svgrepo_com.xml new file mode 100644 index 000000000..af7fb9cef --- /dev/null +++ b/mobile/examples/phi-3-vision/android/android/app/src/main/res/drawable/user_svgrepo_com.xml @@ -0,0 +1,20 @@ + + + + diff --git a/mobile/examples/phi-3-vision/android/android/app/src/main/res/layout/activity_main.xml b/mobile/examples/phi-3-vision/android/android/app/src/main/res/layout/activity_main.xml new file mode 100644 index 000000000..0068a252b --- /dev/null +++ b/mobile/examples/phi-3-vision/android/android/app/src/main/res/layout/activity_main.xml @@ -0,0 +1,112 @@ + + + + + + + + + + + + + + + + + diff --git a/mobile/examples/phi-3-vision/android/android/app/src/main/res/mipmap-anydpi-v26/ic_launcher.xml b/mobile/examples/phi-3-vision/android/android/app/src/main/res/mipmap-anydpi-v26/ic_launcher.xml new file mode 100644 index 000000000..7353dbd1f --- /dev/null +++ b/mobile/examples/phi-3-vision/android/android/app/src/main/res/mipmap-anydpi-v26/ic_launcher.xml @@ -0,0 +1,5 @@ + + + + + \ No newline at end of file diff --git a/mobile/examples/phi-3-vision/android/android/app/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml b/mobile/examples/phi-3-vision/android/android/app/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml new file mode 100644 index 000000000..7353dbd1f --- /dev/null +++ b/mobile/examples/phi-3-vision/android/android/app/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml @@ -0,0 +1,5 @@ + + + + + \ No newline at end of file diff --git a/mobile/examples/phi-3-vision/android/android/app/src/main/res/mipmap-hdpi/ic_launcher.webp b/mobile/examples/phi-3-vision/android/android/app/src/main/res/mipmap-hdpi/ic_launcher.webp new file mode 100644 index 000000000..2d42b11fb Binary files /dev/null and b/mobile/examples/phi-3-vision/android/android/app/src/main/res/mipmap-hdpi/ic_launcher.webp differ diff --git a/mobile/examples/phi-3-vision/android/android/app/src/main/res/mipmap-hdpi/ic_launcher_round.webp b/mobile/examples/phi-3-vision/android/android/app/src/main/res/mipmap-hdpi/ic_launcher_round.webp new file mode 100644 index 000000000..937daee63 Binary files /dev/null and b/mobile/examples/phi-3-vision/android/android/app/src/main/res/mipmap-hdpi/ic_launcher_round.webp differ diff --git a/mobile/examples/phi-3-vision/android/android/app/src/main/res/mipmap-mdpi/ic_launcher.webp b/mobile/examples/phi-3-vision/android/android/app/src/main/res/mipmap-mdpi/ic_launcher.webp new file mode 100644 index 000000000..5301822c9 Binary files /dev/null and b/mobile/examples/phi-3-vision/android/android/app/src/main/res/mipmap-mdpi/ic_launcher.webp differ diff --git a/mobile/examples/phi-3-vision/android/android/app/src/main/res/mipmap-mdpi/ic_launcher_round.webp b/mobile/examples/phi-3-vision/android/android/app/src/main/res/mipmap-mdpi/ic_launcher_round.webp new file mode 100644 index 000000000..63172f756 Binary files /dev/null and b/mobile/examples/phi-3-vision/android/android/app/src/main/res/mipmap-mdpi/ic_launcher_round.webp differ diff --git a/mobile/examples/phi-3-vision/android/android/app/src/main/res/mipmap-xhdpi/ic_launcher.webp b/mobile/examples/phi-3-vision/android/android/app/src/main/res/mipmap-xhdpi/ic_launcher.webp new file mode 100644 index 000000000..4de312af5 Binary files /dev/null and b/mobile/examples/phi-3-vision/android/android/app/src/main/res/mipmap-xhdpi/ic_launcher.webp differ diff --git a/mobile/examples/phi-3-vision/android/android/app/src/main/res/mipmap-xhdpi/ic_launcher_round.webp b/mobile/examples/phi-3-vision/android/android/app/src/main/res/mipmap-xhdpi/ic_launcher_round.webp new file mode 100644 index 000000000..5d3f69699 Binary files /dev/null and b/mobile/examples/phi-3-vision/android/android/app/src/main/res/mipmap-xhdpi/ic_launcher_round.webp differ diff --git a/mobile/examples/phi-3-vision/android/android/app/src/main/res/mipmap-xxhdpi/ic_launcher.webp b/mobile/examples/phi-3-vision/android/android/app/src/main/res/mipmap-xxhdpi/ic_launcher.webp new file mode 100644 index 000000000..f9ade715d Binary files /dev/null and b/mobile/examples/phi-3-vision/android/android/app/src/main/res/mipmap-xxhdpi/ic_launcher.webp differ diff --git a/mobile/examples/phi-3-vision/android/android/app/src/main/res/mipmap-xxhdpi/ic_launcher_round.webp b/mobile/examples/phi-3-vision/android/android/app/src/main/res/mipmap-xxhdpi/ic_launcher_round.webp new file mode 100644 index 000000000..5c5c8efa2 Binary files /dev/null and b/mobile/examples/phi-3-vision/android/android/app/src/main/res/mipmap-xxhdpi/ic_launcher_round.webp differ diff --git a/mobile/examples/phi-3-vision/android/android/app/src/main/res/mipmap-xxxhdpi/ic_launcher.webp b/mobile/examples/phi-3-vision/android/android/app/src/main/res/mipmap-xxxhdpi/ic_launcher.webp new file mode 100644 index 000000000..57e37a887 Binary files /dev/null and b/mobile/examples/phi-3-vision/android/android/app/src/main/res/mipmap-xxxhdpi/ic_launcher.webp differ diff --git a/mobile/examples/phi-3-vision/android/android/app/src/main/res/mipmap-xxxhdpi/ic_launcher_round.webp b/mobile/examples/phi-3-vision/android/android/app/src/main/res/mipmap-xxxhdpi/ic_launcher_round.webp new file mode 100644 index 000000000..a8a9cffb3 Binary files /dev/null and b/mobile/examples/phi-3-vision/android/android/app/src/main/res/mipmap-xxxhdpi/ic_launcher_round.webp differ diff --git a/mobile/examples/phi-3-vision/android/android/app/src/main/res/values-night/themes.xml b/mobile/examples/phi-3-vision/android/android/app/src/main/res/values-night/themes.xml new file mode 100644 index 000000000..82ba79c5c --- /dev/null +++ b/mobile/examples/phi-3-vision/android/android/app/src/main/res/values-night/themes.xml @@ -0,0 +1,16 @@ + + + + \ No newline at end of file diff --git a/mobile/examples/phi-3-vision/android/android/app/src/main/res/values/colors.xml b/mobile/examples/phi-3-vision/android/android/app/src/main/res/values/colors.xml new file mode 100644 index 000000000..31b1630c0 --- /dev/null +++ b/mobile/examples/phi-3-vision/android/android/app/src/main/res/values/colors.xml @@ -0,0 +1,10 @@ + + + #CC03A9F4 + #CC03A9F4 + #CC03A9F4 + #a26cf3 + #7626ef + #FF000000 + #FFFFFFFF + \ No newline at end of file diff --git a/mobile/examples/phi-3-vision/android/android/app/src/main/res/values/ic_launcher_background.xml b/mobile/examples/phi-3-vision/android/android/app/src/main/res/values/ic_launcher_background.xml new file mode 100644 index 000000000..30d36e13c --- /dev/null +++ b/mobile/examples/phi-3-vision/android/android/app/src/main/res/values/ic_launcher_background.xml @@ -0,0 +1,4 @@ + + + #a26cf3 + \ No newline at end of file diff --git a/mobile/examples/phi-3-vision/android/android/app/src/main/res/values/strings.xml b/mobile/examples/phi-3-vision/android/android/app/src/main/res/values/strings.xml new file mode 100644 index 000000000..eaeec2693 --- /dev/null +++ b/mobile/examples/phi-3-vision/android/android/app/src/main/res/values/strings.xml @@ -0,0 +1,3 @@ + + Local Multimodal LLM + \ No newline at end of file diff --git a/mobile/examples/phi-3-vision/android/android/app/src/main/res/values/themes.xml b/mobile/examples/phi-3-vision/android/android/app/src/main/res/values/themes.xml new file mode 100644 index 000000000..5e976ac5e --- /dev/null +++ b/mobile/examples/phi-3-vision/android/android/app/src/main/res/values/themes.xml @@ -0,0 +1,16 @@ + + + + \ No newline at end of file diff --git a/mobile/examples/phi-3-vision/android/android/app/src/main/res/xml/backup_rules.xml b/mobile/examples/phi-3-vision/android/android/app/src/main/res/xml/backup_rules.xml new file mode 100644 index 000000000..fa0f996d2 --- /dev/null +++ b/mobile/examples/phi-3-vision/android/android/app/src/main/res/xml/backup_rules.xml @@ -0,0 +1,13 @@ + + + + \ No newline at end of file diff --git a/mobile/examples/phi-3-vision/android/android/app/src/main/res/xml/data_extraction_rules.xml b/mobile/examples/phi-3-vision/android/android/app/src/main/res/xml/data_extraction_rules.xml new file mode 100644 index 000000000..9ee9997b0 --- /dev/null +++ b/mobile/examples/phi-3-vision/android/android/app/src/main/res/xml/data_extraction_rules.xml @@ -0,0 +1,19 @@ + + + + + + + \ No newline at end of file diff --git a/mobile/examples/phi-3-vision/android/android/app/src/test/java/ai/onnxruntime/genai/vision/demo/README.md b/mobile/examples/phi-3-vision/android/android/app/src/test/java/ai/onnxruntime/genai/vision/demo/README.md new file mode 100644 index 000000000..2d75f43a3 --- /dev/null +++ b/mobile/examples/phi-3-vision/android/android/app/src/test/java/ai/onnxruntime/genai/vision/demo/README.md @@ -0,0 +1,5 @@ +**Note**: + +Note that we are not implementing any unit tests here as there's no simple unit tests that can be done without having a model locally. + +Debugging/Testing should be done via running the app on the simulator or actual android device. \ No newline at end of file diff --git a/mobile/examples/phi-3-vision/android/android/build.gradle.kts b/mobile/examples/phi-3-vision/android/android/build.gradle.kts new file mode 100644 index 000000000..1c3d467e4 --- /dev/null +++ b/mobile/examples/phi-3-vision/android/android/build.gradle.kts @@ -0,0 +1,4 @@ +// Top-level build file where you can add configuration options common to all sub-projects/modules. +plugins { + id("com.android.application") version "8.1.0" apply false +} \ No newline at end of file diff --git a/mobile/examples/phi-3-vision/android/android/gradle.properties b/mobile/examples/phi-3-vision/android/android/gradle.properties new file mode 100644 index 000000000..2c53b8e1f --- /dev/null +++ b/mobile/examples/phi-3-vision/android/android/gradle.properties @@ -0,0 +1,21 @@ +# Project-wide Gradle settings. +# IDE (e.g. Android Studio) users: +# Gradle settings configured through the IDE *will override* +# any settings specified in this file. +# For more details on how to configure your build environment visit +# http://www.gradle.org/docs/current/userguide/build_environment.html +# Specifies the JVM arguments used for the daemon process. +# The setting is particularly useful for tweaking memory settings. +org.gradle.jvmargs=-Xmx8192m -Dfile.encoding=UTF-8 +# When configured, Gradle will run in incubating parallel mode. +# This option should only be used with decoupled projects. More details, visit +# http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects +# org.gradle.parallel=true +# AndroidX package structure to make it clearer which packages are bundled with the +# Android operating system, and which are packaged with your app's APK +# https://developer.android.com/topic/libraries/support-library/androidx-rn +android.useAndroidX=true +# Enables namespacing of each library's R class so that its R class includes only the +# resources declared in the library itself and none from the library's dependencies, +# thereby reducing the size of the R class for that library +android.nonTransitiveRClass=true \ No newline at end of file diff --git a/mobile/examples/phi-3-vision/android/android/gradle/wrapper/gradle-wrapper.jar b/mobile/examples/phi-3-vision/android/android/gradle/wrapper/gradle-wrapper.jar new file mode 100644 index 000000000..e708b1c02 Binary files /dev/null and b/mobile/examples/phi-3-vision/android/android/gradle/wrapper/gradle-wrapper.jar differ diff --git a/mobile/examples/phi-3-vision/android/android/gradle/wrapper/gradle-wrapper.properties b/mobile/examples/phi-3-vision/android/android/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 000000000..2cc0653c4 --- /dev/null +++ b/mobile/examples/phi-3-vision/android/android/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,6 @@ +#Mon Mar 25 10:44:29 AEST 2024 +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-8.0-bin.zip +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists diff --git a/mobile/examples/phi-3-vision/android/android/gradlew b/mobile/examples/phi-3-vision/android/android/gradlew new file mode 100644 index 000000000..4f906e0c8 --- /dev/null +++ b/mobile/examples/phi-3-vision/android/android/gradlew @@ -0,0 +1,185 @@ +#!/usr/bin/env sh + +# +# Copyright 2015 the original author or authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +############################################################################## +## +## Gradle start up script for UN*X +## +############################################################################## + +# Attempt to set APP_HOME +# Resolve links: $0 may be a link +PRG="$0" +# Need this for relative symlinks. +while [ -h "$PRG" ] ; do + ls=`ls -ld "$PRG"` + link=`expr "$ls" : '.*-> \(.*\)$'` + if expr "$link" : '/.*' > /dev/null; then + PRG="$link" + else + PRG=`dirname "$PRG"`"/$link" + fi +done +SAVED="`pwd`" +cd "`dirname \"$PRG\"`/" >/dev/null +APP_HOME="`pwd -P`" +cd "$SAVED" >/dev/null + +APP_NAME="Gradle" +APP_BASE_NAME=`basename "$0"` + +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' + +# Use the maximum available, or set MAX_FD != -1 to use that value. +MAX_FD="maximum" + +warn () { + echo "$*" +} + +die () { + echo + echo "$*" + echo + exit 1 +} + +# OS specific support (must be 'true' or 'false'). +cygwin=false +msys=false +darwin=false +nonstop=false +case "`uname`" in + CYGWIN* ) + cygwin=true + ;; + Darwin* ) + darwin=true + ;; + MINGW* ) + msys=true + ;; + NONSTOP* ) + nonstop=true + ;; +esac + +CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar + + +# Determine the Java command to use to start the JVM. +if [ -n "$JAVA_HOME" ] ; then + if [ -x "$JAVA_HOME/jre/sh/java" ] ; then + # IBM's JDK on AIX uses strange locations for the executables + JAVACMD="$JAVA_HOME/jre/sh/java" + else + JAVACMD="$JAVA_HOME/bin/java" + fi + if [ ! -x "$JAVACMD" ] ; then + die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +else + JAVACMD="java" + which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." +fi + +# Increase the maximum file descriptors if we can. +if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then + MAX_FD_LIMIT=`ulimit -H -n` + if [ $? -eq 0 ] ; then + if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then + MAX_FD="$MAX_FD_LIMIT" + fi + ulimit -n $MAX_FD + if [ $? -ne 0 ] ; then + warn "Could not set maximum file descriptor limit: $MAX_FD" + fi + else + warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" + fi +fi + +# For Darwin, add options to specify how the application appears in the dock +if $darwin; then + GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" +fi + +# For Cygwin or MSYS, switch paths to Windows format before running java +if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then + APP_HOME=`cygpath --path --mixed "$APP_HOME"` + CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` + + JAVACMD=`cygpath --unix "$JAVACMD"` + + # We build the pattern for arguments to be converted via cygpath + ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` + SEP="" + for dir in $ROOTDIRSRAW ; do + ROOTDIRS="$ROOTDIRS$SEP$dir" + SEP="|" + done + OURCYGPATTERN="(^($ROOTDIRS))" + # Add a user-defined pattern to the cygpath arguments + if [ "$GRADLE_CYGPATTERN" != "" ] ; then + OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" + fi + # Now convert the arguments - kludge to limit ourselves to /bin/sh + i=0 + for arg in "$@" ; do + CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` + CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option + + if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition + eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` + else + eval `echo args$i`="\"$arg\"" + fi + i=`expr $i + 1` + done + case $i in + 0) set -- ;; + 1) set -- "$args0" ;; + 2) set -- "$args0" "$args1" ;; + 3) set -- "$args0" "$args1" "$args2" ;; + 4) set -- "$args0" "$args1" "$args2" "$args3" ;; + 5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; + 6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; + 7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; + 8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; + 9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; + esac +fi + +# Escape application args +save () { + for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done + echo " " +} +APP_ARGS=`save "$@"` + +# Collect all arguments for the java command, following the shell quoting and substitution rules +eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS" + +exec "$JAVACMD" "$@" diff --git a/mobile/examples/phi-3-vision/android/android/gradlew.bat b/mobile/examples/phi-3-vision/android/android/gradlew.bat new file mode 100644 index 000000000..107acd32c --- /dev/null +++ b/mobile/examples/phi-3-vision/android/android/gradlew.bat @@ -0,0 +1,89 @@ +@rem +@rem Copyright 2015 the original author or authors. +@rem +@rem Licensed under the Apache License, Version 2.0 (the "License"); +@rem you may not use this file except in compliance with the License. +@rem You may obtain a copy of the License at +@rem +@rem https://www.apache.org/licenses/LICENSE-2.0 +@rem +@rem Unless required by applicable law or agreed to in writing, software +@rem distributed under the License is distributed on an "AS IS" BASIS, +@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +@rem See the License for the specific language governing permissions and +@rem limitations under the License. +@rem + +@if "%DEBUG%" == "" @echo off +@rem ########################################################################## +@rem +@rem Gradle startup script for Windows +@rem +@rem ########################################################################## + +@rem Set local scope for the variables with windows NT shell +if "%OS%"=="Windows_NT" setlocal + +set DIRNAME=%~dp0 +if "%DIRNAME%" == "" set DIRNAME=. +set APP_BASE_NAME=%~n0 +set APP_HOME=%DIRNAME% + +@rem Resolve any "." and ".." in APP_HOME to make it shorter. +for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi + +@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" + +@rem Find java.exe +if defined JAVA_HOME goto findJavaFromJavaHome + +set JAVA_EXE=java.exe +%JAVA_EXE% -version >NUL 2>&1 +if "%ERRORLEVEL%" == "0" goto execute + +echo. +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:findJavaFromJavaHome +set JAVA_HOME=%JAVA_HOME:"=% +set JAVA_EXE=%JAVA_HOME%/bin/java.exe + +if exist "%JAVA_EXE%" goto execute + +echo. +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:execute +@rem Setup the command line + +set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar + + +@rem Execute Gradle +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %* + +:end +@rem End local scope for the variables with windows NT shell +if "%ERRORLEVEL%"=="0" goto mainEnd + +:fail +rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of +rem the _cmd.exe /c_ return code! +if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 +exit /b 1 + +:mainEnd +if "%OS%"=="Windows_NT" endlocal + +:omega diff --git a/mobile/examples/phi-3-vision/android/android/images/Local_LLM_1.jpg b/mobile/examples/phi-3-vision/android/android/images/Local_LLM_1.jpg new file mode 100644 index 000000000..59f54cd8e Binary files /dev/null and b/mobile/examples/phi-3-vision/android/android/images/Local_LLM_1.jpg differ diff --git a/mobile/examples/phi-3-vision/android/android/images/Local_LLM_2.jpg b/mobile/examples/phi-3-vision/android/android/images/Local_LLM_2.jpg new file mode 100644 index 000000000..2eacdb25f Binary files /dev/null and b/mobile/examples/phi-3-vision/android/android/images/Local_LLM_2.jpg differ diff --git a/mobile/examples/phi-3-vision/android/android/images/Local_LLM_3.jpg b/mobile/examples/phi-3-vision/android/android/images/Local_LLM_3.jpg new file mode 100644 index 000000000..6d4e04e7a Binary files /dev/null and b/mobile/examples/phi-3-vision/android/android/images/Local_LLM_3.jpg differ diff --git a/mobile/examples/phi-3-vision/android/android/settings.gradle.kts b/mobile/examples/phi-3-vision/android/android/settings.gradle.kts new file mode 100644 index 000000000..49ef4d372 --- /dev/null +++ b/mobile/examples/phi-3-vision/android/android/settings.gradle.kts @@ -0,0 +1,18 @@ +pluginManagement { + repositories { + google() + mavenCentral() + gradlePluginPortal() + } +} +dependencyResolutionManagement { + repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS) + repositories { + google() + mavenCentral() + } +} + +rootProject.name = "ORT GenAI Vision Demo" +include(":app") + \ No newline at end of file diff --git a/mobile/examples/phi-3/android/app/build.gradle.kts b/mobile/examples/phi-3/android/app/build.gradle.kts index 3598433ca..e8d3478a5 100644 --- a/mobile/examples/phi-3/android/app/build.gradle.kts +++ b/mobile/examples/phi-3/android/app/build.gradle.kts @@ -14,15 +14,10 @@ android { versionName = "1.0" testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner" - externalNativeBuild { - cmake { - cppFlags += "-std=c++17" - } - } ndk { //noinspection ChromeOsAbiSupport - abiFilters += listOf("arm64-v8a", "x86_64") + abiFilters += listOf("arm64-v8a") } } @@ -41,12 +36,6 @@ android { targetCompatibility = JavaVersion.VERSION_1_8 } - externalNativeBuild { - cmake { - path = file("src/main/cpp/CMakeLists.txt") - version = "3.22.1" - } - } buildFeatures { viewBinding = true } @@ -60,4 +49,9 @@ dependencies { testImplementation("junit:junit:4.13.2") androidTestImplementation("androidx.test.ext:junit:1.1.5") androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1") + + // ONNX Runtime with GenAI + implementation("com.microsoft.onnxruntime:onnxruntime-android:latest.release") + implementation(files("libs/onnxruntime-genai-android-0.4.0-dev.aar")) + } \ No newline at end of file diff --git a/mobile/examples/phi-3/android/app/libs/onnxruntime-genai-android-0.4.0-dev.aar b/mobile/examples/phi-3/android/app/libs/onnxruntime-genai-android-0.4.0-dev.aar new file mode 100644 index 000000000..8df8c6dcb Binary files /dev/null and b/mobile/examples/phi-3/android/app/libs/onnxruntime-genai-android-0.4.0-dev.aar differ diff --git a/mobile/examples/phi-3/android/app/src/main/cpp/CMakeLists.txt b/mobile/examples/phi-3/android/app/src/main/cpp/CMakeLists.txt deleted file mode 100644 index 2401f11dc..000000000 --- a/mobile/examples/phi-3/android/app/src/main/cpp/CMakeLists.txt +++ /dev/null @@ -1,43 +0,0 @@ -# For more information about using CMake with Android Studio, read the -# documentation: https://d.android.com/studio/projects/add-native-code.html. -# For more examples on how to use CMake, see https://github.com/android/ndk-samples. - -# Sets the minimum CMake version required for this project. -cmake_minimum_required(VERSION 3.22.1) - -# Declares the project name. The project name can be accessed via ${ PROJECT_NAME}, -# Since this is the top level CMakeLists.txt, the project name is also accessible -# with ${CMAKE_PROJECT_NAME} (both CMake variables are in-sync within the top level -# build script scope). -project("genai") - -#set(APP_LIB_DIR ${PROJECT_SOURCE_DIR}/../../../libs) - -# Creates and names a library, sets it as either STATIC -# or SHARED, and provides the relative paths to its source code. -# You can define multiple libraries, and CMake builds them for you. -# Gradle automatically packages shared libraries with your APK. -# -# In this top level CMakeLists.txt, ${CMAKE_PROJECT_NAME} is used to define -# the target library name; in the sub-module's CMakeLists.txt, ${PROJECT_NAME} -# is preferred for the same purpose. -# -# In order to load a library into your app from Java/Kotlin, you must call -# System.loadLibrary() and pass the name of the library defined here; -# for GameActivity/NativeActivity derived applications, the same library name must be -# used in the AndroidManifest.xml file. -add_library(${CMAKE_PROJECT_NAME} SHARED - # List C/C++ source files with relative paths to this CMakeLists.txt. - native-lib.cpp) - -target_link_directories(${CMAKE_PROJECT_NAME} PRIVATE ${PROJECT_SOURCE_DIR}/../jniLibs/${CMAKE_ANDROID_ARCH_ABI}) - -# Specifies libraries CMake should link to your target library. You -# can link libraries from various origins, such as libraries defined in this -# build script, prebuilt third-party libraries, or Android system libraries. -target_link_libraries(${CMAKE_PROJECT_NAME} - # List libraries link to the target library - onnxruntime-genai - onnxruntime - android - log) diff --git a/mobile/examples/phi-3/android/app/src/main/cpp/README.md b/mobile/examples/phi-3/android/app/src/main/cpp/README.md deleted file mode 100644 index 87f1120e8..000000000 --- a/mobile/examples/phi-3/android/app/src/main/cpp/README.md +++ /dev/null @@ -1,5 +0,0 @@ -**Note:** - - -`ort_genai_c.h` file is copied from ORT GenAI package so it matches GenAI libraries. -We copy it here because Android Studio code completion doesn't work nicely with the C++ code if the file is not in this directory. \ No newline at end of file diff --git a/mobile/examples/phi-3/android/app/src/main/cpp/native-lib.cpp b/mobile/examples/phi-3/android/app/src/main/cpp/native-lib.cpp deleted file mode 100644 index 854fb0ff0..000000000 --- a/mobile/examples/phi-3/android/app/src/main/cpp/native-lib.cpp +++ /dev/null @@ -1,181 +0,0 @@ -#include -#include -#include - -#include - -#include "ort_genai_c.h" - -namespace { - void ThrowException(JNIEnv *env, OgaResult *result) { - __android_log_write(ANDROID_LOG_DEBUG, "native", "ThrowException"); - // copy error so we can release the OgaResult - jstring jerr_msg = env->NewStringUTF(OgaResultGetError(result)); - OgaDestroyResult(result); - - static const char *className = "ai/onnxruntime/genai/demo/GenAIException"; - jclass exClazz = env->FindClass(className); - jmethodID exConstructor = env->GetMethodID(exClazz, "", "(Ljava/lang/String;)V"); - jobject javaException = env->NewObject(exClazz, exConstructor, jerr_msg); - env->Throw(static_cast(javaException)); - } - - void ThrowIfError(JNIEnv *env, OgaResult *result) { - if (result != nullptr) { - ThrowException(env, result); - } - } - - // handle conversion/release of jstring to const char* - struct CString { - CString(JNIEnv *env, jstring str) - : env_{env}, str_{str}, cstr{env->GetStringUTFChars(str, /* isCopy */ nullptr)} { - } - - const char *cstr; - - operator const char *() const { return cstr; } - - ~CString() { - env_->ReleaseStringUTFChars(str_, cstr); - } - - private: - JNIEnv *env_; - jstring str_; - }; -} - -extern "C" JNIEXPORT jlong JNICALL -Java_ai_onnxruntime_genai_demo_GenAIWrapper_loadModel(JNIEnv *env, jobject thiz, jstring model_path) { - CString path{env, model_path}; - OgaModel *model = nullptr; - OgaResult *result = OgaCreateModel(path, &model); - - ThrowIfError(env, result); - - return (jlong)model; -} - -extern "C" JNIEXPORT void JNICALL -Java_ai_onnxruntime_genai_demo_GenAIWrapper_releaseModel(JNIEnv *env, jobject thiz, jlong native_model) { - auto* model = reinterpret_cast(native_model); - OgaDestroyModel(model); -} - -extern "C" JNIEXPORT jlong JNICALL -Java_ai_onnxruntime_genai_demo_GenAIWrapper_createTokenizer(JNIEnv *env, jobject thiz, jlong native_model) { - const auto* model = reinterpret_cast(native_model); - OgaTokenizer *tokenizer = nullptr; - OgaResult* result = OgaCreateTokenizer(model, &tokenizer); - - ThrowIfError(env, result); - - return (jlong)tokenizer; -} - -extern "C" JNIEXPORT void JNICALL -Java_ai_onnxruntime_genai_demo_GenAIWrapper_releaseTokenizer(JNIEnv *env, jobject thiz, jlong native_tokenizer) { - auto* tokenizer = reinterpret_cast(native_tokenizer); - OgaDestroyTokenizer(tokenizer); -} - -extern "C" -JNIEXPORT jstring JNICALL -Java_ai_onnxruntime_genai_demo_GenAIWrapper_run(JNIEnv *env, jobject thiz, jlong native_model, jlong native_tokenizer, - jstring jprompt, jboolean use_callback) { - using SequencesPtr = std::unique_ptr>; - using GeneratorParamsPtr = std::unique_ptr>; - using TokenizerStreamPtr = std::unique_ptr>; - using GeneratorPtr = std::unique_ptr>; - - auto* model = reinterpret_cast(native_model); - auto* tokenizer = reinterpret_cast(native_tokenizer); - - CString prompt{env, jprompt}; - - const auto check_result = [env](OgaResult* result) { - ThrowIfError(env, result); - }; - - OgaSequences* sequences = nullptr; - check_result(OgaCreateSequences(&sequences)); - SequencesPtr seq_cleanup{sequences, OgaDestroySequences}; - - check_result(OgaTokenizerEncode(tokenizer, prompt, sequences)); - - OgaGeneratorParams* generator_params = nullptr; - check_result(OgaCreateGeneratorParams(model, &generator_params)); - GeneratorParamsPtr gp_cleanup{generator_params, OgaDestroyGeneratorParams}; - - check_result(OgaGeneratorParamsSetSearchNumber(generator_params, "max_length", 120)); - check_result(OgaGeneratorParamsSetInputSequences(generator_params, sequences)); - - __android_log_print(ANDROID_LOG_DEBUG, "native", "starting token generation"); - - const auto decode_tokens = [&](const int32_t* tokens, size_t num_tokens){ - const char* output_text = nullptr; - check_result(OgaTokenizerDecode(tokenizer, tokens, num_tokens, &output_text)); - jstring text = env->NewStringUTF(output_text); - OgaDestroyString(output_text); - return text; - }; - - jstring output_text; - - if (!use_callback) { - OgaSequences *output_sequences = nullptr; - check_result(OgaGenerate(model, generator_params, &output_sequences)); - SequencesPtr output_seq_cleanup(output_sequences, OgaDestroySequences); - - size_t num_sequences = OgaSequencesCount(output_sequences); - __android_log_print(ANDROID_LOG_DEBUG, "native", "%zu sequences generated", num_sequences); - - // We don't handle batched requests, so there will only be one sequence and we can hardcode using `0` as the index. - const int32_t* tokens = OgaSequencesGetSequenceData(output_sequences, 0); - size_t num_tokens = OgaSequencesGetSequenceCount(output_sequences, 0); - - output_text = decode_tokens(tokens, num_tokens); - } - else { - OgaTokenizerStream* tokenizer_stream = nullptr; - check_result(OgaCreateTokenizerStream(tokenizer, &tokenizer_stream)); - TokenizerStreamPtr stream_cleanup(tokenizer_stream, OgaDestroyTokenizerStream); - - OgaGenerator *generator = nullptr; - check_result(OgaCreateGenerator(model, generator_params, &generator)); - GeneratorPtr gen_cleanup(generator, OgaDestroyGenerator); - - // setup the callback to GenAIWrapper::gotNextToken - jclass genai_wrapper = env->GetObjectClass(thiz); - jmethodID callback_id = env->GetMethodID(genai_wrapper, "gotNextToken", "(Ljava/lang/String;)V"); - const auto do_callback = [&](const char* token){ - jstring jtoken = env->NewStringUTF(token); - env->CallVoidMethod(thiz, callback_id, jtoken); - env->DeleteLocalRef(jtoken); - }; - - while (!OgaGenerator_IsDone(generator)) { - check_result(OgaGenerator_ComputeLogits(generator)); - check_result(OgaGenerator_GenerateNextToken(generator)); - - const int32_t* seq = OgaGenerator_GetSequenceData(generator, 0); - size_t seq_len = OgaGenerator_GetSequenceCount(generator, 0); // last token - const char* token = nullptr; - check_result(OgaTokenizerStreamDecode(tokenizer_stream, seq[seq_len - 1], &token)); - do_callback(token); - // Destroy is (assumably) not required for OgaTokenizerStreamDecode based on this which seems to indicate - // the tokenizer is re-using memory for each call. - // `'out' is valid until the next call to OgaTokenizerStreamDecode - // or when the OgaTokenizerStream is destroyed` - // OgaDestroyString(token); This causes 'Scudo ERROR: misaligned pointer when deallocating address' - } - - // decode overall - const int32_t* tokens = OgaGenerator_GetSequenceData(generator, 0); - size_t num_tokens = OgaGenerator_GetSequenceCount(generator, 0); - output_text = decode_tokens(tokens, num_tokens); - } - - return output_text; -} diff --git a/mobile/examples/phi-3/android/app/src/main/cpp/ort_genai_c.h b/mobile/examples/phi-3/android/app/src/main/cpp/ort_genai_c.h deleted file mode 100644 index ac8472a92..000000000 --- a/mobile/examples/phi-3/android/app/src/main/cpp/ort_genai_c.h +++ /dev/null @@ -1,239 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include - -#ifdef __cplusplus -extern "C" { -#endif - -#ifdef _WIN32 -#ifdef BUILDING_ORT_GENAI_C -#define OGA_EXPORT __declspec(dllexport) -#else -#define OGA_EXPORT __declspec(dllimport) -#endif -#define OGA_API_CALL _stdcall -#else -// To make symbols visible on macOS/iOS -#ifdef __APPLE__ -#define OGA_EXPORT __attribute__((visibility("default"))) -#else -#define OGA_EXPORT -#endif -#define OGA_API_CALL -#endif - -// ONNX Runtime Generative AI C API -// This API is not thread safe. - -typedef struct OgaResult OgaResult; -typedef struct OgaGeneratorParams OgaGeneratorParams; -typedef struct OgaGenerator OgaGenerator; -typedef struct OgaModel OgaModel; -// OgaSequences is an array of token arrays where the number of token arrays can be obtained using -// OgaSequencesCount and the number of tokens in each token array can be obtained using OgaSequencesGetSequenceCount. -typedef struct OgaSequences OgaSequences; -typedef struct OgaTokenizer OgaTokenizer; -typedef struct OgaTokenizerStream OgaTokenizerStream; - -/* \brief Call this on process exit to cleanly shutdown the genai library & its onnxruntime usage - * \return Error message contained in the OgaResult. The const char* is owned by the OgaResult - * and can will be freed when the OgaResult is destroyed. - */ -OGA_EXPORT OgaResult* OGA_API_CALL OgaShutdown(); - -/* - * \param[in] result OgaResult that contains the error message. - * \return Error message contained in the OgaResult. The const char* is owned by the OgaResult - * and can will be freed when the OgaResult is destroyed. - */ -OGA_EXPORT const char* OGA_API_CALL OgaResultGetError(const OgaResult* result); - -/* - * \param[in] Set logging options, see logging.h 'struct LogItems' for the list of available options - */ -OGA_EXPORT OgaResult* OGA_API_CALL OgaSetLogBool(const char* name, bool value); -OGA_EXPORT OgaResult* OGA_API_CALL OgaSetLogString(const char* name, const char* value); - -/* - * \param[in] result OgaResult to be destroyed. - */ -OGA_EXPORT void OGA_API_CALL OgaDestroyResult(OgaResult*); -OGA_EXPORT void OGA_API_CALL OgaDestroyString(const char*); - -OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateSequences(OgaSequences** out); - -/* - * \param[in] sequences OgaSequences to be destroyed. - */ -OGA_EXPORT void OGA_API_CALL OgaDestroySequences(OgaSequences* sequences); - -/* - * \brief Returns the number of sequences in the OgaSequences - * \param[in] sequences - * \return The number of sequences in the OgaSequences - */ -OGA_EXPORT size_t OGA_API_CALL OgaSequencesCount(const OgaSequences* sequences); - -/* - * \brief Returns the number of tokens in the sequence at the given index - * \param[in] sequences - * \return The number of tokens in the sequence at the given index - */ -OGA_EXPORT size_t OGA_API_CALL OgaSequencesGetSequenceCount(const OgaSequences* sequences, size_t sequence_index); - -/* - * \brief Returns a pointer to the sequence data at the given index. The number of tokens in the sequence - * is given by OgaSequencesGetSequenceCount - * \param[in] sequences - * \return The pointer to the sequence data at the given index. The pointer is valid until the OgaSequences is destroyed. - */ -OGA_EXPORT const int32_t* OGA_API_CALL OgaSequencesGetSequenceData(const OgaSequences* sequences, size_t sequence_index); - -/* - * \brief Creates a model from the given configuration directory and device type. - * \param[in] config_path The path to the model configuration directory. The path is expected to be encoded in UTF-8. - * \param[in] device_type The device type to use for the model. - * \param[out] out The created model. - * \return OgaResult containing the error message if the model creation failed. - */ -OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateModel(const char* config_path, OgaModel** out); - -/* - * \brief Destroys the given model. - * \param[in] model The model to be destroyed. - */ -OGA_EXPORT void OGA_API_CALL OgaDestroyModel(OgaModel* model); - -/* - * \brief Generates an array of token arrays from the model execution based on the given generator params. - * \param[in] model The model to use for generation. - * \param[in] generator_params The parameters to use for generation. - * \param[out] out The generated sequences of tokens. The caller is responsible for freeing the sequences using OgaDestroySequences - * after it is done using the sequences. - * \return OgaResult containing the error message if the generation failed. - */ -OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerate(OgaModel* model, const OgaGeneratorParams* generator_params, OgaSequences** out); - -/* - * \brief Creates a OgaGeneratorParams from the given model. - * \param[in] model The model to use for generation. - * \param[out] out The created generator params. - * \return OgaResult containing the error message if the generator params creation failed. - */ -OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateGeneratorParams(const OgaModel* model, OgaGeneratorParams** out); - -/* - * \brief Destroys the given generator params. - * \param[in] generator_params The generator params to be destroyed. - */ -OGA_EXPORT void OGA_API_CALL OgaDestroyGeneratorParams(OgaGeneratorParams* generator_params); - -OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetSearchNumber(OgaGeneratorParams* generator_params, const char* name, double value); -OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetSearchBool(OgaGeneratorParams* generator_params, const char* name, bool value); -OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsTryGraphCaptureWithMaxBatchSize(OgaGeneratorParams* generator_params, int32_t max_batch_size); - -/* - * \brief Sets the input ids for the generator params. The input ids are used to seed the generation. - * \param[in] generator_params The generator params to set the input ids on. - * \param[in] input_ids The input ids array of size input_ids_count = batch_size * sequence_length. - * \param[in] input_ids_count The total number of input ids. - * \param[in] sequence_length The sequence length of the input ids. - * \param[in] batch_size The batch size of the input ids. - * \return OgaResult containing the error message if the setting of the input ids failed. - */ -OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetInputIDs(OgaGeneratorParams* generator_params, const int32_t* input_ids, - size_t input_ids_count, size_t sequence_length, size_t batch_size); - -/* - * \brief Sets the input id sequences for the generator params. The input id sequences are used to seed the generation. - * \param[in] generator_params The generator params to set the input ids on. - * \param[in] sequences The input id sequences. - * \return OgaResult containing the error message if the setting of the input id sequences failed. - */ -OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetInputSequences(OgaGeneratorParams* generator_params, const OgaSequences* sequences); - -OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetWhisperInputFeatures(OgaGeneratorParams*, const int32_t* inputs, size_t count); -OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetWhisperDecoderInputIDs(OgaGeneratorParams*, const int32_t* input_ids, size_t input_ids_count); - -/* - * \brief Creates a generator from the given model and generator params. - * \param[in] model The model to use for generation. - * \param[in] params The parameters to use for generation. - * \param[out] out The created generator. - * \return OgaResult containing the error message if the generator creation failed. - */ -OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateGenerator(OgaModel* model, const OgaGeneratorParams* params, OgaGenerator** out); - -/* - * \brief Destroys the given generator. - * \param[in] generator The generator to be destroyed. - */ -OGA_EXPORT void OGA_API_CALL OgaDestroyGenerator(OgaGenerator* generator); - -/* - * \brief Returns true if the generator has finished generating all the sequences. - * \param[in] generator The generator to check if it is done with generating all sequences. - * \return True if the generator has finished generating all the sequences, false otherwise. - */ -OGA_EXPORT bool OGA_API_CALL OgaGenerator_IsDone(const OgaGenerator* generator); - -/* - * \brief Computes the logits from the model based on the input ids and the past state. The computed logits are stored in the generator. - * \param[in] generator The generator to compute the logits for. - * \return OgaResult containing the error message if the computation of the logits failed. - */ -OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_ComputeLogits(OgaGenerator* generator); -OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken(OgaGenerator* generator); - -/* - * \brief Returns the number of tokens in the sequence at the given index. - * \param[in] generator The generator to get the count of the tokens for the sequence at the given index. - * \return The number tokens in the sequence at the given index. - */ -OGA_EXPORT size_t OGA_API_CALL OgaGenerator_GetSequenceCount(const OgaGenerator* generator, size_t index); - -/* - * \brief Returns a pointer to the sequence data at the given index. The number of tokens in the sequence - * is given by OgaGenerator_GetSequenceCount - * \param[in] generator The generator to get the sequence data for the sequence at the given index. - * \return The pointer to the sequence data at the given index. The sequence data is owned by the OgaGenerator - * and will be freed when the OgaGenerator is destroyed. The caller must copy the data if it needs to - * be used after the OgaGenerator is destroyed. - */ -OGA_EXPORT const int32_t* OGA_API_CALL OgaGenerator_GetSequenceData(const OgaGenerator* generator, size_t index); - -OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateTokenizer(const OgaModel* model, OgaTokenizer** out); -OGA_EXPORT void OGA_API_CALL OgaDestroyTokenizer(OgaTokenizer*); - -/* Encodes a single string and adds the encoded sequence of tokens to the OgaSequences. The OgaSequences must be freed with OgaDestroySequences - when it is no longer needed. - */ -OGA_EXPORT OgaResult* OGA_API_CALL OgaTokenizerEncode(const OgaTokenizer*, const char* str, OgaSequences* sequences); - -/* Decode a single token sequence and returns a null terminated utf8 string. out_string must be freed with OgaDestroyString - */ -OGA_EXPORT OgaResult* OGA_API_CALL OgaTokenizerDecode(const OgaTokenizer*, const int32_t* tokens, size_t token_count, const char** out_string); - -/* OgaTokenizerStream is to decoded token strings incrementally, one token at a time. - */ -OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateTokenizerStream(const OgaTokenizer*, OgaTokenizerStream** out); -OGA_EXPORT void OGA_API_CALL OgaDestroyTokenizerStream(OgaTokenizerStream*); - -/* - * Decode a single token in the stream. If this results in a word being generated, it will be returned in 'out'. - * The caller is responsible for concatenating each chunk together to generate the complete result. - * 'out' is valid until the next call to OgaTokenizerStreamDecode or when the OgaTokenizerStream is destroyed - */ -OGA_EXPORT OgaResult* OGA_API_CALL OgaTokenizerStreamDecode(OgaTokenizerStream*, int32_t token, const char** out); - -OGA_EXPORT OgaResult* OGA_API_CALL OgaSetCurrentGpuDeviceId(int device_id); -OGA_EXPORT OgaResult* OGA_API_CALL OgaGetCurrentGpuDeviceId(int* device_id); - -#ifdef __cplusplus -} -#endif \ No newline at end of file diff --git a/mobile/examples/phi-3/android/app/src/main/java/ai/onnxruntime/genai/demo/GenAIException.java b/mobile/examples/phi-3/android/app/src/main/java/ai/onnxruntime/genai/demo/GenAIException.java deleted file mode 100644 index 94ba3c27b..000000000 --- a/mobile/examples/phi-3/android/app/src/main/java/ai/onnxruntime/genai/demo/GenAIException.java +++ /dev/null @@ -1,7 +0,0 @@ -package ai.onnxruntime.genai.demo; - -public class GenAIException extends Exception { - public GenAIException(String message) { - super(message); - } -} \ No newline at end of file diff --git a/mobile/examples/phi-3/android/app/src/main/java/ai/onnxruntime/genai/demo/GenAIWrapper.java b/mobile/examples/phi-3/android/app/src/main/java/ai/onnxruntime/genai/demo/GenAIWrapper.java deleted file mode 100644 index 3599946c6..000000000 --- a/mobile/examples/phi-3/android/app/src/main/java/ai/onnxruntime/genai/demo/GenAIWrapper.java +++ /dev/null @@ -1,64 +0,0 @@ -package ai.onnxruntime.genai.demo; - -import android.util.Log; - -public class GenAIWrapper implements AutoCloseable { - // Load the GenAI library on application startup. - static { - System.loadLibrary("genai"); // JNI layer - System.loadLibrary("onnxruntime-genai"); - System.loadLibrary("onnxruntime"); - } - - private long nativeModel; - private long nativeTokenizer; - private TokenUpdateListener listener; - - public interface TokenUpdateListener { - void onTokenUpdate(String token); - } - - public GenAIWrapper(String modelPath) throws GenAIException { - nativeModel = loadModel(modelPath); - nativeTokenizer = createTokenizer(nativeModel); - } - - public void setTokenUpdateListener(TokenUpdateListener listener) { - this.listener = listener; - } - - void run(String prompt) throws GenAIException { - run(nativeModel, nativeTokenizer, prompt, /* useCallback */ true); - } - - @Override - public void close() throws Exception { - if (nativeTokenizer != 0) { - releaseTokenizer(nativeTokenizer); - } - - if (nativeModel != 0) { - releaseModel(nativeModel); - } - - nativeTokenizer = 0; - nativeModel = 0; - } - - public void gotNextToken(String token) { - Log.i("GenAI", "gotNextToken: " + token); - if (listener != null) { - listener.onTokenUpdate(token); - } - } - - private native long loadModel(String modelPath); - - private native void releaseModel(long nativeModel); - - private native long createTokenizer(long nativeModel); - - private native void releaseTokenizer(long nativeTokenizer); - - private native String run(long nativeModel, long nativeTokenizer, String prompt, boolean useCallback); -} diff --git a/mobile/examples/phi-3/android/app/src/main/java/ai/onnxruntime/genai/demo/MainActivity.java b/mobile/examples/phi-3/android/app/src/main/java/ai/onnxruntime/genai/demo/MainActivity.java index 5e228aa44..befdb6f95 100644 --- a/mobile/examples/phi-3/android/app/src/main/java/ai/onnxruntime/genai/demo/MainActivity.java +++ b/mobile/examples/phi-3/android/app/src/main/java/ai/onnxruntime/genai/demo/MainActivity.java @@ -4,9 +4,11 @@ import android.content.Context; import android.os.Bundle; +import android.text.method.ScrollingMovementMethod; import android.util.Log; import android.util.Pair; import android.view.View; +import android.view.WindowManager; import android.widget.EditText; import android.widget.ImageButton; import android.widget.TextView; @@ -18,17 +20,27 @@ import java.util.List; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.function.Consumer; +import ai.onnxruntime.genai.GenAIException; +import ai.onnxruntime.genai.Generator; +import ai.onnxruntime.genai.GeneratorParams; +import ai.onnxruntime.genai.Sequences; +import ai.onnxruntime.genai.TokenizerStream; import ai.onnxruntime.genai.demo.databinding.ActivityMainBinding; +import ai.onnxruntime.genai.Model; +import ai.onnxruntime.genai.Tokenizer; -public class MainActivity extends AppCompatActivity implements GenAIWrapper.TokenUpdateListener { +public class MainActivity extends AppCompatActivity implements Consumer { private ActivityMainBinding binding; private EditText userMsgEdt; - private GenAIWrapper genAIWrapper; + private Model model; + private Tokenizer tokenizer; private ImageButton sendMsgIB; private TextView generatedTV; private TextView promptTV; + private TextView progressText; private static final String TAG = "genai.demo.MainActivity"; private static boolean fileExists(Context context, String fileName) { @@ -56,10 +68,22 @@ protected void onCreate(Bundle savedInstanceState) { generatedTV = findViewById(R.id.sample_text); promptTV = findViewById(R.id.user_text); + Consumer tokenListener = this; + + //enable scrolling and resizing of text boxes + generatedTV.setMovementMethod(new ScrollingMovementMethod()); + getWindow().setSoftInputMode(WindowManager.LayoutParams.SOFT_INPUT_ADJUST_RESIZE); + // adding on click listener for send message button. sendMsgIB.setOnClickListener(new View.OnClickListener() { @Override public void onClick(View v) { + if (model == null) { + // if the edit text is empty display a toast message. + Toast.makeText(MainActivity.this, "Model not loaded yet, please wait...", Toast.LENGTH_SHORT).show(); + return; + } + // Checking if the message entered // by user is empty or not. if (userMsgEdt.getText().toString().isEmpty()) { @@ -69,7 +93,7 @@ public void onClick(View v) { } String promptQuestion = userMsgEdt.getText().toString(); - String promptQuestion_formatted = "<|user|>\n" + promptQuestion + "<|end|>\n<|assistant|>"; + String promptQuestion_formatted = "You are a helpful AI assistant. Answer in two paragraphs or less<|end|><|user|>"+promptQuestion+"<|end|>\n"; Log.i("GenAI: prompt question", promptQuestion_formatted); setVisibility(); @@ -85,8 +109,32 @@ public void onClick(View v) { @Override public void run() { try { - genAIWrapper.run(promptQuestion_formatted); - } catch (GenAIException e) { + TokenizerStream stream = tokenizer.createStream(); + + GeneratorParams generatorParams = model.createGeneratorParams(); + //examples for optional parameters to format AI response + //generatorParams.setSearchOption("length_penalty", 1000); + //generatorParams.setSearchOption("max_length", 500); + + Sequences encodedPrompt = tokenizer.encode(promptQuestion_formatted); + generatorParams.setInput(encodedPrompt); + + Generator generator = new Generator(model, generatorParams); + + while (!generator.isDone()) { + generator.computeLogits(); + generator.generateNextToken(); + + int token = generator.getLastTokenInSequence(0); + + tokenListener.accept(stream.decode(token)); + } + + generator.close(); + generatorParams.close(); + + } + catch (GenAIException e) { throw new RuntimeException(e); } @@ -101,99 +149,98 @@ public void run() { @Override protected void onDestroy() { - try { - genAIWrapper.close(); - } catch (Exception e) { - Log.e(TAG, "exception from closing genAIWrapper", e); - } - genAIWrapper = null; + tokenizer.close(); + tokenizer = null; + model.close(); + model = null; super.onDestroy(); } private void downloadModels(Context context) throws GenAIException { - List> urlFilePairs = Arrays.asList( - new Pair<>( - "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx/resolve/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/added_tokens.json?download=true", - "added_tokens.json"), - new Pair<>( - "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx/resolve/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/config.json?download=true", - "config.json"), - new Pair<>( - "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx/resolve/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/configuration_phi3.py?download=true", - "configuration_phi3.py"), - new Pair<>( - "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx/resolve/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/genai_config.json?download=true", - "genai_config.json"), - new Pair<>( - "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx/resolve/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/phi3-mini-4k-instruct-cpu-int4-rtn-block-32-acc-level-4.onnx?download=true", - "phi3-mini-4k-instruct-cpu-int4-rtn-block-32-acc-level-4.onnx"), - new Pair<>( - "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx/resolve/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/phi3-mini-4k-instruct-cpu-int4-rtn-block-32-acc-level-4.onnx.data?download=true", - "phi3-mini-4k-instruct-cpu-int4-rtn-block-32-acc-level-4.onnx.data"), - new Pair<>( - "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx/resolve/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/special_tokens_map.json?download=true", - "special_tokens_map.json"), - new Pair<>( - "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx/resolve/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/tokenizer.json?download=true", - "tokenizer.json"), - new Pair<>( - "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx/resolve/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/tokenizer.model?download=true", - "tokenizer.model"), - new Pair<>( - "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx/resolve/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/tokenizer_config.json?download=true", - "tokenizer_config.json")); + + final String baseUrl = "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx/resolve/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/"; + List files = Arrays.asList( + "added_tokens.json", + "config.json", + "configuration_phi3.py", + "genai_config.json", + "phi3-mini-4k-instruct-cpu-int4-rtn-block-32-acc-level-4.onnx", + "phi3-mini-4k-instruct-cpu-int4-rtn-block-32-acc-level-4.onnx.data", + "special_tokens_map.json", + "tokenizer.json", + "tokenizer.model", + "tokenizer_config.json"); + + List> urlFilePairs = new ArrayList<>(); + for (String file : files) { + if (/*file.endsWith(".data") ||*/ !fileExists(context, file)) { + urlFilePairs.add(new Pair<>( + baseUrl + file,// + "?download=true", + file)); + } + } + if (urlFilePairs.isEmpty()) { + // Display a message using Toast + Toast.makeText(this, "All files already exist. Skipping download.", Toast.LENGTH_SHORT).show(); + Log.d(TAG, "All files already exist. Skipping download."); + model = new Model(getFilesDir().getPath()); + tokenizer = model.createTokenizer(); + return; + } + + progressText.setText("Downloading..."); + progressText.setVisibility(View.VISIBLE); + Toast.makeText(this, "Downloading model for the app... Model Size greater than 2GB, please allow a few minutes to download.", Toast.LENGTH_SHORT).show(); ExecutorService executor = Executors.newSingleThreadExecutor(); - for (int i = 0; i < urlFilePairs.size(); i++) { - final int index = i; - String url = urlFilePairs.get(index).first; - String fileName = urlFilePairs.get(index).second; - if (fileExists(context, fileName)) { - // Display a message using Toast - Toast.makeText(this, "File already exists. Skipping Download.", Toast.LENGTH_SHORT).show(); - - Log.d(TAG, "File " + fileName + " already exists. Skipping download."); - // note: since we always download the files lists together for once, - // so assuming if one filename exists, then the download model step has already - // be - // done. - genAIWrapper = createGenAIWrapper(); - break; - } - executor.execute(() -> { - ModelDownloader.downloadModel(context, url, fileName, new ModelDownloader.DownloadCallback() { - @Override - public void onDownloadComplete() throws GenAIException { - Log.d(TAG, "Download complete for " + fileName); - if (index == urlFilePairs.size() - 1) { - // Last download completed, create GenAIWrapper - genAIWrapper = createGenAIWrapper(); - Log.d(TAG, "All downloads completed"); - } + executor.execute(() -> { + ModelDownloader.downloadModel(context, urlFilePairs, new ModelDownloader.DownloadCallback() { + @Override + public void onProgress(long lastBytesRead, long bytesRead, long bytesTotal) { + long lastPctDone = 100 * lastBytesRead / bytesTotal; + long pctDone = 100 * bytesRead / bytesTotal; + if (pctDone > lastPctDone) { + Log.d(TAG, "Downloading files: " + pctDone + "%"); + runOnUiThread(() -> { + progressText.setText("Downloading: " + pctDone + "%"); + }); + } + } + @Override + public void onDownloadComplete() { + Log.d(TAG, "All downloads completed."); + + // Last download completed, create SimpleGenAI + try { + model = new Model(getFilesDir().getPath()); + tokenizer = model.createTokenizer(); + runOnUiThread(() -> { + Toast.makeText(context, "All downloads completed", Toast.LENGTH_SHORT).show(); + progressText.setVisibility(View.INVISIBLE); + }); + } catch (GenAIException e) { + e.printStackTrace(); + throw new RuntimeException(e); } - }); + + } }); - } + }); executor.shutdown(); } - private GenAIWrapper createGenAIWrapper() throws GenAIException { - // Create GenAIWrapper object and load model from android device file path. - GenAIWrapper wrapper = new GenAIWrapper(getFilesDir().getPath()); - wrapper.setTokenUpdateListener(this); - return wrapper; - } - @Override - public void onTokenUpdate(String token) { + public void accept(String token) { runOnUiThread(() -> { // Update and aggregate the generated text and write to text box. CharSequence generated = generatedTV.getText(); generatedTV.setText(generated + token); generatedTV.invalidate(); + final int scrollAmount = generatedTV.getLayout().getLineTop(generatedTV.getLineCount()) - generatedTV.getHeight(); + generatedTV.scrollTo(0, Math.max(scrollAmount, 0)); }); } diff --git a/mobile/examples/phi-3/android/app/src/main/java/ai/onnxruntime/genai/demo/ModelDownloader.java b/mobile/examples/phi-3/android/app/src/main/java/ai/onnxruntime/genai/demo/ModelDownloader.java index a79c4e6c5..e309a11da 100644 --- a/mobile/examples/phi-3/android/app/src/main/java/ai/onnxruntime/genai/demo/ModelDownloader.java +++ b/mobile/examples/phi-3/android/app/src/main/java/ai/onnxruntime/genai/demo/ModelDownloader.java @@ -15,6 +15,8 @@ import java.net.HttpURLConnection; import java.net.URL; +import ai.onnxruntime.genai.GenAIException; + public class ModelDownloader { interface DownloadCallback { void onDownloadComplete() throws GenAIException; @@ -59,7 +61,7 @@ public static void downloadModel(Context context, String url, String fileName, D e.printStackTrace(); Log.e(TAG, "Exception occurred during model download: " + e.getMessage()); } catch (GenAIException e) { - throw new RuntimeException(e); + throw new RuntimeException(e); } } } \ No newline at end of file diff --git a/mobile/examples/phi-3/android/app/src/main/jniLibs/arm64-v8a/libonnxruntime-genai.so b/mobile/examples/phi-3/android/app/src/main/jniLibs/arm64-v8a/libonnxruntime-genai.so deleted file mode 100755 index f5e6e6e05..000000000 Binary files a/mobile/examples/phi-3/android/app/src/main/jniLibs/arm64-v8a/libonnxruntime-genai.so and /dev/null differ diff --git a/mobile/examples/phi-3/android/app/src/main/jniLibs/arm64-v8a/libonnxruntime.so b/mobile/examples/phi-3/android/app/src/main/jniLibs/arm64-v8a/libonnxruntime.so deleted file mode 100644 index c14976892..000000000 Binary files a/mobile/examples/phi-3/android/app/src/main/jniLibs/arm64-v8a/libonnxruntime.so and /dev/null differ diff --git a/mobile/examples/phi-3/android/app/src/main/jniLibs/arm64-v8a/libonnxruntime4j_jni.so b/mobile/examples/phi-3/android/app/src/main/jniLibs/arm64-v8a/libonnxruntime4j_jni.so deleted file mode 100644 index f75540a19..000000000 Binary files a/mobile/examples/phi-3/android/app/src/main/jniLibs/arm64-v8a/libonnxruntime4j_jni.so and /dev/null differ diff --git a/mobile/examples/phi-3/android/app/src/main/jniLibs/x86_64/libonnxruntime-genai.so b/mobile/examples/phi-3/android/app/src/main/jniLibs/x86_64/libonnxruntime-genai.so deleted file mode 100755 index f986355ab..000000000 Binary files a/mobile/examples/phi-3/android/app/src/main/jniLibs/x86_64/libonnxruntime-genai.so and /dev/null differ diff --git a/mobile/examples/phi-3/android/app/src/main/jniLibs/x86_64/libonnxruntime.so b/mobile/examples/phi-3/android/app/src/main/jniLibs/x86_64/libonnxruntime.so deleted file mode 100644 index c14976892..000000000 Binary files a/mobile/examples/phi-3/android/app/src/main/jniLibs/x86_64/libonnxruntime.so and /dev/null differ diff --git a/mobile/examples/phi-3/android/app/src/main/jniLibs/x86_64/libonnxruntime4j_jni.so b/mobile/examples/phi-3/android/app/src/main/jniLibs/x86_64/libonnxruntime4j_jni.so deleted file mode 100644 index f75540a19..000000000 Binary files a/mobile/examples/phi-3/android/app/src/main/jniLibs/x86_64/libonnxruntime4j_jni.so and /dev/null differ diff --git a/mobile/examples/phi-3/android/app/src/main/res/drawable/rounded_corner2.xml b/mobile/examples/phi-3/android/app/src/main/res/drawable/rounded_corner2.xml index de9870ee9..0fdf4eb96 100644 --- a/mobile/examples/phi-3/android/app/src/main/res/drawable/rounded_corner2.xml +++ b/mobile/examples/phi-3/android/app/src/main/res/drawable/rounded_corner2.xml @@ -1,7 +1,7 @@ - + @@ -12,12 +12,14 @@ - - - - @color/blue_700 @color/black - @color/teal_200 - @color/teal_200 + @color/purple_200 + @color/purple_700 @color/black ?attr/colorPrimaryVariant diff --git a/mobile/examples/phi-3/android/app/src/main/res/values/colors.xml b/mobile/examples/phi-3/android/app/src/main/res/values/colors.xml index 8a660986c..644390572 100644 --- a/mobile/examples/phi-3/android/app/src/main/res/values/colors.xml +++ b/mobile/examples/phi-3/android/app/src/main/res/values/colors.xml @@ -3,8 +3,8 @@ #CC03A9F4 #CC03A9F4 #CC03A9F4 - #FF03DAC5 - #FF018786 + #a367e8 + #8a38e9 #FF000000 #FFFFFFFF \ No newline at end of file diff --git a/mobile/examples/phi-3/android/app/src/main/res/values/themes.xml b/mobile/examples/phi-3/android/app/src/main/res/values/themes.xml index 53ebc75d6..5e976ac5e 100644 --- a/mobile/examples/phi-3/android/app/src/main/res/values/themes.xml +++ b/mobile/examples/phi-3/android/app/src/main/res/values/themes.xml @@ -6,8 +6,8 @@ @color/blue_700 @color/white - @color/teal_200 - @color/teal_700 + @color/purple_200 + @color/purple_700 @color/black ?attr/colorPrimaryVariant