mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-18 13:29:47 +00:00
Vector store inference api (#598)
# What does this PR do? Moves all the memory providers to use the inference API and improved the memory tests to setup the inference stack correctly and use the embedding models ## Test Plan torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference" --inference-model="Llama3.2-3B-Instruct" --embedding-model="sentence-transformers/all-MiniLM-L6-v2" llama_stack/providers/tests/inference/test_embeddings.py --env EMBEDDING_DIMENSION=384 pytest -v -s llama_stack/providers/tests/memory/test_memory.py --providers="inference=together,memory=weaviate" --embedding-model="togethercomputer/m2-bert-80M-2k-retrieval" --env EMBEDDING_DIMENSION=768 --env TOGETHER_API_KEY=<API-KEY> --env WEAVIATE_API_KEY=foo --env WEAVIATE_CLUSTER_URL=bar pytest -v -s llama_stack/providers/tests/memory/test_memory.py --providers="inference=together,memory=chroma" --embedding-model="togethercomputer/m2-bert-80M-2k-retrieval" --env EMBEDDING_DIMENSION=768 --env TOGETHER_API_KEY=<API-KEY>--env CHROMA_HOST=localhost --env CHROMA_PORT=8000 pytest -v -s llama_stack/providers/tests/memory/test_memory.py --providers="inference=together,memory=pgvector" --embedding-model="togethercomputer/m2-bert-80M-2k-retrieval" --env PGVECTOR_DB=postgres --env PGVECTOR_USER=postgres --env PGVECTOR_PASSWORD=mysecretpassword --env PGVECTOR_HOST=0.0.0.0 --env EMBEDDING_DIMENSION=768 --env TOGETHER_API_KEY=<API-KEY> pytest -v -s llama_stack/providers/tests/memory/test_memory.py --providers="inference=together,memory=faiss" --embedding-model="togethercomputer/m2-bert-80M-2k-retrieval" --env EMBEDDING_DIMENSION=768 --env TOGETHER_API_KEY=<API-KEY>
This commit is contained in:
parent
db7b26a8c9
commit
4f8b73b9e1
15 changed files with 235 additions and 118 deletions
|
|
@ -4,12 +4,16 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
from .config import PGVectorConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: PGVectorConfig, _deps):
|
||||
async def get_adapter_impl(config: PGVectorConfig, deps: Dict[Api, ProviderSpec]):
|
||||
from .pgvector import PGVectorMemoryAdapter
|
||||
|
||||
impl = PGVectorMemoryAdapter(config)
|
||||
impl = PGVectorMemoryAdapter(config, deps[Api.inference])
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -16,9 +16,9 @@ from pydantic import BaseModel, parse_obj_as
|
|||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
||||
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
|
||||
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ALL_MINILM_L6_V2_DIMENSION,
|
||||
BankWithIndex,
|
||||
EmbeddingIndex,
|
||||
)
|
||||
|
|
@ -120,8 +120,9 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
|
||||
|
||||
class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||
def __init__(self, config: PGVectorConfig) -> None:
|
||||
def __init__(self, config: PGVectorConfig, inference_api: Api.inference) -> None:
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.cursor = None
|
||||
self.conn = None
|
||||
self.cache = {}
|
||||
|
|
@ -160,27 +161,17 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_memory_bank(
|
||||
self,
|
||||
memory_bank: MemoryBank,
|
||||
) -> None:
|
||||
async def register_memory_bank(self, memory_bank: MemoryBank) -> None:
|
||||
assert (
|
||||
memory_bank.memory_bank_type == MemoryBankType.vector.value
|
||||
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
|
||||
|
||||
upsert_models(
|
||||
self.cursor,
|
||||
[
|
||||
(memory_bank.identifier, memory_bank),
|
||||
],
|
||||
upsert_models(self.cursor, [(memory_bank.identifier, memory_bank)])
|
||||
index = PGVectorIndex(memory_bank, memory_bank.embedding_dimension, self.cursor)
|
||||
self.cache[memory_bank.identifier] = BankWithIndex(
|
||||
memory_bank, index, self.inference_api
|
||||
)
|
||||
|
||||
index = BankWithIndex(
|
||||
bank=memory_bank,
|
||||
index=PGVectorIndex(memory_bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
|
||||
)
|
||||
self.cache[memory_bank.identifier] = index
|
||||
|
||||
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
||||
await self.cache[memory_bank_id].index.delete()
|
||||
del self.cache[memory_bank_id]
|
||||
|
|
@ -190,8 +181,9 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
for bank in banks:
|
||||
if bank.identifier not in self.cache:
|
||||
index = BankWithIndex(
|
||||
bank=bank,
|
||||
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
|
||||
bank,
|
||||
PGVectorIndex(bank, bank.embedding_dimension, self.cursor),
|
||||
self.inference_api,
|
||||
)
|
||||
self.cache[bank.identifier] = index
|
||||
return banks
|
||||
|
|
@ -214,14 +206,13 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
index = await self._get_and_cache_bank_index(bank_id)
|
||||
return await index.query_documents(query, params)
|
||||
|
||||
self.inference_api = inference_api
|
||||
|
||||
async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex:
|
||||
if bank_id in self.cache:
|
||||
return self.cache[bank_id]
|
||||
|
||||
bank = await self.memory_bank_store.get_memory_bank(bank_id)
|
||||
index = BankWithIndex(
|
||||
bank=bank,
|
||||
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
|
||||
)
|
||||
self.cache[bank_id] = index
|
||||
return index
|
||||
index = PGVectorIndex(bank, bank.embedding_dimension, self.cursor)
|
||||
self.cache[bank_id] = BankWithIndex(bank, index, self.inference_api)
|
||||
return self.cache[bank_id]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue