[/scoring] add ability to define aggregation functions for scoring functions & refactors (#597)

# What does this PR do?

- Add ability to define aggregation functions for scoring functions via
`ScoringFnParams`
- Supported by `basic` / `regex_parser` / `llm_as_judge` scoring
functions


## Test Plan

```
pytest -v -s -m basic_scoring_together_inference scoring/test_scoring.py
```
<img width="855" alt="image"
src="https://github.com/user-attachments/assets/12db8e6e-2ad4-462e-b9b9-70ba6c050a6c">


```
pytest -v -s -m llm_as_judge_scoring_together_inference scoring/test_scoring.py
```
<img width="858" alt="image"
src="https://github.com/user-attachments/assets/bf806676-6f5e-456d-be9f-f81a26d1df19">



**Example Response** (`basic`)
<img width="863" alt="image"
src="https://github.com/user-attachments/assets/0e57a49c-8386-45cc-8fa9-3e61aaa9a3be">

**Example Response** (`llm-as-judge`)
<img width="854" alt="image"
src="https://github.com/user-attachments/assets/38065bc2-b724-47ed-9535-79b6099c4362">


## Sources

Please link relevant resources if necessary.


## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Ran pre-commit to handle lint / formatting issues.
- [ ] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [ ] Updated relevant documentation.
- [ ] Wrote necessary unit or integration tests.
This commit is contained in:
Xi Yan 2024-12-11 10:03:42 -08:00 committed by GitHub
parent e128f2547a
commit a4bcfb8bba
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 323 additions and 55 deletions

View file

@ -4926,6 +4926,15 @@
"config"
]
},
"AggregationFunctionType": {
"type": "string",
"enum": [
"average",
"median",
"categorical_count",
"accuracy"
]
},
"AppEvalTaskConfig": {
"type": "object",
"properties": {
@ -4953,6 +4962,9 @@
},
{
"$ref": "#/components/schemas/RegexParserScoringFnParams"
},
{
"$ref": "#/components/schemas/BasicScoringFnParams"
}
]
}
@ -4968,6 +4980,26 @@
"scoring_params"
]
},
"BasicScoringFnParams": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "basic",
"default": "basic"
},
"aggregation_functions": {
"type": "array",
"items": {
"$ref": "#/components/schemas/AggregationFunctionType"
}
}
},
"additionalProperties": false,
"required": [
"type"
]
},
"BenchmarkEvalTaskConfig": {
"type": "object",
"properties": {
@ -5015,6 +5047,12 @@
"items": {
"type": "string"
}
},
"aggregation_functions": {
"type": "array",
"items": {
"$ref": "#/components/schemas/AggregationFunctionType"
}
}
},
"additionalProperties": false,
@ -5061,6 +5099,12 @@
"items": {
"type": "string"
}
},
"aggregation_functions": {
"type": "array",
"items": {
"$ref": "#/components/schemas/AggregationFunctionType"
}
}
},
"additionalProperties": false,
@ -6014,6 +6058,9 @@
},
{
"$ref": "#/components/schemas/RegexParserScoringFnParams"
},
{
"$ref": "#/components/schemas/BasicScoringFnParams"
}
]
}
@ -7771,6 +7818,9 @@
},
{
"$ref": "#/components/schemas/RegexParserScoringFnParams"
},
{
"$ref": "#/components/schemas/BasicScoringFnParams"
}
]
}
@ -7998,6 +8048,9 @@
},
{
"$ref": "#/components/schemas/RegexParserScoringFnParams"
},
{
"$ref": "#/components/schemas/BasicScoringFnParams"
}
]
},
@ -8046,6 +8099,9 @@
},
{
"$ref": "#/components/schemas/RegexParserScoringFnParams"
},
{
"$ref": "#/components/schemas/BasicScoringFnParams"
}
]
},
@ -8491,6 +8547,10 @@
{
"name": "Agents"
},
{
"name": "AggregationFunctionType",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/AggregationFunctionType\" />"
},
{
"name": "AppEvalTaskConfig",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/AppEvalTaskConfig\" />"
@ -8503,6 +8563,10 @@
"name": "Attachment",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/Attachment\" />"
},
{
"name": "BasicScoringFnParams",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/BasicScoringFnParams\" />"
},
{
"name": "BatchChatCompletionRequest",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/BatchChatCompletionRequest\" />"
@ -9146,9 +9210,11 @@
"AgentTurnResponseStreamChunk",
"AgentTurnResponseTurnCompletePayload",
"AgentTurnResponseTurnStartPayload",
"AggregationFunctionType",
"AppEvalTaskConfig",
"AppendRowsRequest",
"Attachment",
"BasicScoringFnParams",
"BatchChatCompletionRequest",
"BatchChatCompletionResponse",
"BatchCompletionRequest",

View file

@ -216,6 +216,13 @@ components:
- event_type
- turn_id
type: object
AggregationFunctionType:
enum:
- average
- median
- categorical_count
- accuracy
type: string
AppEvalTaskConfig:
additionalProperties: false
properties:
@ -230,6 +237,7 @@ components:
oneOf:
- $ref: '#/components/schemas/LLMAsJudgeScoringFnParams'
- $ref: '#/components/schemas/RegexParserScoringFnParams'
- $ref: '#/components/schemas/BasicScoringFnParams'
type: object
type:
const: app
@ -280,6 +288,20 @@ components:
- content
- mime_type
type: object
BasicScoringFnParams:
additionalProperties: false
properties:
aggregation_functions:
items:
$ref: '#/components/schemas/AggregationFunctionType'
type: array
type:
const: basic
default: basic
type: string
required:
- type
type: object
BatchChatCompletionRequest:
additionalProperties: false
properties:
@ -1280,6 +1302,10 @@ components:
LLMAsJudgeScoringFnParams:
additionalProperties: false
properties:
aggregation_functions:
items:
$ref: '#/components/schemas/AggregationFunctionType'
type: array
judge_model:
type: string
judge_score_regexes:
@ -1984,6 +2010,10 @@ components:
RegexParserScoringFnParams:
additionalProperties: false
properties:
aggregation_functions:
items:
$ref: '#/components/schemas/AggregationFunctionType'
type: array
parsing_regexes:
items:
type: string
@ -2195,6 +2225,7 @@ components:
oneOf:
- $ref: '#/components/schemas/LLMAsJudgeScoringFnParams'
- $ref: '#/components/schemas/RegexParserScoringFnParams'
- $ref: '#/components/schemas/BasicScoringFnParams'
provider_id:
type: string
provider_scoring_fn_id:
@ -2515,6 +2546,7 @@ components:
- oneOf:
- $ref: '#/components/schemas/LLMAsJudgeScoringFnParams'
- $ref: '#/components/schemas/RegexParserScoringFnParams'
- $ref: '#/components/schemas/BasicScoringFnParams'
- type: 'null'
type: object
required:
@ -2555,6 +2587,7 @@ components:
- oneOf:
- $ref: '#/components/schemas/LLMAsJudgeScoringFnParams'
- $ref: '#/components/schemas/RegexParserScoringFnParams'
- $ref: '#/components/schemas/BasicScoringFnParams'
- type: 'null'
type: object
required:
@ -2592,6 +2625,7 @@ components:
oneOf:
- $ref: '#/components/schemas/LLMAsJudgeScoringFnParams'
- $ref: '#/components/schemas/RegexParserScoringFnParams'
- $ref: '#/components/schemas/BasicScoringFnParams'
provider_id:
type: string
provider_resource_id:
@ -5161,6 +5195,9 @@ tags:
/>
name: AgentTurnResponseTurnStartPayload
- name: Agents
- description: <SchemaDefinition schemaRef="#/components/schemas/AggregationFunctionType"
/>
name: AggregationFunctionType
- description: <SchemaDefinition schemaRef="#/components/schemas/AppEvalTaskConfig"
/>
name: AppEvalTaskConfig
@ -5169,6 +5206,9 @@ tags:
name: AppendRowsRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/Attachment" />
name: Attachment
- description: <SchemaDefinition schemaRef="#/components/schemas/BasicScoringFnParams"
/>
name: BasicScoringFnParams
- description: <SchemaDefinition schemaRef="#/components/schemas/BatchChatCompletionRequest"
/>
name: BatchChatCompletionRequest
@ -5636,9 +5676,11 @@ x-tagGroups:
- AgentTurnResponseStreamChunk
- AgentTurnResponseTurnCompletePayload
- AgentTurnResponseTurnStartPayload
- AggregationFunctionType
- AppEvalTaskConfig
- AppendRowsRequest
- Attachment
- BasicScoringFnParams
- BatchChatCompletionRequest
- BatchChatCompletionResponse
- BatchCompletionRequest