mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 04:50:39 +00:00
add registeration test
This commit is contained in:
parent
0385309862
commit
f2464050c7
1 changed files with 56 additions and 2 deletions
|
@ -15,14 +15,68 @@ def sample_judge_prompt_template():
|
||||||
return "Output a number response in the following format: Score: <number>, where <number> is the number between 0 and 9."
|
return "Output a number response in the following format: Score: <number>, where <number> is the number between 0 and 9."
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_scoring_fn_id():
|
||||||
|
return "llm-as-judge-test-prompt"
|
||||||
|
|
||||||
|
|
||||||
|
def register_scoring_function(
|
||||||
|
llama_stack_client,
|
||||||
|
provider_id,
|
||||||
|
scoring_fn_id,
|
||||||
|
judge_model_id,
|
||||||
|
judge_prompt_template,
|
||||||
|
):
|
||||||
|
llama_stack_client.scoring_functions.register(
|
||||||
|
scoring_fn_id=scoring_fn_id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
description="LLM as judge scoring function with test prompt",
|
||||||
|
return_type={
|
||||||
|
"type": "string",
|
||||||
|
},
|
||||||
|
params={
|
||||||
|
"type": "llm_as_judge",
|
||||||
|
"judge_model": judge_model_id,
|
||||||
|
"prompt_template": judge_prompt_template,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_scoring_functions_list(llama_stack_client):
|
def test_scoring_functions_list(llama_stack_client):
|
||||||
# NOTE: this needs you to ensure that you are starting from a clean state
|
|
||||||
# but so far we don't have an unregister API unfortunately, so be careful
|
|
||||||
response = llama_stack_client.scoring_functions.list()
|
response = llama_stack_client.scoring_functions.list()
|
||||||
assert isinstance(response, list)
|
assert isinstance(response, list)
|
||||||
assert len(response) > 0
|
assert len(response) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_scoring_functions_register(
|
||||||
|
llama_stack_client,
|
||||||
|
sample_scoring_fn_id,
|
||||||
|
judge_model_id,
|
||||||
|
sample_judge_prompt_template,
|
||||||
|
):
|
||||||
|
llm_as_judge_provider = [
|
||||||
|
x
|
||||||
|
for x in llama_stack_client.providers.list()
|
||||||
|
if x.api == "scoring" and x.provider_type == "inline::llm-as-judge"
|
||||||
|
]
|
||||||
|
if len(llm_as_judge_provider) == 0:
|
||||||
|
pytest.skip("No llm-as-judge provider found, cannot test registeration")
|
||||||
|
|
||||||
|
llm_as_judge_provider_id = llm_as_judge_provider[0].provider_id
|
||||||
|
register_scoring_function(
|
||||||
|
llama_stack_client,
|
||||||
|
llm_as_judge_provider_id,
|
||||||
|
sample_scoring_fn_id,
|
||||||
|
judge_model_id,
|
||||||
|
sample_judge_prompt_template,
|
||||||
|
)
|
||||||
|
|
||||||
|
list_response = llama_stack_client.scoring_functions.list()
|
||||||
|
assert isinstance(list_response, list)
|
||||||
|
assert len(list_response) > 0
|
||||||
|
assert any(x.identifier == sample_scoring_fn_id for x in list_response)
|
||||||
|
|
||||||
|
|
||||||
def test_scoring_score(llama_stack_client):
|
def test_scoring_score(llama_stack_client):
|
||||||
register_dataset(llama_stack_client, for_rag=True)
|
register_dataset(llama_stack_client, for_rag=True)
|
||||||
response = llama_stack_client.datasets.list()
|
response = llama_stack_client.datasets.list()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue