mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 16:24:44 +00:00
fix scoring fixture
This commit is contained in:
parent
d0ad198be9
commit
04d8660247
1 changed files with 10 additions and 14 deletions
|
@ -7,6 +7,8 @@
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from llama_stack.apis.models import ModelInput
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api, Provider
|
from llama_stack.distribution.datatypes import Api, Provider
|
||||||
|
|
||||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
||||||
|
@ -76,20 +78,14 @@ async def scoring_stack(request, inference_model):
|
||||||
[Api.scoring, Api.datasetio, Api.inference],
|
[Api.scoring, Api.datasetio, Api.inference],
|
||||||
providers,
|
providers,
|
||||||
provider_data,
|
provider_data,
|
||||||
)
|
models=[
|
||||||
|
ModelInput(model_id=model)
|
||||||
provider_id = providers["inference"][0].provider_id
|
for model in [
|
||||||
await impls[Api.models].register_model(
|
inference_model,
|
||||||
model_id=inference_model,
|
"Llama3.1-405B-Instruct",
|
||||||
provider_id=provider_id,
|
"Llama3.1-8B-Instruct",
|
||||||
)
|
]
|
||||||
await impls[Api.models].register_model(
|
],
|
||||||
model_id="Llama3.1-405B-Instruct",
|
|
||||||
provider_id=provider_id,
|
|
||||||
)
|
|
||||||
await impls[Api.models].register_model(
|
|
||||||
model_id="Llama3.1-8B-Instruct",
|
|
||||||
provider_id=provider_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return impls
|
return impls
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue