diff --git a/tests/integration/scoring/test_scoring.py b/tests/integration/scoring/test_scoring.py index b695c2ef7..a08664990 100644 --- a/tests/integration/scoring/test_scoring.py +++ b/tests/integration/scoring/test_scoring.py @@ -15,14 +15,68 @@ def sample_judge_prompt_template(): return "Output a number response in the following format: Score: , where is the number between 0 and 9." +@pytest.fixture +def sample_scoring_fn_id(): + return "llm-as-judge-test-prompt" + + +def register_scoring_function( + llama_stack_client, + provider_id, + scoring_fn_id, + judge_model_id, + judge_prompt_template, +): + llama_stack_client.scoring_functions.register( + scoring_fn_id=scoring_fn_id, + provider_id=provider_id, + description="LLM as judge scoring function with test prompt", + return_type={ + "type": "string", + }, + params={ + "type": "llm_as_judge", + "judge_model": judge_model_id, + "prompt_template": judge_prompt_template, + }, + ) + + def test_scoring_functions_list(llama_stack_client): - # NOTE: this needs you to ensure that you are starting from a clean state - # but so far we don't have an unregister API unfortunately, so be careful response = llama_stack_client.scoring_functions.list() assert isinstance(response, list) assert len(response) > 0 +def test_scoring_functions_register( + llama_stack_client, + sample_scoring_fn_id, + judge_model_id, + sample_judge_prompt_template, +): + llm_as_judge_provider = [ + x + for x in llama_stack_client.providers.list() + if x.api == "scoring" and x.provider_type == "inline::llm-as-judge" + ] + if len(llm_as_judge_provider) == 0: + pytest.skip("No llm-as-judge provider found, cannot test registeration") + + llm_as_judge_provider_id = llm_as_judge_provider[0].provider_id + register_scoring_function( + llama_stack_client, + llm_as_judge_provider_id, + sample_scoring_fn_id, + judge_model_id, + sample_judge_prompt_template, + ) + + list_response = llama_stack_client.scoring_functions.list() + assert isinstance(list_response, list) + assert len(list_response) > 0 + assert any(x.identifier == sample_scoring_fn_id for x in list_response) + + def test_scoring_score(llama_stack_client): register_dataset(llama_stack_client, for_rag=True) response = llama_stack_client.datasets.list()