mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 07:39:38 +00:00
fix scoring register
This commit is contained in:
parent
def6d5d8ad
commit
0351072531
3 changed files with 19 additions and 7 deletions
|
@ -248,7 +248,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
|||
|
||||
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, Scoring):
|
||||
async def list_scoring_functions(self) -> List[ScoringFnDefWithProvider]:
|
||||
return await self.get_all_with_type("scoring_function")
|
||||
return await self.get_all_with_type("scoring_fn")
|
||||
|
||||
async def get_scoring_function(
|
||||
self, name: str
|
||||
|
|
|
@ -53,5 +53,8 @@ async def scoring_stack(request):
|
|||
provider_data,
|
||||
)
|
||||
|
||||
print(impls)
|
||||
return impls[Api.scoring], impls[Api.scoring_functions]
|
||||
return (
|
||||
impls[Api.scoring],
|
||||
impls[Api.scoring_functions],
|
||||
impls[Api.datasets],
|
||||
)
|
||||
|
|
|
@ -7,6 +7,8 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
# pytest llama_stack/providers/tests/scoring/test_scoring.py
|
||||
|
@ -19,7 +21,14 @@ class TestScoring:
|
|||
async def test_scoring_functions_list(self, scoring_stack):
|
||||
# NOTE: this needs you to ensure that you are starting from a clean state
|
||||
# but so far we don't have an unregister API unfortunately, so be careful
|
||||
_, scoring_functions_impl = scoring_stack
|
||||
# response = await datasets_impl.list_datasets()
|
||||
# assert isinstance(response, list)
|
||||
# assert len(response) == 0
|
||||
_, scoring_functions_impl, _ = scoring_stack
|
||||
response = await scoring_functions_impl.list_scoring_functions()
|
||||
assert isinstance(response, list)
|
||||
assert len(response) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scoring_score(self, scoring_stack):
|
||||
scoring_impl, scoring_functions_impl, datasets_impl = scoring_stack
|
||||
await register_dataset(datasets_impl)
|
||||
response = await datasets_impl.list_datasets()
|
||||
assert len(response) == 1
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue