[/scoring] add ability to define aggregation functions for scoring functions & refactors (#597)

# What does this PR do?

- Add ability to define aggregation functions for scoring functions via
`ScoringFnParams`
- Supported by `basic` / `regex_parser` / `llm_as_judge` scoring
functions


## Test Plan

```
pytest -v -s -m basic_scoring_together_inference scoring/test_scoring.py
```
<img width="855" alt="image"
src="https://github.com/user-attachments/assets/12db8e6e-2ad4-462e-b9b9-70ba6c050a6c">


```
pytest -v -s -m llm_as_judge_scoring_together_inference scoring/test_scoring.py
```
<img width="858" alt="image"
src="https://github.com/user-attachments/assets/bf806676-6f5e-456d-be9f-f81a26d1df19">



**Example Response** (`basic`)
<img width="863" alt="image"
src="https://github.com/user-attachments/assets/0e57a49c-8386-45cc-8fa9-3e61aaa9a3be">

**Example Response** (`llm-as-judge`)
<img width="854" alt="image"
src="https://github.com/user-attachments/assets/38065bc2-b724-47ed-9535-79b6099c4362">


## Sources

Please link relevant resources if necessary.


## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Ran pre-commit to handle lint / formatting issues.
- [ ] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [ ] Updated relevant documentation.
- [ ] Wrote necessary unit or integration tests.
This commit is contained in:
Xi Yan 2024-12-11 10:03:42 -08:00 committed by GitHub
parent e128f2547a
commit a4bcfb8bba
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 323 additions and 55 deletions

View file

@ -31,6 +31,15 @@ from llama_stack.apis.resource import Resource, ResourceType
class ScoringFnParamsType(Enum):
llm_as_judge = "llm_as_judge"
regex_parser = "regex_parser"
basic = "basic"
@json_schema_type
class AggregationFunctionType(Enum):
average = "average"
median = "median"
categorical_count = "categorical_count"
accuracy = "accuracy"
@json_schema_type
@ -44,6 +53,10 @@ class LLMAsJudgeScoringFnParams(BaseModel):
description="Regexes to extract the answer from generated response",
default_factory=list,
)
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
description="Aggregation functions to apply to the scores of each row",
default_factory=list,
)
@json_schema_type
@ -55,12 +68,26 @@ class RegexParserScoringFnParams(BaseModel):
description="Regex to extract the answer from generated response",
default_factory=list,
)
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
description="Aggregation functions to apply to the scores of each row",
default_factory=list,
)
@json_schema_type
class BasicScoringFnParams(BaseModel):
type: Literal[ScoringFnParamsType.basic.value] = ScoringFnParamsType.basic.value
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
description="Aggregation functions to apply to the scores of each row",
default_factory=list,
)
ScoringFnParams = Annotated[
Union[
LLMAsJudgeScoringFnParams,
RegexParserScoringFnParams,
BasicScoringFnParams,
],
Field(discriminator="type"),
]