forked from phoenix-oss/llama-stack-mirror
Respect passed in embedding model
This commit is contained in:
parent
bda974e660
commit
00352bd251
2 changed files with 16 additions and 14 deletions
|
@ -13,9 +13,9 @@ from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import httpx
|
import httpx
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
from llama_stack.providers.utils.memory.file_utils import data_url_from_file
|
from llama_stack.providers.utils.memory.file_utils import data_url_from_file
|
||||||
|
@ -120,7 +120,7 @@ async def run_main(host: str, port: int, stream: bool):
|
||||||
name="test_bank",
|
name="test_bank",
|
||||||
config=VectorMemoryBankConfig(
|
config=VectorMemoryBankConfig(
|
||||||
bank_id="test_bank",
|
bank_id="test_bank",
|
||||||
embedding_model="dragon-roberta-query-2",
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
chunk_size_in_tokens=512,
|
chunk_size_in_tokens=512,
|
||||||
overlap_size_in_tokens=64,
|
overlap_size_in_tokens=64,
|
||||||
),
|
),
|
||||||
|
@ -129,7 +129,7 @@ async def run_main(host: str, port: int, stream: bool):
|
||||||
|
|
||||||
retrieved_bank = await client.get_memory_bank(bank.bank_id)
|
retrieved_bank = await client.get_memory_bank(bank.bank_id)
|
||||||
assert retrieved_bank is not None
|
assert retrieved_bank is not None
|
||||||
assert retrieved_bank.config.embedding_model == "dragon-roberta-query-2"
|
assert retrieved_bank.config.embedding_model == "all-MiniLM-L6-v2"
|
||||||
|
|
||||||
urls = [
|
urls = [
|
||||||
"memory_optimizations.rst",
|
"memory_optimizations.rst",
|
||||||
|
|
|
@ -25,20 +25,22 @@ from llama_stack.apis.memory import * # noqa: F403
|
||||||
|
|
||||||
ALL_MINILM_L6_V2_DIMENSION = 384
|
ALL_MINILM_L6_V2_DIMENSION = 384
|
||||||
|
|
||||||
EMBEDDING_MODEL = None
|
EMBEDDING_MODELS = {}
|
||||||
|
|
||||||
|
|
||||||
def get_embedding_model() -> "SentenceTransformer":
|
def get_embedding_model(model: str) -> "SentenceTransformer":
|
||||||
global EMBEDDING_MODEL
|
global EMBEDDING_MODELS
|
||||||
|
|
||||||
if EMBEDDING_MODEL is None:
|
loaded_model = EMBEDDING_MODELS.get(model)
|
||||||
print("Loading sentence transformer")
|
if loaded_model is not None:
|
||||||
|
return loaded_model
|
||||||
|
|
||||||
|
print(f"Loading sentence transformer for {model}...")
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
EMBEDDING_MODEL = SentenceTransformer("all-MiniLM-L6-v2")
|
loaded_model = SentenceTransformer(model)
|
||||||
|
EMBEDDING_MODELS[model] = loaded_model
|
||||||
return EMBEDDING_MODEL
|
return loaded_model
|
||||||
|
|
||||||
|
|
||||||
def parse_data_url(data_url: str):
|
def parse_data_url(data_url: str):
|
||||||
|
@ -151,7 +153,7 @@ class BankWithIndex:
|
||||||
self,
|
self,
|
||||||
documents: List[MemoryBankDocument],
|
documents: List[MemoryBankDocument],
|
||||||
) -> None:
|
) -> None:
|
||||||
model = get_embedding_model()
|
model = get_embedding_model(self.bank.config.embedding_model)
|
||||||
for doc in documents:
|
for doc in documents:
|
||||||
content = await content_from_doc(doc)
|
content = await content_from_doc(doc)
|
||||||
chunks = make_overlapped_chunks(
|
chunks = make_overlapped_chunks(
|
||||||
|
@ -187,6 +189,6 @@ class BankWithIndex:
|
||||||
else:
|
else:
|
||||||
query_str = _process(query)
|
query_str = _process(query)
|
||||||
|
|
||||||
model = get_embedding_model()
|
model = get_embedding_model(self.bank.config.embedding_model)
|
||||||
query_vector = model.encode([query_str])[0].astype(np.float32)
|
query_vector = model.encode([query_str])[0].astype(np.float32)
|
||||||
return await self.index.query(query_vector, k)
|
return await self.index.query(query_vector, k)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue