Vector store inference api (#598)

# What does this PR do?
Moves all the memory providers to use the inference API and improved the
memory tests to setup the inference stack correctly and use the
embedding models


## Test Plan
torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference"
--inference-model="Llama3.2-3B-Instruct"
--embedding-model="sentence-transformers/all-MiniLM-L6-v2"
llama_stack/providers/tests/inference/test_embeddings.py --env
EMBEDDING_DIMENSION=384


pytest -v -s llama_stack/providers/tests/memory/test_memory.py
--providers="inference=together,memory=weaviate"
--embedding-model="togethercomputer/m2-bert-80M-2k-retrieval" --env
EMBEDDING_DIMENSION=768 --env TOGETHER_API_KEY=<API-KEY> --env
WEAVIATE_API_KEY=foo --env WEAVIATE_CLUSTER_URL=bar
 
pytest -v -s llama_stack/providers/tests/memory/test_memory.py
--providers="inference=together,memory=chroma"
--embedding-model="togethercomputer/m2-bert-80M-2k-retrieval" --env
EMBEDDING_DIMENSION=768 --env TOGETHER_API_KEY=<API-KEY>--env
CHROMA_HOST=localhost --env CHROMA_PORT=8000

pytest -v -s llama_stack/providers/tests/memory/test_memory.py
--providers="inference=together,memory=pgvector"
--embedding-model="togethercomputer/m2-bert-80M-2k-retrieval" --env
PGVECTOR_DB=postgres --env PGVECTOR_USER=postgres --env
PGVECTOR_PASSWORD=mysecretpassword --env PGVECTOR_HOST=0.0.0.0 --env
EMBEDDING_DIMENSION=768 --env TOGETHER_API_KEY=<API-KEY>

pytest -v -s llama_stack/providers/tests/memory/test_memory.py
--providers="inference=together,memory=faiss"
--embedding-model="togethercomputer/m2-bert-80M-2k-retrieval" --env
EMBEDDING_DIMENSION=768 --env TOGETHER_API_KEY=<API-KEY>
This commit is contained in:
Dinesh Yeduguru 2024-12-12 11:16:54 -08:00 committed by GitHub
parent db7b26a8c9
commit 4f8b73b9e1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 235 additions and 118 deletions

View file

@ -4,12 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from llama_stack.distribution.datatypes import RemoteProviderConfig
from llama_stack.providers.datatypes import Api, ProviderSpec
async def get_adapter_impl(config: RemoteProviderConfig, _deps):
async def get_adapter_impl(config: RemoteProviderConfig, deps: Dict[Api, ProviderSpec]):
from .chroma import ChromaMemoryAdapter
impl = ChromaMemoryAdapter(config.url)
impl = ChromaMemoryAdapter(config.url, deps[Api.inference])
await impl.initialize()
return impl

View file

@ -15,8 +15,7 @@ from numpy.typing import NDArray
from pydantic import parse_obj_as
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex,
EmbeddingIndex,
@ -72,7 +71,7 @@ class ChromaIndex(EmbeddingIndex):
class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
def __init__(self, url: str) -> None:
def __init__(self, url: str, inference_api: Api.inference) -> None:
log.info(f"Initializing ChromaMemoryAdapter with url: {url}")
url = url.rstrip("/")
parsed = urlparse(url)
@ -82,6 +81,7 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
self.host = parsed.hostname
self.port = parsed.port
self.inference_api = inference_api
self.client = None
self.cache = {}
@ -109,10 +109,9 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
name=memory_bank.identifier,
metadata={"bank": memory_bank.model_dump_json()},
)
bank_index = BankWithIndex(
bank=memory_bank, index=ChromaIndex(self.client, collection)
self.cache[memory_bank.identifier] = BankWithIndex(
memory_bank, ChromaIndex(self.client, collection), self.inference_api
)
self.cache[memory_bank.identifier] = bank_index
async def list_memory_banks(self) -> List[MemoryBank]:
collections = await self.client.list_collections()
@ -124,11 +123,11 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
log.exception(f"Failed to parse bank: {collection.metadata}")
continue
index = BankWithIndex(
bank=bank,
index=ChromaIndex(self.client, collection),
self.cache[bank.identifier] = BankWithIndex(
bank,
ChromaIndex(self.client, collection),
self.inference_api,
)
self.cache[bank.identifier] = index
return [i.bank for i in self.cache.values()]
@ -166,6 +165,8 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
collection = await self.client.get_collection(bank_id)
if not collection:
raise ValueError(f"Bank {bank_id} not found in Chroma")
index = BankWithIndex(bank=bank, index=ChromaIndex(self.client, collection))
index = BankWithIndex(
bank, ChromaIndex(self.client, collection), self.inference_api
)
self.cache[bank_id] = index
return index