address feedback

This commit is contained in:
Dinesh Yeduguru 2024-12-11 16:24:37 -08:00
parent e167e9eb93
commit 5821ec9ef3
12 changed files with 61 additions and 76 deletions

View file

@ -37,13 +37,6 @@ def inference_model(request):
return request.config.getoption("--inference-model", None)
@pytest.fixture(scope="session")
def embedding_model(request):
if hasattr(request, "param"):
return request.param
return request.config.getoption("--embedding-model", None)
@pytest.fixture(scope="session")
def inference_remote() -> ProviderFixture:
return remote_stack_fixture()
@ -239,31 +232,21 @@ INFERENCE_FIXTURES = [
async def inference_stack(request, inference_model):
fixture_name = request.param
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
test_stack = await construct_stack_for_test(
[Api.inference],
{"inference": inference_fixture.providers},
inference_fixture.provider_data,
models=[ModelInput(model_id=inference_model)],
)
model_type = ModelType.llm
metadata = {}
if os.getenv("EMBEDDING_DIMENSION"):
model_type = ModelType.embedding_model
metadata["embedding_dimension"] = get_env_or_fail("EMBEDDING_DIMENSION")
return test_stack.impls[Api.inference], test_stack.impls[Api.models]
@pytest_asyncio.fixture(scope="session")
async def embedding_stack(request, embedding_model):
fixture_name = request.param
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
test_stack = await construct_stack_for_test(
[Api.inference],
{"inference": inference_fixture.providers},
inference_fixture.provider_data,
models=[
ModelInput(
model_id=embedding_model,
model_type=ModelType.embedding_model,
metadata={
"embedding_dimension": get_env_or_fail("EMBEDDING_DIMENSION"),
},
model_id=inference_model,
model_type=model_type,
metadata=metadata,
)
],
)