Vector store inference api (#598)

# What does this PR do?
Moves all the memory providers to use the inference API and improved the
memory tests to setup the inference stack correctly and use the
embedding models


## Test Plan
torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference"
--inference-model="Llama3.2-3B-Instruct"
--embedding-model="sentence-transformers/all-MiniLM-L6-v2"
llama_stack/providers/tests/inference/test_embeddings.py --env
EMBEDDING_DIMENSION=384


pytest -v -s llama_stack/providers/tests/memory/test_memory.py
--providers="inference=together,memory=weaviate"
--embedding-model="togethercomputer/m2-bert-80M-2k-retrieval" --env
EMBEDDING_DIMENSION=768 --env TOGETHER_API_KEY=<API-KEY> --env
WEAVIATE_API_KEY=foo --env WEAVIATE_CLUSTER_URL=bar
 
pytest -v -s llama_stack/providers/tests/memory/test_memory.py
--providers="inference=together,memory=chroma"
--embedding-model="togethercomputer/m2-bert-80M-2k-retrieval" --env
EMBEDDING_DIMENSION=768 --env TOGETHER_API_KEY=<API-KEY>--env
CHROMA_HOST=localhost --env CHROMA_PORT=8000

pytest -v -s llama_stack/providers/tests/memory/test_memory.py
--providers="inference=together,memory=pgvector"
--embedding-model="togethercomputer/m2-bert-80M-2k-retrieval" --env
PGVECTOR_DB=postgres --env PGVECTOR_USER=postgres --env
PGVECTOR_PASSWORD=mysecretpassword --env PGVECTOR_HOST=0.0.0.0 --env
EMBEDDING_DIMENSION=768 --env TOGETHER_API_KEY=<API-KEY>

pytest -v -s llama_stack/providers/tests/memory/test_memory.py
--providers="inference=together,memory=faiss"
--embedding-model="togethercomputer/m2-bert-80M-2k-retrieval" --env
EMBEDDING_DIMENSION=768 --env TOGETHER_API_KEY=<API-KEY>
This commit is contained in:
Dinesh Yeduguru 2024-12-12 11:16:54 -08:00 committed by GitHub
parent db7b26a8c9
commit 4f8b73b9e1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 235 additions and 118 deletions

View file

@ -39,6 +39,7 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.inline.memory.faiss",
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
deprecation_warning="Please use the `inline::faiss` provider instead.",
api_dependencies=[Api.inference],
),
InlineProviderSpec(
api=Api.memory,
@ -46,6 +47,7 @@ def available_providers() -> List[ProviderSpec]:
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
module="llama_stack.providers.inline.memory.faiss",
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
api_dependencies=[Api.inference],
),
remote_provider_spec(
Api.memory,
@ -55,6 +57,7 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.remote.memory.chroma",
config_class="llama_stack.distribution.datatypes.RemoteProviderConfig",
),
api_dependencies=[Api.inference],
),
remote_provider_spec(
Api.memory,
@ -64,6 +67,7 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.remote.memory.pgvector",
config_class="llama_stack.providers.remote.memory.pgvector.PGVectorConfig",
),
api_dependencies=[Api.inference],
),
remote_provider_spec(
Api.memory,
@ -74,6 +78,7 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.remote.memory.weaviate.WeaviateConfig",
provider_data_validator="llama_stack.providers.remote.memory.weaviate.WeaviateRequestProviderData",
),
api_dependencies=[Api.inference],
),
remote_provider_spec(
api=Api.memory,
@ -83,6 +88,7 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.remote.memory.sample",
config_class="llama_stack.providers.remote.memory.sample.SampleConfig",
),
api_dependencies=[],
),
remote_provider_spec(
Api.memory,
@ -92,5 +98,6 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.remote.memory.qdrant",
config_class="llama_stack.providers.remote.memory.qdrant.QdrantConfig",
),
api_dependencies=[Api.inference],
),
]