diff --git a/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java b/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java index ed020fd7..833b5883 100644 --- a/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java +++ b/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java @@ -63,7 +63,8 @@ public class MainActivity extends AppCompatActivity implements Runnable, LlmCall private ImageButton mCameraButton; private ListView mMessagesView; private MessageAdapter mMessageAdapter; - private LlmModule mModule = null; + private List mLoadedModules = new ArrayList<>(); + private LlmModule mActiveModule = null; private Message mResultMessage = null; private ImageButton mSettingsButton; private TextView mMemoryView; @@ -145,7 +146,7 @@ private void setLocalModel( long runStartTime = System.currentTimeMillis(); // Create LlmModule with dataPath - mModule = + mActiveModule = new LlmModule( ModelUtils.getModelCategory( mCurrentSettingsFields.getModelType(), mCurrentSettingsFields.getBackendType()), @@ -153,8 +154,9 @@ private void setLocalModel( tokenizerPath, temperature, dataPath); + mLoadedModules.add(mActiveModule); - int loadResult = mModule.load(); + int loadResult = mActiveModule.load(); long loadDuration = System.currentTimeMillis() - runStartTime; String modelLoadError = ""; String modelInfo = ""; @@ -186,7 +188,7 @@ private void setLocalModel( if (mCurrentSettingsFields.getModelType() == ModelType.LLAVA_1_5) { ETLogging.getInstance().log("Llava start prefill prompt"); - startPos = mModule.prefillPrompt(PromptFormat.getLlavaPresetPrompt()); + startPos = mActiveModule.prefillPrompt(PromptFormat.getLlavaPresetPrompt()); ETLogging.getInstance().log("Llava completes prefill prompt"); } } @@ -658,7 +660,7 @@ private void showMediaPreview(List uris) { ETImage img = processedImageList.get(0); ETLogging.getInstance().log("Llava start prefill image"); startPos = - mModule.prefillImages( + mActiveModule.prefillImages( img.getInts(), img.getWidth(), img.getHeight(), @@ -687,7 +689,7 @@ private void onModelRunStarted() { mSendButton.setImageResource(R.drawable.baseline_stop_24); mSendButton.setOnClickListener( view -> { - mModule.stop(); + mActiveModule.stop(); }); } @@ -738,21 +740,21 @@ public void run() { mCurrentSettingsFields.getModelType(), mCurrentSettingsFields.getBackendType()) == ModelUtils.VISION_MODEL) { - mModule.generate( + mActiveModule.generate( finalPrompt, ModelUtils.VISION_MODEL_SEQ_LEN, MainActivity.this, false); } else if (mCurrentSettingsFields.getModelType() == ModelType.LLAMA_GUARD_3) { String llamaGuardPromptForClassification = PromptFormat.getFormattedLlamaGuardPrompt(rawPrompt); ETLogging.getInstance() .log("Running inference.. prompt=" + llamaGuardPromptForClassification); - mModule.generate( + mActiveModule.generate( llamaGuardPromptForClassification, llamaGuardPromptForClassification.length() + 64, MainActivity.this, false); } else { ETLogging.getInstance().log("Running inference.. prompt=" + finalPrompt); - mModule.generate( + mActiveModule.generate( finalPrompt, (int) (finalPrompt.length() * 0.75) + 64, MainActivity.this,