fix scoring fixture

This commit is contained in:
Ashwin Bharambe 2024-11-12 11:14:38 -08:00
parent d0ad198be9
commit 04d8660247

View file

@ -7,6 +7,8 @@
import pytest
import pytest_asyncio
from llama_stack.apis.models import ModelInput
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
@ -76,20 +78,14 @@ async def scoring_stack(request, inference_model):
[Api.scoring, Api.datasetio, Api.inference],
providers,
provider_data,
)
provider_id = providers["inference"][0].provider_id
await impls[Api.models].register_model(
model_id=inference_model,
provider_id=provider_id,
)
await impls[Api.models].register_model(
model_id="Llama3.1-405B-Instruct",
provider_id=provider_id,
)
await impls[Api.models].register_model(
model_id="Llama3.1-8B-Instruct",
provider_id=provider_id,
models=[
ModelInput(model_id=model)
for model in [
inference_model,
"Llama3.1-405B-Instruct",
"Llama3.1-8B-Instruct",
]
],
)
return impls