get multiple providers working for meta-reference (for inference + safety)

This commit is contained in:
Ashwin Bharambe 2024-11-04 16:33:42 -08:00
parent 60800bc09b
commit 6c7ea6e904
10 changed files with 136 additions and 95 deletions

View file

@ -22,39 +22,45 @@ from ..env import get_env_or_fail
@pytest.fixture(scope="session")
def memory_meta_reference() -> ProviderFixture:
return ProviderFixture(
provider=Provider(
provider_id="meta-reference",
provider_type="meta-reference",
config=FaissImplConfig().model_dump(),
),
providers=[
Provider(
provider_id="meta-reference",
provider_type="meta-reference",
config=FaissImplConfig().model_dump(),
)
],
)
@pytest.fixture(scope="session")
def memory_pgvector() -> ProviderFixture:
return ProviderFixture(
provider=Provider(
provider_id="pgvector",
provider_type="remote::pgvector",
config=PGVectorConfig(
host=os.getenv("PGVECTOR_HOST", "localhost"),
port=os.getenv("PGVECTOR_PORT", 5432),
db=get_env_or_fail("PGVECTOR_DB"),
user=get_env_or_fail("PGVECTOR_USER"),
password=get_env_or_fail("PGVECTOR_PASSWORD"),
).model_dump(),
),
providers=[
Provider(
provider_id="pgvector",
provider_type="remote::pgvector",
config=PGVectorConfig(
host=os.getenv("PGVECTOR_HOST", "localhost"),
port=os.getenv("PGVECTOR_PORT", 5432),
db=get_env_or_fail("PGVECTOR_DB"),
user=get_env_or_fail("PGVECTOR_USER"),
password=get_env_or_fail("PGVECTOR_PASSWORD"),
).model_dump(),
)
],
)
@pytest.fixture(scope="session")
def memory_weaviate() -> ProviderFixture:
return ProviderFixture(
provider=Provider(
provider_id="weaviate",
provider_type="remote::weaviate",
config=WeaviateConfig().model_dump(),
),
providers=[
Provider(
provider_id="weaviate",
provider_type="remote::weaviate",
config=WeaviateConfig().model_dump(),
)
],
provider_data=dict(
weaviate_api_key=get_env_or_fail("WEAVIATE_API_KEY"),
weaviate_cluster_url=get_env_or_fail("WEAVIATE_CLUSTER_URL"),
@ -72,7 +78,7 @@ async def memory_stack(request):
impls = await resolve_impls_for_test_v2(
[Api.memory],
{"memory": [fixture.provider.model_dump()]},
{"memory": fixture.providers},
fixture.provider_data,
)