diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index bcf125bec..4c5bdf654 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -248,7 +248,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, Scoring): async def list_scoring_functions(self) -> List[ScoringFnDefWithProvider]: - return await self.get_all_with_type("scoring_function") + return await self.get_all_with_type("scoring_fn") async def get_scoring_function( self, name: str diff --git a/llama_stack/providers/tests/scoring/fixtures.py b/llama_stack/providers/tests/scoring/fixtures.py index 470fad215..337838d8e 100644 --- a/llama_stack/providers/tests/scoring/fixtures.py +++ b/llama_stack/providers/tests/scoring/fixtures.py @@ -53,5 +53,8 @@ async def scoring_stack(request): provider_data, ) - print(impls) - return impls[Api.scoring], impls[Api.scoring_functions] + return ( + impls[Api.scoring], + impls[Api.scoring_functions], + impls[Api.datasets], + ) diff --git a/llama_stack/providers/tests/scoring/test_scoring.py b/llama_stack/providers/tests/scoring/test_scoring.py index 1b50cbc38..d1518d2e3 100644 --- a/llama_stack/providers/tests/scoring/test_scoring.py +++ b/llama_stack/providers/tests/scoring/test_scoring.py @@ -7,6 +7,8 @@ import pytest +from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset + # How to run this test: # # pytest llama_stack/providers/tests/scoring/test_scoring.py @@ -19,7 +21,14 @@ class TestScoring: async def test_scoring_functions_list(self, scoring_stack): # NOTE: this needs you to ensure that you are starting from a clean state # but so far we don't have an unregister API unfortunately, so be careful - _, scoring_functions_impl = scoring_stack - # response = await datasets_impl.list_datasets() - # assert isinstance(response, list) - # assert len(response) == 0 + _, scoring_functions_impl, _ = scoring_stack + response = await scoring_functions_impl.list_scoring_functions() + assert isinstance(response, list) + assert len(response) > 0 + + @pytest.mark.asyncio + async def test_scoring_score(self, scoring_stack): + scoring_impl, scoring_functions_impl, datasets_impl = scoring_stack + await register_dataset(datasets_impl) + response = await datasets_impl.list_datasets() + assert len(response) == 1