use provider resource id to validate for models

This commit is contained in:
Dinesh Yeduguru 2024-11-12 08:21:37 -08:00
parent e4f14eafe2
commit 95b7f57d92
7 changed files with 75 additions and 46 deletions

View file

@ -142,6 +142,31 @@ def inference_bedrock() -> ProviderFixture:
)
def get_model_short_name(model_name: str) -> str:
"""Convert model name to a short test identifier.
Args:
model_name: Full model name like "Llama3.1-8B-Instruct"
Returns:
Short name like "llama_8b" suitable for test markers
"""
model_name = model_name.lower()
if "vision" in model_name:
return "llama_vision"
elif "3b" in model_name:
return "llama_3b"
elif "8b" in model_name:
return "llama_8b"
else:
return model_name.replace(".", "_").replace("-", "_")
@pytest.fixture(scope="session")
def model_id(inference_model) -> str:
return get_model_short_name(inference_model)
INFERENCE_FIXTURES = [
"meta_reference",
"ollama",
@ -154,7 +179,7 @@ INFERENCE_FIXTURES = [
@pytest_asyncio.fixture(scope="session")
async def inference_stack(request, inference_model):
async def inference_stack(request, inference_model, model_id):
fixture_name = request.param
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
impls = await resolve_impls_for_test_v2(
@ -163,7 +188,7 @@ async def inference_stack(request, inference_model):
inference_fixture.provider_data,
models=[
ModelInput(
model_id=inference_model,
model_id=model_id,
)
],
)