forked from phoenix-oss/llama-stack-mirror
Respect passed in embedding model
This commit is contained in:
parent
bda974e660
commit
00352bd251
2 changed files with 16 additions and 14 deletions
|
@ -13,9 +13,9 @@ from typing import Any, Dict, List, Optional
|
|||
|
||||
import fire
|
||||
import httpx
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.providers.utils.memory.file_utils import data_url_from_file
|
||||
|
@ -120,7 +120,7 @@ async def run_main(host: str, port: int, stream: bool):
|
|||
name="test_bank",
|
||||
config=VectorMemoryBankConfig(
|
||||
bank_id="test_bank",
|
||||
embedding_model="dragon-roberta-query-2",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
),
|
||||
|
@ -129,7 +129,7 @@ async def run_main(host: str, port: int, stream: bool):
|
|||
|
||||
retrieved_bank = await client.get_memory_bank(bank.bank_id)
|
||||
assert retrieved_bank is not None
|
||||
assert retrieved_bank.config.embedding_model == "dragon-roberta-query-2"
|
||||
assert retrieved_bank.config.embedding_model == "all-MiniLM-L6-v2"
|
||||
|
||||
urls = [
|
||||
"memory_optimizations.rst",
|
||||
|
|
|
@ -25,20 +25,22 @@ from llama_stack.apis.memory import * # noqa: F403
|
|||
|
||||
ALL_MINILM_L6_V2_DIMENSION = 384
|
||||
|
||||
EMBEDDING_MODEL = None
|
||||
EMBEDDING_MODELS = {}
|
||||
|
||||
|
||||
def get_embedding_model() -> "SentenceTransformer":
|
||||
global EMBEDDING_MODEL
|
||||
def get_embedding_model(model: str) -> "SentenceTransformer":
|
||||
global EMBEDDING_MODELS
|
||||
|
||||
if EMBEDDING_MODEL is None:
|
||||
print("Loading sentence transformer")
|
||||
loaded_model = EMBEDDING_MODELS.get(model)
|
||||
if loaded_model is not None:
|
||||
return loaded_model
|
||||
|
||||
print(f"Loading sentence transformer for {model}...")
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
EMBEDDING_MODEL = SentenceTransformer("all-MiniLM-L6-v2")
|
||||
|
||||
return EMBEDDING_MODEL
|
||||
loaded_model = SentenceTransformer(model)
|
||||
EMBEDDING_MODELS[model] = loaded_model
|
||||
return loaded_model
|
||||
|
||||
|
||||
def parse_data_url(data_url: str):
|
||||
|
@ -151,7 +153,7 @@ class BankWithIndex:
|
|||
self,
|
||||
documents: List[MemoryBankDocument],
|
||||
) -> None:
|
||||
model = get_embedding_model()
|
||||
model = get_embedding_model(self.bank.config.embedding_model)
|
||||
for doc in documents:
|
||||
content = await content_from_doc(doc)
|
||||
chunks = make_overlapped_chunks(
|
||||
|
@ -187,6 +189,6 @@ class BankWithIndex:
|
|||
else:
|
||||
query_str = _process(query)
|
||||
|
||||
model = get_embedding_model()
|
||||
model = get_embedding_model(self.bank.config.embedding_model)
|
||||
query_vector = model.encode([query_str])[0].astype(np.float32)
|
||||
return await self.index.query(query_vector, k)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue