Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Mlc ai Web llm Multi Model RAG Engine

From Leeroopedia

Template:Metadata

Overview

Multi_Model_RAG_Engine is a Wrapper Doc that documents how web-llm's MLCEngine supports loading multiple models simultaneously to enable in-browser RAG pipelines. This covers the multi-model loading mechanism in CreateMLCEngine() and engine.reload(), the model-dispatched embedding() and chatCompletion() methods, and the complete RAG example from the official repository.

Code Reference

Multi-Model Engine Creation

From src/engine.ts at lines 90-98, CreateMLCEngine() accepts an array of model IDs:

export async function CreateMLCEngine(
  modelId: string | string[],
  engineConfig?: MLCEngineConfig,
  chatOpts?: ChatOptions | ChatOptions[],
): Promise<MLCEngine> {
  const engine = new MLCEngine(engineConfig);
  await engine.reload(modelId, chatOpts);
  return engine;
}

Sequential Model Loading

From src/engine.ts at lines 194-237, reload() loads models sequentially:

async reload(
  modelId: string | string[],
  chatOpts?: ChatOptions | ChatOptions[],
): Promise<void> {
  await this.unload();
  if (!Array.isArray(modelId)) {
    modelId = [modelId];
  }
  // Validate unique model IDs
  if (new Set(modelId).size < modelId.length) {
    throw new ReloadModelIdNotUniqueError(modelId);
  }
  // Load each model sequentially
  this.reloadController = new AbortController();
  try {
    for (let i = 0; i < modelId.length; i++) {
      await this.reloadInternal(
        modelId[i],
        chatOpts ? chatOpts[i] : undefined,
      );
    }
  } catch (error) {
    if (error instanceof DOMException && error.name === "AbortError") {
      log.warn("Reload() is aborted.", error.message);
      return;
    }
    throw error;
  }
}

Pipeline Type Dispatch

From src/engine.ts at lines 379-394, during reloadInternal(), the engine creates the appropriate pipeline based on model_type:

let newPipeline: LLMChatPipeline | EmbeddingPipeline;
if (modelRecord.model_type === ModelType.embedding) {
  newPipeline = new EmbeddingPipeline(tvm, tokenizer, curModelConfig);
} else {
  newPipeline = new LLMChatPipeline(
    tvm, tokenizer, curModelConfig, logitProcessor,
  );
}
await newPipeline.asyncLoadWebGPUPipelines();
this.loadedModelIdToPipeline.set(modelId, newPipeline);

Internal State Maps

The MLCEngine maintains four maps to track loaded models (lines 117-128):

// Maps each loaded model's modelId to its pipeline
private loadedModelIdToPipeline: Map<string, LLMChatPipeline | EmbeddingPipeline>;
// Maps each loaded model's modelId to its chatConfig
private loadedModelIdToChatConfig: Map<string, ChatConfig>;
// Maps each loaded model's modelId to its modelType
private loadedModelIdToModelType: Map<string, ModelType>;
// Maps each loaded model's modelId to a lock (one request at a time per model)
private loadedModelIdToLock: Map<string, CustomLock>;

Model Selection for API Calls

When multiple models are loaded, the model parameter in API requests becomes required. The engine resolves which model to use via getModelStates() (lines 1209-1267), which:

  1. Retrieves all loaded model IDs
  2. Calls getModelIdToUse() to select the appropriate model (auto-selects if only one model of the correct type is loaded; requires explicit specification otherwise)
  3. Validates that the selected model has the correct pipeline type (EmbeddingPipeline for embedding requests, LLMChatPipeline for chat/completion requests)

Embedding API

engine.embeddings.create() dispatches to MLCEngine.embedding() at lines 1084-1130, which:

  • Resolves the embedding model pipeline
  • Acquires the model's lock
  • Calls EmbeddingPipeline.embedStep()
  • Returns CreateEmbeddingResponse

Chat Completion API

engine.chat.completions.create() dispatches to MLCEngine.chatCompletion() at lines 767-945, which:

  • Resolves the LLM pipeline
  • Acquires the model's lock
  • Runs prefill and decode to generate text
  • Returns ChatCompletion or streams ChatCompletionChunk

Official RAG Example

From examples/embeddings/src/embeddings.ts at lines 160-206:

