mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 04:50:39 +00:00
fix scoring
This commit is contained in:
parent
f2464050c7
commit
546a417b09
2 changed files with 6 additions and 5 deletions
|
@ -6,7 +6,7 @@
|
||||||
import re
|
import re
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from llama_stack.apis.inference.inference import Inference
|
from llama_stack.apis.inference.inference import Inference, UserMessage
|
||||||
from llama_stack.apis.scoring import ScoringResultRow
|
from llama_stack.apis.scoring import ScoringResultRow
|
||||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||||
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||||
|
@ -58,10 +58,9 @@ class LlmAsJudgeScoringFn(RegisteredBaseScoringFn):
|
||||||
judge_response = await self.inference_api.chat_completion(
|
judge_response = await self.inference_api.chat_completion(
|
||||||
model_id=fn_def.params.judge_model,
|
model_id=fn_def.params.judge_model,
|
||||||
messages=[
|
messages=[
|
||||||
{
|
UserMessage(
|
||||||
"role": "user",
|
content=judge_input_msg,
|
||||||
"content": judge_input_msg,
|
),
|
||||||
}
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
content = judge_response.completion_message.content
|
content = judge_response.completion_message.content
|
||||||
|
|
|
@ -76,6 +76,8 @@ def test_scoring_functions_register(
|
||||||
assert len(list_response) > 0
|
assert len(list_response) > 0
|
||||||
assert any(x.identifier == sample_scoring_fn_id for x in list_response)
|
assert any(x.identifier == sample_scoring_fn_id for x in list_response)
|
||||||
|
|
||||||
|
# TODO: add unregister to make clean state
|
||||||
|
|
||||||
|
|
||||||
def test_scoring_score(llama_stack_client):
|
def test_scoring_score(llama_stack_client):
|
||||||
register_dataset(llama_stack_client, for_rag=True)
|
register_dataset(llama_stack_client, for_rag=True)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue