mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-19 13:59:40 +00:00
Vector store inference api (#598)
# What does this PR do? Moves all the memory providers to use the inference API and improved the memory tests to setup the inference stack correctly and use the embedding models ## Test Plan torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference" --inference-model="Llama3.2-3B-Instruct" --embedding-model="sentence-transformers/all-MiniLM-L6-v2" llama_stack/providers/tests/inference/test_embeddings.py --env EMBEDDING_DIMENSION=384 pytest -v -s llama_stack/providers/tests/memory/test_memory.py --providers="inference=together,memory=weaviate" --embedding-model="togethercomputer/m2-bert-80M-2k-retrieval" --env EMBEDDING_DIMENSION=768 --env TOGETHER_API_KEY=<API-KEY> --env WEAVIATE_API_KEY=foo --env WEAVIATE_CLUSTER_URL=bar pytest -v -s llama_stack/providers/tests/memory/test_memory.py --providers="inference=together,memory=chroma" --embedding-model="togethercomputer/m2-bert-80M-2k-retrieval" --env EMBEDDING_DIMENSION=768 --env TOGETHER_API_KEY=<API-KEY>--env CHROMA_HOST=localhost --env CHROMA_PORT=8000 pytest -v -s llama_stack/providers/tests/memory/test_memory.py --providers="inference=together,memory=pgvector" --embedding-model="togethercomputer/m2-bert-80M-2k-retrieval" --env PGVECTOR_DB=postgres --env PGVECTOR_USER=postgres --env PGVECTOR_PASSWORD=mysecretpassword --env PGVECTOR_HOST=0.0.0.0 --env EMBEDDING_DIMENSION=768 --env TOGETHER_API_KEY=<API-KEY> pytest -v -s llama_stack/providers/tests/memory/test_memory.py --providers="inference=together,memory=faiss" --embedding-model="togethercomputer/m2-bert-80M-2k-retrieval" --env EMBEDDING_DIMENSION=768 --env TOGETHER_API_KEY=<API-KEY>
This commit is contained in:
parent
db7b26a8c9
commit
4f8b73b9e1
15 changed files with 235 additions and 118 deletions
|
|
@ -15,8 +15,7 @@ from numpy.typing import NDArray
|
|||
from pydantic import parse_obj_as
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
||||
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
BankWithIndex,
|
||||
EmbeddingIndex,
|
||||
|
|
@ -72,7 +71,7 @@ class ChromaIndex(EmbeddingIndex):
|
|||
|
||||
|
||||
class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||
def __init__(self, url: str) -> None:
|
||||
def __init__(self, url: str, inference_api: Api.inference) -> None:
|
||||
log.info(f"Initializing ChromaMemoryAdapter with url: {url}")
|
||||
url = url.rstrip("/")
|
||||
parsed = urlparse(url)
|
||||
|
|
@ -82,6 +81,7 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
|
||||
self.host = parsed.hostname
|
||||
self.port = parsed.port
|
||||
self.inference_api = inference_api
|
||||
|
||||
self.client = None
|
||||
self.cache = {}
|
||||
|
|
@ -109,10 +109,9 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
name=memory_bank.identifier,
|
||||
metadata={"bank": memory_bank.model_dump_json()},
|
||||
)
|
||||
bank_index = BankWithIndex(
|
||||
bank=memory_bank, index=ChromaIndex(self.client, collection)
|
||||
self.cache[memory_bank.identifier] = BankWithIndex(
|
||||
memory_bank, ChromaIndex(self.client, collection), self.inference_api
|
||||
)
|
||||
self.cache[memory_bank.identifier] = bank_index
|
||||
|
||||
async def list_memory_banks(self) -> List[MemoryBank]:
|
||||
collections = await self.client.list_collections()
|
||||
|
|
@ -124,11 +123,11 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
log.exception(f"Failed to parse bank: {collection.metadata}")
|
||||
continue
|
||||
|
||||
index = BankWithIndex(
|
||||
bank=bank,
|
||||
index=ChromaIndex(self.client, collection),
|
||||
self.cache[bank.identifier] = BankWithIndex(
|
||||
bank,
|
||||
ChromaIndex(self.client, collection),
|
||||
self.inference_api,
|
||||
)
|
||||
self.cache[bank.identifier] = index
|
||||
|
||||
return [i.bank for i in self.cache.values()]
|
||||
|
||||
|
|
@ -166,6 +165,8 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
collection = await self.client.get_collection(bank_id)
|
||||
if not collection:
|
||||
raise ValueError(f"Bank {bank_id} not found in Chroma")
|
||||
index = BankWithIndex(bank=bank, index=ChromaIndex(self.client, collection))
|
||||
index = BankWithIndex(
|
||||
bank, ChromaIndex(self.client, collection), self.inference_api
|
||||
)
|
||||
self.cache[bank_id] = index
|
||||
return index
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue