mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
This PR does the following: 1) adds the ability to generate embeddings in all supported inference providers. 2) Moves all the memory providers to use the inference API and improved the memory tests to setup the inference stack correctly and use the embedding models This is a merge from #589 and #598
136 lines
4.3 KiB
Python
136 lines
4.3 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import os
|
|
import tempfile
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
|
|
from llama_stack.apis.inference import ModelInput, ModelType
|
|
|
|
from llama_stack.distribution.datatypes import Api, Provider
|
|
from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig
|
|
from llama_stack.providers.inline.memory.faiss import FaissImplConfig
|
|
from llama_stack.providers.remote.memory.chroma import ChromaRemoteImplConfig
|
|
from llama_stack.providers.remote.memory.pgvector import PGVectorConfig
|
|
from llama_stack.providers.remote.memory.weaviate import WeaviateConfig
|
|
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
|
from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig
|
|
from ..conftest import ProviderFixture, remote_stack_fixture
|
|
from ..env import get_env_or_fail
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def memory_remote() -> ProviderFixture:
|
|
return remote_stack_fixture()
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def memory_faiss() -> ProviderFixture:
|
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
|
return ProviderFixture(
|
|
providers=[
|
|
Provider(
|
|
provider_id="faiss",
|
|
provider_type="inline::faiss",
|
|
config=FaissImplConfig(
|
|
kvstore=SqliteKVStoreConfig(db_path=temp_file.name).model_dump(),
|
|
).model_dump(),
|
|
)
|
|
],
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def memory_pgvector() -> ProviderFixture:
|
|
return ProviderFixture(
|
|
providers=[
|
|
Provider(
|
|
provider_id="pgvector",
|
|
provider_type="remote::pgvector",
|
|
config=PGVectorConfig(
|
|
host=os.getenv("PGVECTOR_HOST", "localhost"),
|
|
port=os.getenv("PGVECTOR_PORT", 5432),
|
|
db=get_env_or_fail("PGVECTOR_DB"),
|
|
user=get_env_or_fail("PGVECTOR_USER"),
|
|
password=get_env_or_fail("PGVECTOR_PASSWORD"),
|
|
).model_dump(),
|
|
)
|
|
],
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def memory_weaviate() -> ProviderFixture:
|
|
return ProviderFixture(
|
|
providers=[
|
|
Provider(
|
|
provider_id="weaviate",
|
|
provider_type="remote::weaviate",
|
|
config=WeaviateConfig().model_dump(),
|
|
)
|
|
],
|
|
provider_data=dict(
|
|
weaviate_api_key=get_env_or_fail("WEAVIATE_API_KEY"),
|
|
weaviate_cluster_url=get_env_or_fail("WEAVIATE_CLUSTER_URL"),
|
|
),
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def memory_chroma() -> ProviderFixture:
|
|
url = os.getenv("CHROMA_URL")
|
|
if url:
|
|
config = ChromaRemoteImplConfig(url=url)
|
|
provider_type = "remote::chromadb"
|
|
else:
|
|
if not os.getenv("CHROMA_DB_PATH"):
|
|
raise ValueError("CHROMA_DB_PATH or CHROMA_URL must be set")
|
|
config = ChromaInlineImplConfig(db_path=os.getenv("CHROMA_DB_PATH"))
|
|
provider_type = "inline::chromadb"
|
|
return ProviderFixture(
|
|
providers=[
|
|
Provider(
|
|
provider_id="chroma",
|
|
provider_type=provider_type,
|
|
config=config.model_dump(),
|
|
)
|
|
]
|
|
)
|
|
|
|
|
|
MEMORY_FIXTURES = ["faiss", "pgvector", "weaviate", "remote", "chroma"]
|
|
|
|
|
|
@pytest_asyncio.fixture(scope="session")
|
|
async def memory_stack(inference_model, request):
|
|
fixture_dict = request.param
|
|
|
|
providers = {}
|
|
provider_data = {}
|
|
for key in ["inference", "memory"]:
|
|
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
|
providers[key] = fixture.providers
|
|
if fixture.provider_data:
|
|
provider_data.update(fixture.provider_data)
|
|
|
|
test_stack = await construct_stack_for_test(
|
|
[Api.memory, Api.inference],
|
|
providers,
|
|
provider_data,
|
|
models=[
|
|
ModelInput(
|
|
model_id=inference_model,
|
|
model_type=ModelType.embedding_model,
|
|
metadata={
|
|
"embedding_dimension": get_env_or_fail("EMBEDDING_DIMENSION"),
|
|
},
|
|
)
|
|
],
|
|
)
|
|
|
|
return test_stack.impls[Api.memory], test_stack.impls[Api.memory_banks]
|