async function simpleRAG() {
  // 0. Load both embedding model and LLM to a single WebLLM Engine
  const embeddingModelId = "snowflake-arctic-embed-m-q0f32-MLC-b4";
  const llmModelId = "gemma-2-2b-it-q4f32_1-MLC-1k";
  const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine(
    [embeddingModelId, llmModelId],
    {
      initProgressCallback: initProgressCallback,
      logLevel: "INFO",
    },
  );

  const vectorStore = await MemoryVectorStore.fromTexts(
    ["mitochondria is the powerhouse of the cell"],
    [{ id: 1 }],
    new WebLLMEmbeddings(engine, embeddingModelId),
  );
  const retriever = vectorStore.asRetriever();

  const prompt =
    PromptTemplate.fromTemplate(`Answer the question based only on the following context:
  {context}

  Question: {question}`);

  const chain = RunnableSequence.from([
    {
      context: retriever.pipe(formatDocumentsAsString),
      question: new RunnablePassthrough(),
    },
    prompt,
  ]);

  const formattedPrompt = (
    await chain.invoke("What is the powerhouse of the cell?")
  ).toString();
  const reply = await engine.chat.completions.create({
    messages: [{ role: "user", content: formattedPrompt }],
    model: llmModelId,
  });

  console.log(reply.choices[0].message.content);
  // "The powerhouse of the cell is the mitochondria."
}

I/O Contract

Import:

import {
  CreateMLCEngine,
  MLCEngine,
  MLCEngineConfig,
  ModelType,
  prebuiltAppConfig,
  EmbeddingCreateParams,
  CreateEmbeddingResponse,
  ChatCompletionRequest,
  ChatCompletion,
} from "@mlc-ai/web-llm";

Multi-model creation signature:

// Load multiple models at once
const engine: MLCEngine = await CreateMLCEngine(
  [embeddingModelId, llmModelId],  // string[]
  engineConfig?,                    // MLCEngineConfig
  [embeddingChatOpts?, llmChatOpts?],  // ChatOptions[]
);

Embedding call (requires model parameter when multiple models loaded):

const embResult: CreateEmbeddingResponse = await engine.embeddings.create({
  input: "text to embed",
  model: embeddingModelId,  // REQUIRED when multiple models loaded
});

Chat completion call (requires model parameter when multiple models loaded):

const chatResult: ChatCompletion = await engine.chat.completions.create({
  messages: [{ role: "user", content: "..." }],
  model: llmModelId,  // REQUIRED when multiple models loaded
});

Constraints:

  • All model IDs in the array must be unique (ReloadModelIdNotUniqueError)
  • The chatOpts array size must match the modelId array size (ReloadArgumentSizeUnmatchedError)
  • Total VRAM must accommodate all loaded models simultaneously
  • When multiple models are loaded, the model parameter is required in all API calls

Usage Examples

Complete RAG Pipeline Without LangChain

import * as webllm from "@mlc-ai/web-llm";

// Helper: dot product for cosine similarity of normalized vectors
function dotProduct(a: number[], b: number[]): number {
  let sum = 0;
  for (let i = 0; i < a.length; i++) sum += a[i] * b[i];
  return sum;
}

// Configuration
const embModelId = "snowflake-arctic-embed-m-q0f32-MLC-b4";
const llmId = "Llama-3.2-1B-Instruct-q4f32_1-MLC";

// Stage 1: Load both models
const engine = await webllm.CreateMLCEngine([embModelId, llmId], {
  initProgressCallback: (report) => {
    console.log(`${report.text} (${(report.progress * 100).toFixed(1)}%)`);
  },
});

// Stage 2: Index documents
const QUERY_PREFIX =
  "Represent this sentence for searching relevant passages: ";
const corpus = [
  "The Great Wall of China is over 13,000 miles long.",
  "Photosynthesis converts sunlight into chemical energy in plants.",
  "The Python programming language was created by Guido van Rossum.",
  "Mount Everest is the tallest mountain on Earth at 29,032 feet.",
];

const docFormatted = corpus.map((d) => `[CLS] ${d} [SEP]`);
const docEmbeddings = await engine.embeddings.create({
  input: docFormatted,
  model: embModelId,
});

// Stage 3 & 4: Embed query and retrieve
const question = "What is the tallest mountain?";
const queryFormatted = `[CLS] ${QUERY_PREFIX}${question} [SEP]`;
const queryEmbedding = await engine.embeddings.create({
  input: queryFormatted,
  model: embModelId,
});

