forked from phoenix-oss/llama-stack-mirror
Respect passed in embedding model
This commit is contained in:
parent
bda974e660
commit
00352bd251
2 changed files with 16 additions and 14 deletions
|
@ -25,20 +25,22 @@ from llama_stack.apis.memory import * # noqa: F403
|
|||
|
||||
ALL_MINILM_L6_V2_DIMENSION = 384
|
||||
|
||||
EMBEDDING_MODEL = None
|
||||
EMBEDDING_MODELS = {}
|
||||
|
||||
|
||||
def get_embedding_model() -> "SentenceTransformer":
|
||||
global EMBEDDING_MODEL
|
||||
def get_embedding_model(model: str) -> "SentenceTransformer":
|
||||
global EMBEDDING_MODELS
|
||||
|
||||
if EMBEDDING_MODEL is None:
|
||||
print("Loading sentence transformer")
|
||||
loaded_model = EMBEDDING_MODELS.get(model)
|
||||
if loaded_model is not None:
|
||||
return loaded_model
|
||||
|
||||
from sentence_transformers import SentenceTransformer
|
||||
print(f"Loading sentence transformer for {model}...")
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
EMBEDDING_MODEL = SentenceTransformer("all-MiniLM-L6-v2")
|
||||
|
||||
return EMBEDDING_MODEL
|
||||
loaded_model = SentenceTransformer(model)
|
||||
EMBEDDING_MODELS[model] = loaded_model
|
||||
return loaded_model
|
||||
|
||||
|
||||
def parse_data_url(data_url: str):
|
||||
|
@ -151,7 +153,7 @@ class BankWithIndex:
|
|||
self,
|
||||
documents: List[MemoryBankDocument],
|
||||
) -> None:
|
||||
model = get_embedding_model()
|
||||
model = get_embedding_model(self.bank.config.embedding_model)
|
||||
for doc in documents:
|
||||
content = await content_from_doc(doc)
|
||||
chunks = make_overlapped_chunks(
|
||||
|
@ -187,6 +189,6 @@ class BankWithIndex:
|
|||
else:
|
||||
query_str = _process(query)
|
||||
|
||||
model = get_embedding_model()
|
||||
model = get_embedding_model(self.bank.config.embedding_model)
|
||||
query_vector = model.encode([query_str])[0].astype(np.float32)
|
||||
return await self.index.query(query_vector, k)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue