diff --git a/llama_stack/apis/memory/client.py b/llama_stack/apis/memory/client.py index b4bfcb34d..04c2dab5b 100644 --- a/llama_stack/apis/memory/client.py +++ b/llama_stack/apis/memory/client.py @@ -13,9 +13,9 @@ from typing import Any, Dict, List, Optional import fire import httpx +from termcolor import cprint from llama_stack.distribution.datatypes import RemoteProviderConfig -from termcolor import cprint from llama_stack.apis.memory import * # noqa: F403 from llama_stack.providers.utils.memory.file_utils import data_url_from_file @@ -120,7 +120,7 @@ async def run_main(host: str, port: int, stream: bool): name="test_bank", config=VectorMemoryBankConfig( bank_id="test_bank", - embedding_model="dragon-roberta-query-2", + embedding_model="all-MiniLM-L6-v2", chunk_size_in_tokens=512, overlap_size_in_tokens=64, ), @@ -129,7 +129,7 @@ async def run_main(host: str, port: int, stream: bool): retrieved_bank = await client.get_memory_bank(bank.bank_id) assert retrieved_bank is not None - assert retrieved_bank.config.embedding_model == "dragon-roberta-query-2" + assert retrieved_bank.config.embedding_model == "all-MiniLM-L6-v2" urls = [ "memory_optimizations.rst", diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index 929c91bda..1683ddaa1 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -25,20 +25,22 @@ from llama_stack.apis.memory import * # noqa: F403 ALL_MINILM_L6_V2_DIMENSION = 384 -EMBEDDING_MODEL = None +EMBEDDING_MODELS = {} -def get_embedding_model() -> "SentenceTransformer": - global EMBEDDING_MODEL +def get_embedding_model(model: str) -> "SentenceTransformer": + global EMBEDDING_MODELS - if EMBEDDING_MODEL is None: - print("Loading sentence transformer") + loaded_model = EMBEDDING_MODELS.get(model) + if loaded_model is not None: + return loaded_model - from sentence_transformers import SentenceTransformer + print(f"Loading sentence transformer for {model}...") + from sentence_transformers import SentenceTransformer - EMBEDDING_MODEL = SentenceTransformer("all-MiniLM-L6-v2") - - return EMBEDDING_MODEL + loaded_model = SentenceTransformer(model) + EMBEDDING_MODELS[model] = loaded_model + return loaded_model def parse_data_url(data_url: str): @@ -151,7 +153,7 @@ class BankWithIndex: self, documents: List[MemoryBankDocument], ) -> None: - model = get_embedding_model() + model = get_embedding_model(self.bank.config.embedding_model) for doc in documents: content = await content_from_doc(doc) chunks = make_overlapped_chunks( @@ -187,6 +189,6 @@ class BankWithIndex: else: query_str = _process(query) - model = get_embedding_model() + model = get_embedding_model(self.bank.config.embedding_model) query_vector = model.encode([query_str])[0].astype(np.float32) return await self.index.query(query_vector, k)