Respect passed in embedding model

This commit is contained in:
Ashwin Bharambe 2024-09-24 14:40:28 -07:00
parent bda974e660
commit 00352bd251
2 changed files with 16 additions and 14 deletions

View file

@ -13,9 +13,9 @@ from typing import Any, Dict, List, Optional
import fire
import httpx
from termcolor import cprint
from llama_stack.distribution.datatypes import RemoteProviderConfig
from termcolor import cprint
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.utils.memory.file_utils import data_url_from_file
@ -120,7 +120,7 @@ async def run_main(host: str, port: int, stream: bool):
name="test_bank",
config=VectorMemoryBankConfig(
bank_id="test_bank",
embedding_model="dragon-roberta-query-2",
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
),
@ -129,7 +129,7 @@ async def run_main(host: str, port: int, stream: bool):
retrieved_bank = await client.get_memory_bank(bank.bank_id)
assert retrieved_bank is not None
assert retrieved_bank.config.embedding_model == "dragon-roberta-query-2"
assert retrieved_bank.config.embedding_model == "all-MiniLM-L6-v2"
urls = [
"memory_optimizations.rst",

View file

@ -25,20 +25,22 @@ from llama_stack.apis.memory import * # noqa: F403
ALL_MINILM_L6_V2_DIMENSION = 384
EMBEDDING_MODEL = None
EMBEDDING_MODELS = {}
def get_embedding_model() -> "SentenceTransformer":
global EMBEDDING_MODEL
def get_embedding_model(model: str) -> "SentenceTransformer":
global EMBEDDING_MODELS
if EMBEDDING_MODEL is None:
print("Loading sentence transformer")
loaded_model = EMBEDDING_MODELS.get(model)
if loaded_model is not None:
return loaded_model
print(f"Loading sentence transformer for {model}...")
from sentence_transformers import SentenceTransformer
EMBEDDING_MODEL = SentenceTransformer("all-MiniLM-L6-v2")
return EMBEDDING_MODEL
loaded_model = SentenceTransformer(model)
EMBEDDING_MODELS[model] = loaded_model
return loaded_model
def parse_data_url(data_url: str):
@ -151,7 +153,7 @@ class BankWithIndex:
self,
documents: List[MemoryBankDocument],
) -> None:
model = get_embedding_model()
model = get_embedding_model(self.bank.config.embedding_model)
for doc in documents:
content = await content_from_doc(doc)
chunks = make_overlapped_chunks(
@ -187,6 +189,6 @@ class BankWithIndex:
else:
query_str = _process(query)
model = get_embedding_model()
model = get_embedding_model(self.bank.config.embedding_model)
query_vector = model.encode([query_str])[0].astype(np.float32)
return await self.index.query(query_vector, k)