llama-stack/tests/integration/scoring/test_scoring.py
Xi Yan 5287b437ae
feat(api): (1/n) datasets api clean up (#1573)
## PR Stack
- https://github.com/meta-llama/llama-stack/pull/1573
- https://github.com/meta-llama/llama-stack/pull/1625
- https://github.com/meta-llama/llama-stack/pull/1656
- https://github.com/meta-llama/llama-stack/pull/1657
- https://github.com/meta-llama/llama-stack/pull/1658
- https://github.com/meta-llama/llama-stack/pull/1659
- https://github.com/meta-llama/llama-stack/pull/1660

**Client SDK**
- https://github.com/meta-llama/llama-stack-client-python/pull/203

**CI**
- 1391130488
<img width="1042" alt="image"
src="https://github.com/user-attachments/assets/69636067-376d-436b-9204-896e2dd490ca"
/>
-- the test_rag_agent_with_attachments is flaky and not related to this
PR

## Doc
<img width="789" alt="image"
src="https://github.com/user-attachments/assets/b88390f3-73d6-4483-b09a-a192064e32d9"
/>


## Client Usage
```python
client.datasets.register(
    source={
        "type": "uri",
        "uri": "lsfs://mydata.jsonl",
    },
    schema="jsonl_messages",
    # optional 
    dataset_id="my_first_train_data"
)

# quick prototype debugging
client.datasets.register(
    data_reference={
        "type": "rows",
        "rows": [
                "messages": [...],
        ],
    },
    schema="jsonl_messages",
)
```

## Test Plan
- CI:
1387805545

```
LLAMA_STACK_CONFIG=fireworks pytest -v tests/integration/datasets/test_datasets.py
```

```
LLAMA_STACK_CONFIG=fireworks pytest -v tests/integration/scoring/test_scoring.py
```

```
pytest -v -s --nbval-lax ./docs/notebooks/Llama_Stack_Benchmark_Evals.ipynb
```
2025-03-17 16:55:45 -07:00

198 lines
5.9 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pathlib import Path
import pandas as pd
import pytest
@pytest.fixture
def sample_judge_prompt_template():
return "Output a number response in the following format: Score: <number>, where <number> is the number between 0 and 9."
@pytest.fixture
def sample_scoring_fn_id():
return "llm-as-judge-test-prompt"
def register_scoring_function(
llama_stack_client,
provider_id,
scoring_fn_id,
judge_model_id,
judge_prompt_template,
):
llama_stack_client.scoring_functions.register(
scoring_fn_id=scoring_fn_id,
provider_id=provider_id,
description="LLM as judge scoring function with test prompt",
return_type={
"type": "string",
},
params={
"type": "llm_as_judge",
"judge_model": judge_model_id,
"prompt_template": judge_prompt_template,
},
)
def test_scoring_functions_list(llama_stack_client):
response = llama_stack_client.scoring_functions.list()
assert isinstance(response, list)
assert len(response) > 0
def test_scoring_functions_register(
llama_stack_client,
sample_scoring_fn_id,
judge_model_id,
sample_judge_prompt_template,
):
llm_as_judge_provider = [
x
for x in llama_stack_client.providers.list()
if x.api == "scoring" and x.provider_type == "inline::llm-as-judge"
]
if len(llm_as_judge_provider) == 0:
pytest.skip("No llm-as-judge provider found, cannot test registeration")
llm_as_judge_provider_id = llm_as_judge_provider[0].provider_id
register_scoring_function(
llama_stack_client,
llm_as_judge_provider_id,
sample_scoring_fn_id,
judge_model_id,
sample_judge_prompt_template,
)
list_response = llama_stack_client.scoring_functions.list()
assert isinstance(list_response, list)
assert len(list_response) > 0
assert any(x.identifier == sample_scoring_fn_id for x in list_response)
# TODO: add unregister api for scoring functions
@pytest.mark.parametrize("scoring_fn_id", ["basic::equality"])
def test_scoring_score(llama_stack_client, scoring_fn_id):
# scoring individual rows
df = pd.read_csv(Path(__file__).parent.parent / "datasets" / "test_dataset.csv")
rows = df.to_dict(orient="records")
scoring_functions = {
scoring_fn_id: None,
}
response = llama_stack_client.scoring.score(
input_rows=rows,
scoring_functions=scoring_functions,
)
assert len(response.results) == len(scoring_functions)
for x in scoring_functions:
assert x in response.results
assert len(response.results[x].score_rows) == len(rows)
def test_scoring_score_with_params_llm_as_judge(
llama_stack_client,
sample_judge_prompt_template,
judge_model_id,
):
# scoring individual rows
df = pd.read_csv(Path(__file__).parent.parent / "datasets" / "test_dataset.csv")
rows = df.to_dict(orient="records")
scoring_functions = {
"llm-as-judge::base": dict(
type="llm_as_judge",
judge_model=judge_model_id,
prompt_template=sample_judge_prompt_template,
judge_score_regexes=[r"Score: (\d+)"],
aggregation_functions=[
"categorical_count",
],
)
}
response = llama_stack_client.scoring.score(
input_rows=rows,
scoring_functions=scoring_functions,
)
assert len(response.results) == len(scoring_functions)
for x in scoring_functions:
assert x in response.results
assert len(response.results[x].score_rows) == len(rows)
@pytest.mark.parametrize(
"provider_id",
[
"basic",
"llm-as-judge",
"braintrust",
],
)
def test_scoring_score_with_aggregation_functions(
llama_stack_client,
sample_judge_prompt_template,
judge_model_id,
provider_id,
rag_dataset_for_test,
):
df = pd.read_csv(Path(__file__).parent.parent / "datasets" / "test_dataset.csv")
rows = df.to_dict(orient="records")
scoring_fns_list = [x for x in llama_stack_client.scoring_functions.list() if x.provider_id == provider_id]
if len(scoring_fns_list) == 0:
pytest.skip(f"No scoring functions found for provider {provider_id}, skipping")
scoring_functions = {}
aggr_fns = [
"accuracy",
"median",
"categorical_count",
"average",
]
scoring_fn = scoring_fns_list[0]
if scoring_fn.provider_id == "llm-as-judge":
aggr_fns = ["categorical_count"]
scoring_functions[scoring_fn.identifier] = dict(
type="llm_as_judge",
judge_model=judge_model_id,
prompt_template=sample_judge_prompt_template,
judge_score_regexes=[r"Score: (\d+)"],
aggregation_functions=aggr_fns,
)
elif scoring_fn.provider_id == "basic" or scoring_fn.provider_id == "braintrust":
if "regex_parser" in scoring_fn.identifier:
scoring_functions[scoring_fn.identifier] = dict(
type="regex_parser",
parsing_regexes=[r"Score: (\d+)"],
aggregation_functions=aggr_fns,
)
else:
scoring_functions[scoring_fn.identifier] = dict(
type="basic",
aggregation_functions=aggr_fns,
)
else:
scoring_functions[scoring_fn.identifier] = None
response = llama_stack_client.scoring.score(
input_rows=rows,
scoring_functions=scoring_functions,
)
assert len(response.results) == len(scoring_functions)
for x in scoring_functions:
assert x in response.results
assert len(response.results[x].score_rows) == len(rows)
assert len(response.results[x].aggregated_results) == len(aggr_fns)