This commit is contained in:
Xi Yan 2025-03-11 22:45:48 -07:00
parent 11e57e17e6
commit f9ea90c4f7
3 changed files with 90 additions and 20 deletions

View file

@ -6348,7 +6348,8 @@
"categorical_count", "categorical_count",
"accuracy" "accuracy"
], ],
"title": "AggregationFunctionType" "title": "AggregationFunctionType",
"description": "A type of aggregation function."
}, },
"BasicScoringFnParams": { "BasicScoringFnParams": {
"type": "object", "type": "object",
@ -6362,14 +6363,16 @@
"type": "array", "type": "array",
"items": { "items": {
"$ref": "#/components/schemas/AggregationFunctionType" "$ref": "#/components/schemas/AggregationFunctionType"
} },
"description": "(Optional) Aggregation functions to apply to the scores of each row. No aggregation for results is calculated if not provided."
} }
}, },
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"type" "type"
], ],
"title": "BasicScoringFnParams" "title": "BasicScoringFnParams",
"description": "Parameters for a non-parameterized scoring function."
}, },
"BenchmarkConfig": { "BenchmarkConfig": {
"type": "object", "type": "object",
@ -6420,26 +6423,30 @@
"properties": { "properties": {
"type": { "type": {
"type": "string", "type": "string",
"const": "llm_as_judge", "const": "custom_llm_as_judge",
"default": "llm_as_judge" "default": "custom_llm_as_judge"
}, },
"judge_model": { "judge_model": {
"type": "string" "type": "string",
"description": "The model to use for scoring."
}, },
"prompt_template": { "prompt_template": {
"type": "string" "type": "string",
"description": "(Optional) The prompt template to use for scoring."
}, },
"judge_score_regexes": { "judge_score_regexes": {
"type": "array", "type": "array",
"items": { "items": {
"type": "string" "type": "string"
} },
"description": "(Optional) Regexes to extract the score from the judge model's response."
}, },
"aggregation_functions": { "aggregation_functions": {
"type": "array", "type": "array",
"items": { "items": {
"$ref": "#/components/schemas/AggregationFunctionType" "$ref": "#/components/schemas/AggregationFunctionType"
} },
"description": "(Optional) Aggregation functions to apply to the scores of each row. No aggregation for results is calculated if not provided."
} }
}, },
"additionalProperties": false, "additionalProperties": false,
@ -6447,7 +6454,8 @@
"type", "type",
"judge_model" "judge_model"
], ],
"title": "LLMAsJudgeScoringFnParams" "title": "LLMAsJudgeScoringFnParams",
"description": "Parameters for a scoring function that uses a judge model to score the answer."
}, },
"ModelCandidate": { "ModelCandidate": {
"type": "object", "type": "object",
@ -6491,20 +6499,23 @@
"type": "array", "type": "array",
"items": { "items": {
"type": "string" "type": "string"
} },
"description": "Regexes to extract the answer from generated response"
}, },
"aggregation_functions": { "aggregation_functions": {
"type": "array", "type": "array",
"items": { "items": {
"$ref": "#/components/schemas/AggregationFunctionType" "$ref": "#/components/schemas/AggregationFunctionType"
} },
"description": "(Optional) Aggregation functions to apply to the scores of each row. No aggregation for results is calculated if not provided."
} }
}, },
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"type" "type"
], ],
"title": "RegexParserScoringFnParams" "title": "RegexParserScoringFnParams",
"description": "Parameters for a scoring function that parses the answer from the generated response using regexes, and checks against the expected answer."
}, },
"ScoringFnParams": { "ScoringFnParams": {
"oneOf": [ "oneOf": [
@ -6521,7 +6532,7 @@
"discriminator": { "discriminator": {
"propertyName": "type", "propertyName": "type",
"mapping": { "mapping": {
"llm_as_judge": "#/components/schemas/LLMAsJudgeScoringFnParams", "custom_llm_as_judge": "#/components/schemas/LLMAsJudgeScoringFnParams",
"regex_parser": "#/components/schemas/RegexParserScoringFnParams", "regex_parser": "#/components/schemas/RegexParserScoringFnParams",
"basic": "#/components/schemas/BasicScoringFnParams" "basic": "#/components/schemas/BasicScoringFnParams"
} }

View file

@ -4419,6 +4419,7 @@ components:
- categorical_count - categorical_count
- accuracy - accuracy
title: AggregationFunctionType title: AggregationFunctionType
description: A type of aggregation function.
BasicScoringFnParams: BasicScoringFnParams:
type: object type: object
properties: properties:
@ -4430,10 +4431,15 @@ components:
type: array type: array
items: items:
$ref: '#/components/schemas/AggregationFunctionType' $ref: '#/components/schemas/AggregationFunctionType'
description: >-
(Optional) Aggregation functions to apply to the scores of each row. No
aggregation for results is calculated if not provided.
additionalProperties: false additionalProperties: false
required: required:
- type - type
title: BasicScoringFnParams title: BasicScoringFnParams
description: >-
Parameters for a non-parameterized scoring function.
BenchmarkConfig: BenchmarkConfig:
type: object type: object
properties: properties:
@ -4473,25 +4479,35 @@ components:
properties: properties:
type: type:
type: string type: string
const: llm_as_judge const: custom_llm_as_judge
default: llm_as_judge default: custom_llm_as_judge
judge_model: judge_model:
type: string type: string
description: The model to use for scoring.
prompt_template: prompt_template:
type: string type: string
description: >-
(Optional) The prompt template to use for scoring.
judge_score_regexes: judge_score_regexes:
type: array type: array
items: items:
type: string type: string
description: >-
(Optional) Regexes to extract the score from the judge model's response.
aggregation_functions: aggregation_functions:
type: array type: array
items: items:
$ref: '#/components/schemas/AggregationFunctionType' $ref: '#/components/schemas/AggregationFunctionType'
description: >-
(Optional) Aggregation functions to apply to the scores of each row. No
aggregation for results is calculated if not provided.
additionalProperties: false additionalProperties: false
required: required:
- type - type
- judge_model - judge_model
title: LLMAsJudgeScoringFnParams title: LLMAsJudgeScoringFnParams
description: >-
Parameters for a scoring function that uses a judge model to score the answer.
ModelCandidate: ModelCandidate:
type: object type: object
properties: properties:
@ -4528,14 +4544,22 @@ components:
type: array type: array
items: items:
type: string type: string
description: >-
Regexes to extract the answer from generated response
aggregation_functions: aggregation_functions:
type: array type: array
items: items:
$ref: '#/components/schemas/AggregationFunctionType' $ref: '#/components/schemas/AggregationFunctionType'
description: >-
(Optional) Aggregation functions to apply to the scores of each row. No
aggregation for results is calculated if not provided.
additionalProperties: false additionalProperties: false
required: required:
- type - type
title: RegexParserScoringFnParams title: RegexParserScoringFnParams
description: >-
Parameters for a scoring function that parses the answer from the generated
response using regexes, and checks against the expected answer.
ScoringFnParams: ScoringFnParams:
oneOf: oneOf:
- $ref: '#/components/schemas/LLMAsJudgeScoringFnParams' - $ref: '#/components/schemas/LLMAsJudgeScoringFnParams'
@ -4544,7 +4568,7 @@ components:
discriminator: discriminator:
propertyName: type propertyName: type
mapping: mapping:
llm_as_judge: '#/components/schemas/LLMAsJudgeScoringFnParams' custom_llm_as_judge: '#/components/schemas/LLMAsJudgeScoringFnParams'
regex_parser: '#/components/schemas/RegexParserScoringFnParams' regex_parser: '#/components/schemas/RegexParserScoringFnParams'
basic: '#/components/schemas/BasicScoringFnParams' basic: '#/components/schemas/BasicScoringFnParams'
EvaluateRowsRequest: EvaluateRowsRequest:

View file

@ -12,8 +12,8 @@ from typing import (
Literal, Literal,
Optional, Optional,
Protocol, Protocol,
Union,
runtime_checkable, runtime_checkable,
Union,
) )
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -65,6 +65,15 @@ class ScoringFunctionType(Enum):
@json_schema_type @json_schema_type
class AggregationFunctionType(Enum): class AggregationFunctionType(Enum):
"""
A type of aggregation function.
:cvar average: Average the scores of each row.
:cvar median: Median the scores of each row.
:cvar categorical_count: Count the number of rows that match each category.
:cvar accuracy: Number of correct results over total results.
"""
average = "average" average = "average"
median = "median" median = "median"
categorical_count = "categorical_count" categorical_count = "categorical_count"
@ -73,6 +82,15 @@ class AggregationFunctionType(Enum):
@json_schema_type @json_schema_type
class LLMAsJudgeScoringFnParams(BaseModel): class LLMAsJudgeScoringFnParams(BaseModel):
"""
Parameters for a scoring function that uses a judge model to score the answer.
:param judge_model: The model to use for scoring.
:param prompt_template: (Optional) The prompt template to use for scoring.
:param judge_score_regexes: (Optional) Regexes to extract the score from the judge model's response.
:param aggregation_functions: (Optional) Aggregation functions to apply to the scores of each row. No aggregation for results is calculated if not provided.
"""
type: Literal["custom_llm_as_judge"] = "custom_llm_as_judge" type: Literal["custom_llm_as_judge"] = "custom_llm_as_judge"
judge_model: str judge_model: str
prompt_template: Optional[str] = None prompt_template: Optional[str] = None
@ -88,6 +106,13 @@ class LLMAsJudgeScoringFnParams(BaseModel):
@json_schema_type @json_schema_type
class RegexParserScoringFnParams(BaseModel): class RegexParserScoringFnParams(BaseModel):
"""
Parameters for a scoring function that parses the answer from the generated response using regexes, and checks against the expected answer.
:param parsing_regexes: Regexes to extract the answer from generated response
:param aggregation_functions: (Optional) Aggregation functions to apply to the scores of each row. No aggregation for results is calculated if not provided.
"""
type: Literal["regex_parser"] = "regex_parser" type: Literal["regex_parser"] = "regex_parser"
parsing_regexes: Optional[List[str]] = Field( parsing_regexes: Optional[List[str]] = Field(
description="Regexes to extract the answer from generated response", description="Regexes to extract the answer from generated response",
@ -101,6 +126,12 @@ class RegexParserScoringFnParams(BaseModel):
@json_schema_type @json_schema_type
class BasicScoringFnParams(BaseModel): class BasicScoringFnParams(BaseModel):
"""
Parameters for a non-parameterized scoring function.
:param aggregation_functions: (Optional) Aggregation functions to apply to the scores of each row. No aggregation for results is calculated if not provided.
"""
type: Literal["basic"] = "basic" type: Literal["basic"] = "basic"
aggregation_functions: Optional[List[AggregationFunctionType]] = Field( aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
description="Aggregation functions to apply to the scores of each row", description="Aggregation functions to apply to the scores of each row",
@ -135,7 +166,9 @@ class CommonScoringFnFields(BaseModel):
@json_schema_type @json_schema_type
class ScoringFn(CommonScoringFnFields, Resource): class ScoringFn(CommonScoringFnFields, Resource):
type: Literal[ResourceType.scoring_function.value] = ResourceType.scoring_function.value type: Literal[ResourceType.scoring_function.value] = (
ResourceType.scoring_function.value
)
@property @property
def scoring_fn_id(self) -> str: def scoring_fn_id(self) -> str:
@ -162,7 +195,9 @@ class ScoringFunctions(Protocol):
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ... async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ...
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET") @webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET")
async def get_scoring_function(self, scoring_fn_id: str, /) -> Optional[ScoringFn]: ... async def get_scoring_function(
self, scoring_fn_id: str, /
) -> Optional[ScoringFn]: ...
@webmethod(route="/scoring-functions", method="POST") @webmethod(route="/scoring-functions", method="POST")
async def register_scoring_function( async def register_scoring_function(