mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-19 04:29:40 +00:00
Vector store inference api (#598)
# What does this PR do? Moves all the memory providers to use the inference API and improved the memory tests to setup the inference stack correctly and use the embedding models ## Test Plan torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference" --inference-model="Llama3.2-3B-Instruct" --embedding-model="sentence-transformers/all-MiniLM-L6-v2" llama_stack/providers/tests/inference/test_embeddings.py --env EMBEDDING_DIMENSION=384 pytest -v -s llama_stack/providers/tests/memory/test_memory.py --providers="inference=together,memory=weaviate" --embedding-model="togethercomputer/m2-bert-80M-2k-retrieval" --env EMBEDDING_DIMENSION=768 --env TOGETHER_API_KEY=<API-KEY> --env WEAVIATE_API_KEY=foo --env WEAVIATE_CLUSTER_URL=bar pytest -v -s llama_stack/providers/tests/memory/test_memory.py --providers="inference=together,memory=chroma" --embedding-model="togethercomputer/m2-bert-80M-2k-retrieval" --env EMBEDDING_DIMENSION=768 --env TOGETHER_API_KEY=<API-KEY>--env CHROMA_HOST=localhost --env CHROMA_PORT=8000 pytest -v -s llama_stack/providers/tests/memory/test_memory.py --providers="inference=together,memory=pgvector" --embedding-model="togethercomputer/m2-bert-80M-2k-retrieval" --env PGVECTOR_DB=postgres --env PGVECTOR_USER=postgres --env PGVECTOR_PASSWORD=mysecretpassword --env PGVECTOR_HOST=0.0.0.0 --env EMBEDDING_DIMENSION=768 --env TOGETHER_API_KEY=<API-KEY> pytest -v -s llama_stack/providers/tests/memory/test_memory.py --providers="inference=together,memory=faiss" --embedding-model="togethercomputer/m2-bert-80M-2k-retrieval" --env EMBEDDING_DIMENSION=768 --env TOGETHER_API_KEY=<API-KEY>
This commit is contained in:
parent
db7b26a8c9
commit
4f8b73b9e1
15 changed files with 235 additions and 118 deletions
|
|
@ -22,28 +22,10 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
|
|||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
ALL_MINILM_L6_V2_DIMENSION = 384
|
||||
|
||||
EMBEDDING_MODELS = {}
|
||||
|
||||
|
||||
def get_embedding_model(model: str) -> "SentenceTransformer":
|
||||
global EMBEDDING_MODELS
|
||||
|
||||
loaded_model = EMBEDDING_MODELS.get(model)
|
||||
if loaded_model is not None:
|
||||
return loaded_model
|
||||
|
||||
log.info(f"Loading sentence transformer for {model}...")
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
loaded_model = SentenceTransformer(model)
|
||||
EMBEDDING_MODELS[model] = loaded_model
|
||||
return loaded_model
|
||||
|
||||
|
||||
def parse_pdf(data: bytes) -> str:
|
||||
# For PDF and DOC/DOCX files, we can't reliably convert to string
|
||||
|
|
@ -166,12 +148,12 @@ class EmbeddingIndex(ABC):
|
|||
class BankWithIndex:
|
||||
bank: VectorMemoryBank
|
||||
index: EmbeddingIndex
|
||||
inference_api: Api.inference
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
documents: List[MemoryBankDocument],
|
||||
) -> None:
|
||||
model = get_embedding_model(self.bank.embedding_model)
|
||||
for doc in documents:
|
||||
content = await content_from_doc(doc)
|
||||
chunks = make_overlapped_chunks(
|
||||
|
|
@ -183,7 +165,10 @@ class BankWithIndex:
|
|||
)
|
||||
if not chunks:
|
||||
continue
|
||||
embeddings = model.encode([x.content for x in chunks]).astype(np.float32)
|
||||
embeddings_response = await self.inference_api.embeddings(
|
||||
self.bank.embedding_model, [x.content for x in chunks]
|
||||
)
|
||||
embeddings = np.array(embeddings_response.embeddings)
|
||||
|
||||
await self.index.add_chunks(chunks, embeddings)
|
||||
|
||||
|
|
@ -208,6 +193,8 @@ class BankWithIndex:
|
|||
else:
|
||||
query_str = _process(query)
|
||||
|
||||
model = get_embedding_model(self.bank.embedding_model)
|
||||
query_vector = model.encode([query_str])[0].astype(np.float32)
|
||||
embeddings_response = await self.inference_api.embeddings(
|
||||
self.bank.embedding_model, [query_str]
|
||||
)
|
||||
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
|
||||
return await self.index.query(query_vector, k, score_threshold)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue