Respect passed in embedding model

This commit is contained in:
Ashwin Bharambe 2024-09-24 14:40:28 -07:00
parent bda974e660
commit 00352bd251
2 changed files with 16 additions and 14 deletions

View file

@ -13,9 +13,9 @@ from typing import Any, Dict, List, Optional
import fire import fire
import httpx import httpx
from termcolor import cprint
from llama_stack.distribution.datatypes import RemoteProviderConfig from llama_stack.distribution.datatypes import RemoteProviderConfig
from termcolor import cprint
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.utils.memory.file_utils import data_url_from_file from llama_stack.providers.utils.memory.file_utils import data_url_from_file
@ -120,7 +120,7 @@ async def run_main(host: str, port: int, stream: bool):
name="test_bank", name="test_bank",
config=VectorMemoryBankConfig( config=VectorMemoryBankConfig(
bank_id="test_bank", bank_id="test_bank",
embedding_model="dragon-roberta-query-2", embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512, chunk_size_in_tokens=512,
overlap_size_in_tokens=64, overlap_size_in_tokens=64,
), ),
@ -129,7 +129,7 @@ async def run_main(host: str, port: int, stream: bool):
retrieved_bank = await client.get_memory_bank(bank.bank_id) retrieved_bank = await client.get_memory_bank(bank.bank_id)
assert retrieved_bank is not None assert retrieved_bank is not None
assert retrieved_bank.config.embedding_model == "dragon-roberta-query-2" assert retrieved_bank.config.embedding_model == "all-MiniLM-L6-v2"
urls = [ urls = [
"memory_optimizations.rst", "memory_optimizations.rst",

View file

@ -25,20 +25,22 @@ from llama_stack.apis.memory import * # noqa: F403
ALL_MINILM_L6_V2_DIMENSION = 384 ALL_MINILM_L6_V2_DIMENSION = 384
EMBEDDING_MODEL = None EMBEDDING_MODELS = {}
def get_embedding_model() -> "SentenceTransformer": def get_embedding_model(model: str) -> "SentenceTransformer":
global EMBEDDING_MODEL global EMBEDDING_MODELS
if EMBEDDING_MODEL is None: loaded_model = EMBEDDING_MODELS.get(model)
print("Loading sentence transformer") if loaded_model is not None:
return loaded_model
print(f"Loading sentence transformer for {model}...")
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
EMBEDDING_MODEL = SentenceTransformer("all-MiniLM-L6-v2") loaded_model = SentenceTransformer(model)
EMBEDDING_MODELS[model] = loaded_model
return EMBEDDING_MODEL return loaded_model
def parse_data_url(data_url: str): def parse_data_url(data_url: str):
@ -151,7 +153,7 @@ class BankWithIndex:
self, self,
documents: List[MemoryBankDocument], documents: List[MemoryBankDocument],
) -> None: ) -> None:
model = get_embedding_model() model = get_embedding_model(self.bank.config.embedding_model)
for doc in documents: for doc in documents:
content = await content_from_doc(doc) content = await content_from_doc(doc)
chunks = make_overlapped_chunks( chunks = make_overlapped_chunks(
@ -187,6 +189,6 @@ class BankWithIndex:
else: else:
query_str = _process(query) query_str = _process(query)
model = get_embedding_model() model = get_embedding_model(self.bank.config.embedding_model)
query_vector = model.encode([query_str])[0].astype(np.float32) query_vector = model.encode([query_str])[0].astype(np.float32)
return await self.index.query(query_vector, k) return await self.index.query(query_vector, k)