fix scoring

This commit is contained in:
Xi Yan 2025-03-05 16:05:39 -08:00
parent f2464050c7
commit 546a417b09
2 changed files with 6 additions and 5 deletions

View file

@ -6,7 +6,7 @@
import re import re
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from llama_stack.apis.inference.inference import Inference from llama_stack.apis.inference.inference import Inference, UserMessage
from llama_stack.apis.scoring import ScoringResultRow from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
@ -58,10 +58,9 @@ class LlmAsJudgeScoringFn(RegisteredBaseScoringFn):
judge_response = await self.inference_api.chat_completion( judge_response = await self.inference_api.chat_completion(
model_id=fn_def.params.judge_model, model_id=fn_def.params.judge_model,
messages=[ messages=[
{ UserMessage(
"role": "user", content=judge_input_msg,
"content": judge_input_msg, ),
}
], ],
) )
content = judge_response.completion_message.content content = judge_response.completion_message.content

View file

@ -76,6 +76,8 @@ def test_scoring_functions_register(
assert len(list_response) > 0 assert len(list_response) > 0
assert any(x.identifier == sample_scoring_fn_id for x in list_response) assert any(x.identifier == sample_scoring_fn_id for x in list_response)
# TODO: add unregister to make clean state
def test_scoring_score(llama_stack_client): def test_scoring_score(llama_stack_client):
register_dataset(llama_stack_client, for_rag=True) register_dataset(llama_stack_client, for_rag=True)