[rag evals] refactor & add ability to eval retrieval + generation in agentic eval pipeline (#664)

# What does this PR do?

- See https://github.com/meta-llama/llama-stack/pull/666 &
https://github.com/meta-llama/llama-stack/pull/668

- Refactor BaseScoringFn to be just a minimal interface, add new
RegistrableBaseScoring
- Refactor data schema check
- To separately evaluate retrieval component in RAG, we will have
scoring functions needing "context" column additionally.
- Refactor braintrust eval (more scoring fn added & tested in following
PR)

## Test Plan

```
pytest -v -s -m llm_as_judge_scoring_together_inference scoring/test_scoring.py --judge-model meta-llama/Llama-3.2-3B-Instruct
pytest -v -s -m basic_scoring_together_inference scoring/test_scoring.py
pytest -v -s -m braintrust_scoring_together_inference scoring/test_scoring.py
```

<img width="847" alt="image"
src="https://github.com/user-attachments/assets/d099cb2d-6f9c-4bdf-9d0d-f388cf758c0f"
/>

```
pytest -v -s -m meta_reference_eval_together_inference eval/test_eval.py
pytest -v -s -m meta_reference_eval_together_inference_huggingface_datasetio eval/test_eval.py
```
<img width="850" alt="image"
src="https://github.com/user-attachments/assets/dce28fc3-0493-4d34-820a-567260873cc8"
/>



## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Ran pre-commit to handle lint / formatting issues.
- [ ] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [ ] Updated relevant documentation.
- [ ] Wrote necessary unit or integration tests.
This commit is contained in:
Xi Yan 2025-01-02 11:21:33 -08:00 committed by GitHub
parent 8e5b336792
commit 3a269c4635
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
24 changed files with 544 additions and 139 deletions

View file

@ -14,8 +14,13 @@ from llama_stack.apis.scoring import (
ScoringResult,
)
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator import (
DataSchemaValidatorMixin,
get_valid_schemas,
)
from .config import BasicScoringConfig
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
@ -24,7 +29,9 @@ from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn]
class BasicScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
class BasicScoringImpl(
Scoring, ScoringFunctionsProtocolPrivate, DataSchemaValidatorMixin
):
def __init__(
self,
config: BasicScoringConfig,
@ -61,30 +68,17 @@ class BasicScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
async def register_scoring_function(self, function_def: ScoringFn) -> None:
raise NotImplementedError("Register scoring function not implemented yet")
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(
f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset."
)
for required_column in ["generated_answer", "expected_answer", "input_query"]:
if required_column not in dataset_def.dataset_schema:
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column."
)
if dataset_def.dataset_schema[required_column].type != "string":
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
)
async def score_batch(
self,
dataset_id: str,
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
self.validate_dataset_schema(
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
)
all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=-1,

View file

@ -9,12 +9,12 @@ from typing import Any, Dict, Optional
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from .fn_defs.equality import equality
class EqualityScoringFn(BaseScoringFn):
class EqualityScoringFn(RegisteredBaseScoringFn):
"""
A scoring_fn that assigns a score of 1.0 if the input string matches the target string, and 0.0 otherwise.
"""

View file

@ -9,14 +9,14 @@ from typing import Any, Dict, Optional
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from .fn_defs.regex_parser_multiple_choice_answer import (
regex_parser_multiple_choice_answer,
)
class RegexParserScoringFn(BaseScoringFn):
class RegexParserScoringFn(RegisteredBaseScoringFn):
"""
A scoring_fn that parses answer from generated response according to context and check match with expected_answer.
"""

View file

@ -8,12 +8,12 @@ from typing import Any, Dict, Optional
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from .fn_defs.subset_of import subset_of
class SubsetOfScoringFn(BaseScoringFn):
class SubsetOfScoringFn(RegisteredBaseScoringFn):
"""
A scoring_fn that assigns a score of 1.0 if the expected string is included in the generated string, and 0.0 otherwise.
"""