This commit is contained in:
Xi Yan 2025-03-11 22:45:48 -07:00
parent 11e57e17e6
commit f9ea90c4f7
3 changed files with 90 additions and 20 deletions

View file

@ -6348,7 +6348,8 @@
"categorical_count",
"accuracy"
],
"title": "AggregationFunctionType"
"title": "AggregationFunctionType",
"description": "A type of aggregation function."
},
"BasicScoringFnParams": {
"type": "object",
@ -6362,14 +6363,16 @@
"type": "array",
"items": {
"$ref": "#/components/schemas/AggregationFunctionType"
}
},
"description": "(Optional) Aggregation functions to apply to the scores of each row. No aggregation for results is calculated if not provided."
}
},
"additionalProperties": false,
"required": [
"type"
],
"title": "BasicScoringFnParams"
"title": "BasicScoringFnParams",
"description": "Parameters for a non-parameterized scoring function."
},
"BenchmarkConfig": {
"type": "object",
@ -6420,26 +6423,30 @@
"properties": {
"type": {
"type": "string",
"const": "llm_as_judge",
"default": "llm_as_judge"
"const": "custom_llm_as_judge",
"default": "custom_llm_as_judge"
},
"judge_model": {
"type": "string"
"type": "string",
"description": "The model to use for scoring."
},
"prompt_template": {
"type": "string"
"type": "string",
"description": "(Optional) The prompt template to use for scoring."
},
"judge_score_regexes": {
"type": "array",
"items": {
"type": "string"
}
},
"description": "(Optional) Regexes to extract the score from the judge model's response."
},
"aggregation_functions": {
"type": "array",
"items": {
"$ref": "#/components/schemas/AggregationFunctionType"
}
},
"description": "(Optional) Aggregation functions to apply to the scores of each row. No aggregation for results is calculated if not provided."
}
},
"additionalProperties": false,
@ -6447,7 +6454,8 @@
"type",
"judge_model"
],
"title": "LLMAsJudgeScoringFnParams"
"title": "LLMAsJudgeScoringFnParams",
"description": "Parameters for a scoring function that uses a judge model to score the answer."
},
"ModelCandidate": {
"type": "object",
@ -6491,20 +6499,23 @@
"type": "array",
"items": {
"type": "string"
}
},
"description": "Regexes to extract the answer from generated response"
},
"aggregation_functions": {
"type": "array",
"items": {
"$ref": "#/components/schemas/AggregationFunctionType"
}
},
"description": "(Optional) Aggregation functions to apply to the scores of each row. No aggregation for results is calculated if not provided."
}
},
"additionalProperties": false,
"required": [
"type"
],
"title": "RegexParserScoringFnParams"
"title": "RegexParserScoringFnParams",
"description": "Parameters for a scoring function that parses the answer from the generated response using regexes, and checks against the expected answer."
},
"ScoringFnParams": {
"oneOf": [
@ -6521,7 +6532,7 @@
"discriminator": {
"propertyName": "type",
"mapping": {
"llm_as_judge": "#/components/schemas/LLMAsJudgeScoringFnParams",
"custom_llm_as_judge": "#/components/schemas/LLMAsJudgeScoringFnParams",
"regex_parser": "#/components/schemas/RegexParserScoringFnParams",
"basic": "#/components/schemas/BasicScoringFnParams"
}

View file

@ -4419,6 +4419,7 @@ components:
- categorical_count
- accuracy
title: AggregationFunctionType
description: A type of aggregation function.
BasicScoringFnParams:
type: object
properties:
@ -4430,10 +4431,15 @@ components:
type: array
items:
$ref: '#/components/schemas/AggregationFunctionType'
description: >-
(Optional) Aggregation functions to apply to the scores of each row. No
aggregation for results is calculated if not provided.
additionalProperties: false
required:
- type
title: BasicScoringFnParams
description: >-
Parameters for a non-parameterized scoring function.
BenchmarkConfig:
type: object
properties:
@ -4473,25 +4479,35 @@ components:
properties:
type:
type: string
const: llm_as_judge
default: llm_as_judge
const: custom_llm_as_judge
default: custom_llm_as_judge
judge_model:
type: string
description: The model to use for scoring.
prompt_template:
type: string
description: >-
(Optional) The prompt template to use for scoring.
judge_score_regexes:
type: array
items:
type: string
description: >-
(Optional) Regexes to extract the score from the judge model's response.
aggregation_functions:
type: array
items:
$ref: '#/components/schemas/AggregationFunctionType'
description: >-
(Optional) Aggregation functions to apply to the scores of each row. No
aggregation for results is calculated if not provided.
additionalProperties: false
required:
- type
- judge_model
title: LLMAsJudgeScoringFnParams
description: >-
Parameters for a scoring function that uses a judge model to score the answer.
ModelCandidate:
type: object
properties:
@ -4528,14 +4544,22 @@ components:
type: array
items:
type: string
description: >-
Regexes to extract the answer from generated response
aggregation_functions:
type: array
items:
$ref: '#/components/schemas/AggregationFunctionType'
description: >-
(Optional) Aggregation functions to apply to the scores of each row. No
aggregation for results is calculated if not provided.
additionalProperties: false
required:
- type
title: RegexParserScoringFnParams
description: >-
Parameters for a scoring function that parses the answer from the generated
response using regexes, and checks against the expected answer.
ScoringFnParams:
oneOf:
- $ref: '#/components/schemas/LLMAsJudgeScoringFnParams'
@ -4544,7 +4568,7 @@ components:
discriminator:
propertyName: type
mapping:
llm_as_judge: '#/components/schemas/LLMAsJudgeScoringFnParams'
custom_llm_as_judge: '#/components/schemas/LLMAsJudgeScoringFnParams'
regex_parser: '#/components/schemas/RegexParserScoringFnParams'
basic: '#/components/schemas/BasicScoringFnParams'
EvaluateRowsRequest:

View file

@ -12,8 +12,8 @@ from typing import (
Literal,
Optional,
Protocol,
Union,
runtime_checkable,
Union,
)
from pydantic import BaseModel, Field
@ -65,6 +65,15 @@ class ScoringFunctionType(Enum):
@json_schema_type
class AggregationFunctionType(Enum):
"""
A type of aggregation function.
:cvar average: Average the scores of each row.
:cvar median: Median the scores of each row.
:cvar categorical_count: Count the number of rows that match each category.
:cvar accuracy: Number of correct results over total results.
"""
average = "average"
median = "median"
categorical_count = "categorical_count"
@ -73,6 +82,15 @@ class AggregationFunctionType(Enum):
@json_schema_type
class LLMAsJudgeScoringFnParams(BaseModel):
"""
Parameters for a scoring function that uses a judge model to score the answer.
:param judge_model: The model to use for scoring.
:param prompt_template: (Optional) The prompt template to use for scoring.
:param judge_score_regexes: (Optional) Regexes to extract the score from the judge model's response.
:param aggregation_functions: (Optional) Aggregation functions to apply to the scores of each row. No aggregation for results is calculated if not provided.
"""
type: Literal["custom_llm_as_judge"] = "custom_llm_as_judge"
judge_model: str
prompt_template: Optional[str] = None
@ -88,6 +106,13 @@ class LLMAsJudgeScoringFnParams(BaseModel):
@json_schema_type
class RegexParserScoringFnParams(BaseModel):
"""
Parameters for a scoring function that parses the answer from the generated response using regexes, and checks against the expected answer.
:param parsing_regexes: Regexes to extract the answer from generated response
:param aggregation_functions: (Optional) Aggregation functions to apply to the scores of each row. No aggregation for results is calculated if not provided.
"""
type: Literal["regex_parser"] = "regex_parser"
parsing_regexes: Optional[List[str]] = Field(
description="Regexes to extract the answer from generated response",
@ -101,6 +126,12 @@ class RegexParserScoringFnParams(BaseModel):
@json_schema_type
class BasicScoringFnParams(BaseModel):
"""
Parameters for a non-parameterized scoring function.
:param aggregation_functions: (Optional) Aggregation functions to apply to the scores of each row. No aggregation for results is calculated if not provided.
"""
type: Literal["basic"] = "basic"
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
description="Aggregation functions to apply to the scores of each row",
@ -135,7 +166,9 @@ class CommonScoringFnFields(BaseModel):
@json_schema_type
class ScoringFn(CommonScoringFnFields, Resource):
type: Literal[ResourceType.scoring_function.value] = ResourceType.scoring_function.value
type: Literal[ResourceType.scoring_function.value] = (
ResourceType.scoring_function.value
)
@property
def scoring_fn_id(self) -> str:
@ -162,7 +195,9 @@ class ScoringFunctions(Protocol):
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ...
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET")
async def get_scoring_function(self, scoring_fn_id: str, /) -> Optional[ScoringFn]: ...
async def get_scoring_function(
self, scoring_fn_id: str, /
) -> Optional[ScoringFn]: ...
@webmethod(route="/scoring-functions", method="POST")
async def register_scoring_function(