diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py index 457151c04..f4e8ab0aa 100644 --- a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py @@ -6,7 +6,7 @@ import re from typing import Any, Dict, Optional -from llama_stack.apis.inference.inference import Inference +from llama_stack.apis.inference.inference import Inference, UserMessage from llama_stack.apis.scoring import ScoringResultRow from llama_stack.apis.scoring_functions import ScoringFnParams from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn @@ -58,10 +58,9 @@ class LlmAsJudgeScoringFn(RegisteredBaseScoringFn): judge_response = await self.inference_api.chat_completion( model_id=fn_def.params.judge_model, messages=[ - { - "role": "user", - "content": judge_input_msg, - } + UserMessage( + content=judge_input_msg, + ), ], ) content = judge_response.completion_message.content diff --git a/tests/integration/scoring/test_scoring.py b/tests/integration/scoring/test_scoring.py index a08664990..2cb61303b 100644 --- a/tests/integration/scoring/test_scoring.py +++ b/tests/integration/scoring/test_scoring.py @@ -76,6 +76,8 @@ def test_scoring_functions_register( assert len(list_response) > 0 assert any(x.identifier == sample_scoring_fn_id for x in list_response) + # TODO: add unregister to make clean state + def test_scoring_score(llama_stack_client): register_dataset(llama_stack_client, for_rag=True)