mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-31 03:50:00 +00:00
updated mongodb.py from print to log add documentation for mongodb vector search module changed insert to update mongodb bug fix mongodb json object conversion error
196 lines
6.4 KiB
Python
196 lines
6.4 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import os
|
|
import tempfile
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
|
|
from llama_stack.apis.models import ModelInput, ModelType
|
|
from llama_stack.distribution.datatypes import Api, Provider
|
|
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
|
|
from llama_stack.providers.inline.vector_io.faiss import FaissVectorIOConfig
|
|
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig
|
|
from llama_stack.providers.remote.vector_io.chroma import ChromaVectorIOConfig
|
|
from llama_stack.providers.remote.vector_io.pgvector import PGVectorVectorIOConfig
|
|
from llama_stack.providers.remote.vector_io.qdrant import QdrantVectorIOConfig
|
|
from llama_stack.providers.remote.vector_io.weaviate import WeaviateVectorIOConfig
|
|
from llama_stack.providers.remote.vector_io.mongodb import MongoDBVectorIOConfig
|
|
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
|
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
|
|
|
from ..conftest import ProviderFixture, remote_stack_fixture
|
|
from ..env import get_env_or_fail
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def embedding_model(request):
|
|
if hasattr(request, "param"):
|
|
return request.param
|
|
return request.config.getoption("--embedding-model", None)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def vector_io_remote() -> ProviderFixture:
|
|
return remote_stack_fixture()
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def vector_io_faiss() -> ProviderFixture:
|
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
|
return ProviderFixture(
|
|
providers=[
|
|
Provider(
|
|
provider_id="faiss",
|
|
provider_type="inline::faiss",
|
|
config=FaissVectorIOConfig(
|
|
kvstore=SqliteKVStoreConfig(db_path=temp_file.name).model_dump(),
|
|
).model_dump(),
|
|
)
|
|
],
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def vector_io_sqlite_vec() -> ProviderFixture:
|
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
|
return ProviderFixture(
|
|
providers=[
|
|
Provider(
|
|
provider_id="sqlite_vec",
|
|
provider_type="inline::sqlite_vec",
|
|
config=SQLiteVectorIOConfig(
|
|
kvstore=SqliteKVStoreConfig(db_path=temp_file.name).model_dump(),
|
|
).model_dump(),
|
|
)
|
|
],
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def vector_io_pgvector() -> ProviderFixture:
|
|
return ProviderFixture(
|
|
providers=[
|
|
Provider(
|
|
provider_id="pgvector",
|
|
provider_type="remote::pgvector",
|
|
config=PGVectorVectorIOConfig(
|
|
host=os.getenv("PGVECTOR_HOST", "localhost"),
|
|
port=os.getenv("PGVECTOR_PORT", 5432),
|
|
db=get_env_or_fail("PGVECTOR_DB"),
|
|
user=get_env_or_fail("PGVECTOR_USER"),
|
|
password=get_env_or_fail("PGVECTOR_PASSWORD"),
|
|
).model_dump(),
|
|
)
|
|
],
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def vector_io_weaviate() -> ProviderFixture:
|
|
return ProviderFixture(
|
|
providers=[
|
|
Provider(
|
|
provider_id="weaviate",
|
|
provider_type="remote::weaviate",
|
|
config=WeaviateVectorIOConfig().model_dump(),
|
|
)
|
|
],
|
|
provider_data=dict(
|
|
weaviate_api_key=get_env_or_fail("WEAVIATE_API_KEY"),
|
|
weaviate_cluster_url=get_env_or_fail("WEAVIATE_CLUSTER_URL"),
|
|
),
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def vector_io_chroma() -> ProviderFixture:
|
|
url = os.getenv("CHROMA_URL")
|
|
if url:
|
|
config = ChromaVectorIOConfig(url=url)
|
|
provider_type = "remote::chromadb"
|
|
else:
|
|
if not os.getenv("CHROMA_DB_PATH"):
|
|
raise ValueError("CHROMA_DB_PATH or CHROMA_URL must be set")
|
|
config = InlineChromaVectorIOConfig(db_path=os.getenv("CHROMA_DB_PATH"))
|
|
provider_type = "inline::chromadb"
|
|
return ProviderFixture(
|
|
providers=[
|
|
Provider(
|
|
provider_id="chroma",
|
|
provider_type=provider_type,
|
|
config=config.model_dump(),
|
|
)
|
|
]
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def vector_io_qdrant() -> ProviderFixture:
|
|
url = os.getenv("QDRANT_URL")
|
|
if url:
|
|
config = QdrantVectorIOConfig(url=url)
|
|
provider_type = "remote::qdrant"
|
|
else:
|
|
raise ValueError("QDRANT_URL must be set")
|
|
return ProviderFixture(
|
|
providers=[
|
|
Provider(
|
|
provider_id="qdrant",
|
|
provider_type=provider_type,
|
|
config=config.model_dump(),
|
|
)
|
|
]
|
|
)
|
|
|
|
@pytest.fixture(scope="session")
|
|
def vector_io_mongodb() -> ProviderFixture:
|
|
connection_str = get_env_or_fail("MONGODB_CONNECTION_STR")
|
|
namespace = get_env_or_fail("MONGODB_NAMESPACE")
|
|
config = MongoDBVectorIOConfig(connection_str=connection_str, namespace=namespace)
|
|
provider_type = "remote::mongodb"
|
|
return ProviderFixture(
|
|
providers=[
|
|
Provider(
|
|
provider_id="mongodb",
|
|
provider_type=provider_type,
|
|
config=config.model_dump(),
|
|
)
|
|
]
|
|
)
|
|
|
|
VECTOR_IO_FIXTURES = ["faiss", "pgvector", "weaviate", "chroma", "qdrant", "sqlite_vec", "mongodb"]
|
|
|
|
|
|
@pytest_asyncio.fixture(scope="session")
|
|
async def vector_io_stack(embedding_model, request):
|
|
fixture_dict = request.param
|
|
|
|
providers = {}
|
|
provider_data = {}
|
|
for key in ["inference", "vector_io"]:
|
|
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
|
providers[key] = fixture.providers
|
|
if fixture.provider_data:
|
|
provider_data.update(fixture.provider_data)
|
|
|
|
test_stack = await construct_stack_for_test(
|
|
[Api.vector_io, Api.inference],
|
|
providers,
|
|
provider_data,
|
|
models=[
|
|
ModelInput(
|
|
model_id=embedding_model,
|
|
model_type=ModelType.embedding,
|
|
metadata={
|
|
"embedding_dimension": get_env_or_fail("EMBEDDING_DIMENSION"),
|
|
},
|
|
)
|
|
],
|
|
)
|
|
|
|
return test_stack.impls[Api.vector_io], test_stack.impls[Api.vector_dbs]
|