const queryVec = queryEmbedding.data[0].embedding;
const scores = corpus.map((doc, i) => ({
  doc,
  score: dotProduct(queryVec, docEmbeddings.data[i].embedding),
}));
scores.sort((a, b) => b.score - a.score);
const topDocs = scores.slice(0, 2);

console.log("Retrieved documents:");
for (const d of topDocs) {
  console.log(`  [${d.score.toFixed(4)}] ${d.doc}`);
}

// Stage 5: Generate grounded answer
const context = topDocs.map((d) => d.doc).join("\n");
const reply = await engine.chat.completions.create({
  messages: [
    {
      role: "user",
      content: `Answer based only on this context:\n${context}\n\nQuestion: ${question}`,
    },
  ],
  model: llmId,
});

console.log("Answer:", reply.choices[0].message.content);

RAG with LangChain RunnableSequence

import * as webllm from "@mlc-ai/web-llm";
import { MemoryVectorStore } from "langchain/vectorstores/memory";
import type { EmbeddingsInterface } from "@langchain/core/embeddings";
import { formatDocumentsAsString } from "langchain/util/document";
import { PromptTemplate } from "@langchain/core/prompts";
import {
  RunnableSequence,
  RunnablePassthrough,
} from "@langchain/core/runnables";

class WebLLMEmbeddings implements EmbeddingsInterface {
  engine: webllm.MLCEngineInterface;
  modelId: string;
  constructor(engine: webllm.MLCEngineInterface, modelId: string) {
    this.engine = engine;
    this.modelId = modelId;
  }
  async embedQuery(text: string): Promise<number[]> {
    const reply = await this.engine.embeddings.create({
      input: [text],
      model: this.modelId,
    });
    return reply.data[0].embedding;
  }
  async embedDocuments(texts: string[]): Promise<number[][]> {
    const reply = await this.engine.embeddings.create({
      input: texts,
      model: this.modelId,
    });
    return reply.data.map((d) => d.embedding);
  }
}

// Load models
const embId = "snowflake-arctic-embed-m-q0f32-MLC-b4";
const llmId = "gemma-2-2b-it-q4f32_1-MLC-1k";
const engine = await webllm.CreateMLCEngine([embId, llmId], {
  initProgressCallback: (r) => console.log(r.text),
  logLevel: "INFO",
});

// Build vector store with multiple documents
const vectorStore = await MemoryVectorStore.fromTexts(
  [
    "mitochondria is the powerhouse of the cell",
    "DNA contains the genetic blueprint of organisms",
    "neurons transmit electrical signals in the nervous system",
  ],
  [{ id: 1 }, { id: 2 }, { id: 3 }],
  new WebLLMEmbeddings(engine, embId),
);

// Build LangChain retrieval chain
const retriever = vectorStore.asRetriever({ k: 2 });
const prompt = PromptTemplate.fromTemplate(
  `Answer the question based only on the following context:
{context}

Question: {question}`,
);

const chain = RunnableSequence.from([
  {
    context: retriever.pipe(formatDocumentsAsString),
    question: new RunnablePassthrough(),
  },
  prompt,
]);

// Run RAG
const formattedPrompt = (
  await chain.invoke("What is the powerhouse of the cell?")
).toString();

const reply = await engine.chat.completions.create({
  messages: [{ role: "user", content: formattedPrompt }],
  model: llmId,
});
console.log(reply.choices[0].message.content);

Streaming RAG Response

import * as webllm from "@mlc-ai/web-llm";

const embModelId = "snowflake-arctic-embed-s-q0f32-MLC-b4";
const llmId = "Llama-3.2-1B-Instruct-q4f32_1-MLC";
const engine = await webllm.CreateMLCEngine([embModelId, llmId]);

// ... (assume documents are already embedded and top documents retrieved)
const retrievedContext = "WebGPU provides GPU-accelerated computation in browsers.";
const question = "How can I use the GPU in a browser?";

// Stream the RAG response for real-time display
const stream = await engine.chat.completions.create({
  messages: [
    {
      role: "user",
      content: `Context: ${retrievedContext}\n\nQuestion: ${question}`,
    },
  ],
  model: llmId,
  stream: true,
  stream_options: { include_usage: true },
});

let fullResponse = "";
for await (const chunk of stream) {
  const delta = chunk.choices[0]?.delta?.content || "";
  fullResponse += delta;
  process.stdout.write(delta);
  // In a browser, update a DOM element instead:
  // document.getElementById("output").innerText = fullResponse;
}
console.log("\n\nFull response:", fullResponse);

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment