From bec5a469158ad81aef4c685056b9fb05d7208d2b Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Tue, 11 Mar 2025 23:20:16 -0700 Subject: [PATCH] single type --- docs/_static/llama-stack-spec.html | 353 ++++++++++++++++-- docs/_static/llama-stack-spec.yaml | 280 ++++++++++++-- .../scoring_functions/scoring_functions.py | 172 ++++----- 3 files changed, 639 insertions(+), 166 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index a3ff2c181..a698c2c9c 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -6351,28 +6351,71 @@ "title": "AggregationFunctionType", "description": "A type of aggregation function." }, - "BasicScoringFnParams": { + "AnswerCorrectnessScoringFnParams": { "type": "object", "properties": { - "type": { - "type": "string", - "const": "basic", - "default": "basic" - }, "aggregation_functions": { "type": "array", "items": { "$ref": "#/components/schemas/AggregationFunctionType" }, - "description": "(Optional) Aggregation functions to apply to the scores of each row. No aggregation for results is calculated if not provided." + "description": "(Optional) Aggregation functions to apply to the scores of each row. If not provided, no aggregation will be performed." + }, + "type": { + "type": "string", + "const": "answer_correctness", + "default": "answer_correctness" } }, "additionalProperties": false, "required": [ "type" ], - "title": "BasicScoringFnParams", - "description": "Parameters for a non-parameterized scoring function." + "title": "AnswerCorrectnessScoringFnParams" + }, + "AnswerRelevancyScoringFnParams": { + "type": "object", + "properties": { + "aggregation_functions": { + "type": "array", + "items": { + "$ref": "#/components/schemas/AggregationFunctionType" + }, + "description": "(Optional) Aggregation functions to apply to the scores of each row. If not provided, no aggregation will be performed." + }, + "type": { + "type": "string", + "const": "answer_relevancy", + "default": "answer_relevancy" + } + }, + "additionalProperties": false, + "required": [ + "type" + ], + "title": "AnswerRelevancyScoringFnParams" + }, + "AnswerSimilarityScoringFnParams": { + "type": "object", + "properties": { + "aggregation_functions": { + "type": "array", + "items": { + "$ref": "#/components/schemas/AggregationFunctionType" + }, + "description": "(Optional) Aggregation functions to apply to the scores of each row. If not provided, no aggregation will be performed." + }, + "type": { + "type": "string", + "const": "answer_similarity", + "default": "answer_similarity" + } + }, + "additionalProperties": false, + "required": [ + "type" + ], + "title": "AnswerSimilarityScoringFnParams" }, "BenchmarkConfig": { "type": "object", @@ -6401,6 +6444,116 @@ "title": "BenchmarkConfig", "description": "A benchmark configuration for evaluation." }, + "ContextEntityRecallScoringFnParams": { + "type": "object", + "properties": { + "aggregation_functions": { + "type": "array", + "items": { + "$ref": "#/components/schemas/AggregationFunctionType" + }, + "description": "(Optional) Aggregation functions to apply to the scores of each row. If not provided, no aggregation will be performed." + }, + "type": { + "type": "string", + "const": "context_entity_recall", + "default": "context_entity_recall" + } + }, + "additionalProperties": false, + "required": [ + "type" + ], + "title": "ContextEntityRecallScoringFnParams" + }, + "ContextPrecisionScoringFnParams": { + "type": "object", + "properties": { + "aggregation_functions": { + "type": "array", + "items": { + "$ref": "#/components/schemas/AggregationFunctionType" + }, + "description": "(Optional) Aggregation functions to apply to the scores of each row. If not provided, no aggregation will be performed." + }, + "type": { + "type": "string", + "const": "context_precision", + "default": "context_precision" + } + }, + "additionalProperties": false, + "required": [ + "type" + ], + "title": "ContextPrecisionScoringFnParams" + }, + "ContextRecallScoringFnParams": { + "type": "object", + "properties": { + "aggregation_functions": { + "type": "array", + "items": { + "$ref": "#/components/schemas/AggregationFunctionType" + }, + "description": "(Optional) Aggregation functions to apply to the scores of each row. If not provided, no aggregation will be performed." + }, + "type": { + "type": "string", + "const": "context_recall", + "default": "context_recall" + } + }, + "additionalProperties": false, + "required": [ + "type" + ], + "title": "ContextRecallScoringFnParams" + }, + "ContextRelevancyScoringFnParams": { + "type": "object", + "properties": { + "aggregation_functions": { + "type": "array", + "items": { + "$ref": "#/components/schemas/AggregationFunctionType" + }, + "description": "(Optional) Aggregation functions to apply to the scores of each row. If not provided, no aggregation will be performed." + }, + "type": { + "type": "string", + "const": "context_relevancy", + "default": "context_relevancy" + } + }, + "additionalProperties": false, + "required": [ + "type" + ], + "title": "ContextRelevancyScoringFnParams" + }, + "EqualityScoringFnParams": { + "type": "object", + "properties": { + "aggregation_functions": { + "type": "array", + "items": { + "$ref": "#/components/schemas/AggregationFunctionType" + }, + "description": "(Optional) Aggregation functions to apply to the scores of each row. If not provided, no aggregation will be performed." + }, + "type": { + "type": "string", + "const": "equality", + "default": "equality" + } + }, + "additionalProperties": false, + "required": [ + "type" + ], + "title": "EqualityScoringFnParams" + }, "EvalCandidate": { "oneOf": [ { @@ -6418,6 +6571,50 @@ } } }, + "FactualityScoringFnParams": { + "type": "object", + "properties": { + "aggregation_functions": { + "type": "array", + "items": { + "$ref": "#/components/schemas/AggregationFunctionType" + }, + "description": "(Optional) Aggregation functions to apply to the scores of each row. If not provided, no aggregation will be performed." + }, + "type": { + "type": "string", + "const": "factuality", + "default": "factuality" + } + }, + "additionalProperties": false, + "required": [ + "type" + ], + "title": "FactualityScoringFnParams" + }, + "FaithfulnessScoringFnParams": { + "type": "object", + "properties": { + "aggregation_functions": { + "type": "array", + "items": { + "$ref": "#/components/schemas/AggregationFunctionType" + }, + "description": "(Optional) Aggregation functions to apply to the scores of each row. If not provided, no aggregation will be performed." + }, + "type": { + "type": "string", + "const": "faithfulness", + "default": "faithfulness" + } + }, + "additionalProperties": false, + "required": [ + "type" + ], + "title": "FaithfulnessScoringFnParams" + }, "LLMAsJudgeScoringFnParams": { "type": "object", "properties": { @@ -6427,26 +6624,22 @@ "default": "custom_llm_as_judge" }, "judge_model": { - "type": "string", - "description": "The model to use for scoring." + "type": "string" }, "prompt_template": { - "type": "string", - "description": "(Optional) The prompt template to use for scoring." + "type": "string" }, "judge_score_regexes": { "type": "array", "items": { "type": "string" - }, - "description": "(Optional) Regexes to extract the score from the judge model's response." + } }, "aggregation_functions": { "type": "array", "items": { "$ref": "#/components/schemas/AggregationFunctionType" - }, - "description": "(Optional) Aggregation functions to apply to the scores of each row. No aggregation for results is calculated if not provided." + } } }, "additionalProperties": false, @@ -6454,8 +6647,7 @@ "type", "judge_model" ], - "title": "LLMAsJudgeScoringFnParams", - "description": "Parameters for a scoring function that uses a judge model to score the answer." + "title": "LLMAsJudgeScoringFnParams" }, "ModelCandidate": { "type": "object", @@ -6487,35 +6679,65 @@ "title": "ModelCandidate", "description": "A model candidate for evaluation." }, - "RegexParserScoringFnParams": { + "RegexParserMathScoringFnParams": { "type": "object", "properties": { - "type": { - "type": "string", - "const": "regex_parser", - "default": "regex_parser" - }, "parsing_regexes": { "type": "array", "items": { "type": "string" }, - "description": "Regexes to extract the answer from generated response" + "description": "(Optional) Regexes to extract the answer from generated response." }, "aggregation_functions": { "type": "array", "items": { "$ref": "#/components/schemas/AggregationFunctionType" }, - "description": "(Optional) Aggregation functions to apply to the scores of each row. No aggregation for results is calculated if not provided." + "description": "(Optional) Aggregation functions to apply to the scores of each row. If not provided, no aggregation will be performed." + }, + "type": { + "type": "string", + "const": "regex_parser_math_response", + "default": "regex_parser_math_response" } }, "additionalProperties": false, "required": [ + "parsing_regexes", "type" ], - "title": "RegexParserScoringFnParams", - "description": "Parameters for a scoring function that parses the answer from the generated response using regexes, and checks against the expected answer." + "title": "RegexParserMathScoringFnParams" + }, + "RegexParserScoringFnParams": { + "type": "object", + "properties": { + "parsing_regexes": { + "type": "array", + "items": { + "type": "string" + }, + "description": "(Optional) Regexes to extract the answer from generated response." + }, + "aggregation_functions": { + "type": "array", + "items": { + "$ref": "#/components/schemas/AggregationFunctionType" + }, + "description": "(Optional) Aggregation functions to apply to the scores of each row. If not provided, no aggregation will be performed." + }, + "type": { + "type": "string", + "const": "regex_parser", + "default": "regex_parser" + } + }, + "additionalProperties": false, + "required": [ + "parsing_regexes", + "type" + ], + "title": "RegexParserScoringFnParams" }, "ScoringFnParams": { "oneOf": [ @@ -6526,7 +6748,40 @@ "$ref": "#/components/schemas/RegexParserScoringFnParams" }, { - "$ref": "#/components/schemas/BasicScoringFnParams" + "$ref": "#/components/schemas/RegexParserMathScoringFnParams" + }, + { + "$ref": "#/components/schemas/EqualityScoringFnParams" + }, + { + "$ref": "#/components/schemas/SubsetOfcoringFnParams" + }, + { + "$ref": "#/components/schemas/FactualityScoringFnParams" + }, + { + "$ref": "#/components/schemas/FaithfulnessScoringFnParams" + }, + { + "$ref": "#/components/schemas/AnswerCorrectnessScoringFnParams" + }, + { + "$ref": "#/components/schemas/AnswerRelevancyScoringFnParams" + }, + { + "$ref": "#/components/schemas/AnswerSimilarityScoringFnParams" + }, + { + "$ref": "#/components/schemas/ContextEntityRecallScoringFnParams" + }, + { + "$ref": "#/components/schemas/ContextPrecisionScoringFnParams" + }, + { + "$ref": "#/components/schemas/ContextRecallScoringFnParams" + }, + { + "$ref": "#/components/schemas/ContextRelevancyScoringFnParams" } ], "discriminator": { @@ -6534,10 +6789,43 @@ "mapping": { "custom_llm_as_judge": "#/components/schemas/LLMAsJudgeScoringFnParams", "regex_parser": "#/components/schemas/RegexParserScoringFnParams", - "basic": "#/components/schemas/BasicScoringFnParams" + "regex_parser_math_response": "#/components/schemas/RegexParserMathScoringFnParams", + "equality": "#/components/schemas/EqualityScoringFnParams", + "subset_of": "#/components/schemas/SubsetOfcoringFnParams", + "factuality": "#/components/schemas/FactualityScoringFnParams", + "faithfulness": "#/components/schemas/FaithfulnessScoringFnParams", + "answer_correctness": "#/components/schemas/AnswerCorrectnessScoringFnParams", + "answer_relevancy": "#/components/schemas/AnswerRelevancyScoringFnParams", + "answer_similarity": "#/components/schemas/AnswerSimilarityScoringFnParams", + "context_entity_recall": "#/components/schemas/ContextEntityRecallScoringFnParams", + "context_precision": "#/components/schemas/ContextPrecisionScoringFnParams", + "context_recall": "#/components/schemas/ContextRecallScoringFnParams", + "context_relevancy": "#/components/schemas/ContextRelevancyScoringFnParams" } } }, + "SubsetOfcoringFnParams": { + "type": "object", + "properties": { + "aggregation_functions": { + "type": "array", + "items": { + "$ref": "#/components/schemas/AggregationFunctionType" + }, + "description": "(Optional) Aggregation functions to apply to the scores of each row. If not provided, no aggregation will be performed." + }, + "type": { + "type": "string", + "const": "subset_of", + "default": "subset_of" + } + }, + "additionalProperties": false, + "required": [ + "type" + ], + "title": "SubsetOfcoringFnParams" + }, "EvaluateRowsRequest": { "type": "object", "properties": { @@ -9371,7 +9659,8 @@ }, "additionalProperties": false, "required": [ - "scoring_fn_type" + "scoring_fn_type", + "params" ], "title": "RegisterScoringFunctionRequest" }, diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index fed0e4a85..2fe35cc2c 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -4420,26 +4420,60 @@ components: - accuracy title: AggregationFunctionType description: A type of aggregation function. - BasicScoringFnParams: + AnswerCorrectnessScoringFnParams: type: object properties: - type: - type: string - const: basic - default: basic aggregation_functions: type: array items: $ref: '#/components/schemas/AggregationFunctionType' description: >- - (Optional) Aggregation functions to apply to the scores of each row. No - aggregation for results is calculated if not provided. + (Optional) Aggregation functions to apply to the scores of each row. If + not provided, no aggregation will be performed. + type: + type: string + const: answer_correctness + default: answer_correctness additionalProperties: false required: - type - title: BasicScoringFnParams - description: >- - Parameters for a non-parameterized scoring function. + title: AnswerCorrectnessScoringFnParams + AnswerRelevancyScoringFnParams: + type: object + properties: + aggregation_functions: + type: array + items: + $ref: '#/components/schemas/AggregationFunctionType' + description: >- + (Optional) Aggregation functions to apply to the scores of each row. If + not provided, no aggregation will be performed. + type: + type: string + const: answer_relevancy + default: answer_relevancy + additionalProperties: false + required: + - type + title: AnswerRelevancyScoringFnParams + AnswerSimilarityScoringFnParams: + type: object + properties: + aggregation_functions: + type: array + items: + $ref: '#/components/schemas/AggregationFunctionType' + description: >- + (Optional) Aggregation functions to apply to the scores of each row. If + not provided, no aggregation will be performed. + type: + type: string + const: answer_similarity + default: answer_similarity + additionalProperties: false + required: + - type + title: AnswerSimilarityScoringFnParams BenchmarkConfig: type: object properties: @@ -4465,6 +4499,96 @@ components: title: BenchmarkConfig description: >- A benchmark configuration for evaluation. + ContextEntityRecallScoringFnParams: + type: object + properties: + aggregation_functions: + type: array + items: + $ref: '#/components/schemas/AggregationFunctionType' + description: >- + (Optional) Aggregation functions to apply to the scores of each row. If + not provided, no aggregation will be performed. + type: + type: string + const: context_entity_recall + default: context_entity_recall + additionalProperties: false + required: + - type + title: ContextEntityRecallScoringFnParams + ContextPrecisionScoringFnParams: + type: object + properties: + aggregation_functions: + type: array + items: + $ref: '#/components/schemas/AggregationFunctionType' + description: >- + (Optional) Aggregation functions to apply to the scores of each row. If + not provided, no aggregation will be performed. + type: + type: string + const: context_precision + default: context_precision + additionalProperties: false + required: + - type + title: ContextPrecisionScoringFnParams + ContextRecallScoringFnParams: + type: object + properties: + aggregation_functions: + type: array + items: + $ref: '#/components/schemas/AggregationFunctionType' + description: >- + (Optional) Aggregation functions to apply to the scores of each row. If + not provided, no aggregation will be performed. + type: + type: string + const: context_recall + default: context_recall + additionalProperties: false + required: + - type + title: ContextRecallScoringFnParams + ContextRelevancyScoringFnParams: + type: object + properties: + aggregation_functions: + type: array + items: + $ref: '#/components/schemas/AggregationFunctionType' + description: >- + (Optional) Aggregation functions to apply to the scores of each row. If + not provided, no aggregation will be performed. + type: + type: string + const: context_relevancy + default: context_relevancy + additionalProperties: false + required: + - type + title: ContextRelevancyScoringFnParams + EqualityScoringFnParams: + type: object + properties: + aggregation_functions: + type: array + items: + $ref: '#/components/schemas/AggregationFunctionType' + description: >- + (Optional) Aggregation functions to apply to the scores of each row. If + not provided, no aggregation will be performed. + type: + type: string + const: equality + default: equality + additionalProperties: false + required: + - type + title: EqualityScoringFnParams EvalCandidate: oneOf: - $ref: '#/components/schemas/ModelCandidate' @@ -4474,6 +4598,42 @@ components: mapping: model: '#/components/schemas/ModelCandidate' agent: '#/components/schemas/AgentCandidate' + FactualityScoringFnParams: + type: object + properties: + aggregation_functions: + type: array + items: + $ref: '#/components/schemas/AggregationFunctionType' + description: >- + (Optional) Aggregation functions to apply to the scores of each row. If + not provided, no aggregation will be performed. + type: + type: string + const: factuality + default: factuality + additionalProperties: false + required: + - type + title: FactualityScoringFnParams + FaithfulnessScoringFnParams: + type: object + properties: + aggregation_functions: + type: array + items: + $ref: '#/components/schemas/AggregationFunctionType' + description: >- + (Optional) Aggregation functions to apply to the scores of each row. If + not provided, no aggregation will be performed. + type: + type: string + const: faithfulness + default: faithfulness + additionalProperties: false + required: + - type + title: FaithfulnessScoringFnParams LLMAsJudgeScoringFnParams: type: object properties: @@ -4483,31 +4643,21 @@ components: default: custom_llm_as_judge judge_model: type: string - description: The model to use for scoring. prompt_template: type: string - description: >- - (Optional) The prompt template to use for scoring. judge_score_regexes: type: array items: type: string - description: >- - (Optional) Regexes to extract the score from the judge model's response. aggregation_functions: type: array items: $ref: '#/components/schemas/AggregationFunctionType' - description: >- - (Optional) Aggregation functions to apply to the scores of each row. No - aggregation for results is calculated if not provided. additionalProperties: false required: - type - judge_model title: LLMAsJudgeScoringFnParams - description: >- - Parameters for a scoring function that uses a judge model to score the answer. ModelCandidate: type: object properties: @@ -4533,44 +4683,107 @@ components: - sampling_params title: ModelCandidate description: A model candidate for evaluation. - RegexParserScoringFnParams: + RegexParserMathScoringFnParams: type: object properties: - type: - type: string - const: regex_parser - default: regex_parser parsing_regexes: type: array items: type: string description: >- - Regexes to extract the answer from generated response + (Optional) Regexes to extract the answer from generated response. aggregation_functions: type: array items: $ref: '#/components/schemas/AggregationFunctionType' description: >- - (Optional) Aggregation functions to apply to the scores of each row. No - aggregation for results is calculated if not provided. + (Optional) Aggregation functions to apply to the scores of each row. If + not provided, no aggregation will be performed. + type: + type: string + const: regex_parser_math_response + default: regex_parser_math_response additionalProperties: false required: + - parsing_regexes + - type + title: RegexParserMathScoringFnParams + RegexParserScoringFnParams: + type: object + properties: + parsing_regexes: + type: array + items: + type: string + description: >- + (Optional) Regexes to extract the answer from generated response. + aggregation_functions: + type: array + items: + $ref: '#/components/schemas/AggregationFunctionType' + description: >- + (Optional) Aggregation functions to apply to the scores of each row. If + not provided, no aggregation will be performed. + type: + type: string + const: regex_parser + default: regex_parser + additionalProperties: false + required: + - parsing_regexes - type title: RegexParserScoringFnParams - description: >- - Parameters for a scoring function that parses the answer from the generated - response using regexes, and checks against the expected answer. ScoringFnParams: oneOf: - $ref: '#/components/schemas/LLMAsJudgeScoringFnParams' - $ref: '#/components/schemas/RegexParserScoringFnParams' - - $ref: '#/components/schemas/BasicScoringFnParams' + - $ref: '#/components/schemas/RegexParserMathScoringFnParams' + - $ref: '#/components/schemas/EqualityScoringFnParams' + - $ref: '#/components/schemas/SubsetOfcoringFnParams' + - $ref: '#/components/schemas/FactualityScoringFnParams' + - $ref: '#/components/schemas/FaithfulnessScoringFnParams' + - $ref: '#/components/schemas/AnswerCorrectnessScoringFnParams' + - $ref: '#/components/schemas/AnswerRelevancyScoringFnParams' + - $ref: '#/components/schemas/AnswerSimilarityScoringFnParams' + - $ref: '#/components/schemas/ContextEntityRecallScoringFnParams' + - $ref: '#/components/schemas/ContextPrecisionScoringFnParams' + - $ref: '#/components/schemas/ContextRecallScoringFnParams' + - $ref: '#/components/schemas/ContextRelevancyScoringFnParams' discriminator: propertyName: type mapping: custom_llm_as_judge: '#/components/schemas/LLMAsJudgeScoringFnParams' regex_parser: '#/components/schemas/RegexParserScoringFnParams' - basic: '#/components/schemas/BasicScoringFnParams' + regex_parser_math_response: '#/components/schemas/RegexParserMathScoringFnParams' + equality: '#/components/schemas/EqualityScoringFnParams' + subset_of: '#/components/schemas/SubsetOfcoringFnParams' + factuality: '#/components/schemas/FactualityScoringFnParams' + faithfulness: '#/components/schemas/FaithfulnessScoringFnParams' + answer_correctness: '#/components/schemas/AnswerCorrectnessScoringFnParams' + answer_relevancy: '#/components/schemas/AnswerRelevancyScoringFnParams' + answer_similarity: '#/components/schemas/AnswerSimilarityScoringFnParams' + context_entity_recall: '#/components/schemas/ContextEntityRecallScoringFnParams' + context_precision: '#/components/schemas/ContextPrecisionScoringFnParams' + context_recall: '#/components/schemas/ContextRecallScoringFnParams' + context_relevancy: '#/components/schemas/ContextRelevancyScoringFnParams' + SubsetOfcoringFnParams: + type: object + properties: + aggregation_functions: + type: array + items: + $ref: '#/components/schemas/AggregationFunctionType' + description: >- + (Optional) Aggregation functions to apply to the scores of each row. If + not provided, no aggregation will be performed. + type: + type: string + const: subset_of + default: subset_of + additionalProperties: false + required: + - type + title: SubsetOfcoringFnParams EvaluateRowsRequest: type: object properties: @@ -6364,6 +6577,7 @@ components: additionalProperties: false required: - scoring_fn_type + - params title: RegisterScoringFunctionRequest RegisterShieldRequest: type: object diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py index 85c5ad403..b2aa04a5f 100644 --- a/llama_stack/apis/scoring_functions/scoring_functions.py +++ b/llama_stack/apis/scoring_functions/scoring_functions.py @@ -68,110 +68,27 @@ class AggregationFunctionType(Enum): accuracy = "accuracy" -# TODO(xiyan): -# ============= OPTION 1: SEPARATE ScoringFnParamsType + ScoringFunctionType ============= -# class ScoringFnParamsType(Enum): -# """ -# A type of scoring function parameters. +class BasicScoringFnParamsCommon(BaseModel): + """ + :param aggregation_functions: (Optional) Aggregation functions to apply to the scores of each row. If not provided, no aggregation will be performed. + """ -# :cvar llm_as_judge: Provide judge model and prompt template. -# :cvar regex_parser: Provide regexes to parse the answer from the generated response. -# :cvar basic: Parameters for basic non-parameterized scoring function. -# """ - -# custom_llm_as_judge = "custom_llm_as_judge" -# regex_parser = "regex_parser" -# basic = "basic" - - -# @json_schema_type -# class LLMAsJudgeScoringFnParams(BaseModel): -# """ -# Parameters for a scoring function that uses a judge model to score the answer. - -# :param judge_model: The model to use for scoring. -# :param prompt_template: (Optional) The prompt template to use for scoring. -# :param judge_score_regexes: (Optional) Regexes to extract the score from the judge model's response. -# :param aggregation_functions: (Optional) Aggregation functions to apply to the scores of each row. No aggregation for results is calculated if not provided. -# """ - -# type: Literal["custom_llm_as_judge"] = "custom_llm_as_judge" -# judge_model: str -# prompt_template: Optional[str] = None -# judge_score_regexes: Optional[List[str]] = Field( -# description="Regexes to extract the answer from generated response", -# default_factory=list, -# ) -# aggregation_functions: Optional[List[AggregationFunctionType]] = Field( -# description="Aggregation functions to apply to the scores of each row", -# default_factory=list, -# ) - - -# @json_schema_type -# class RegexParserScoringFnParams(BaseModel): -# """ -# Parameters for a scoring function that parses the answer from the generated response using regexes, and checks against the expected answer. - -# :param parsing_regexes: Regexes to extract the answer from generated response -# :param aggregation_functions: (Optional) Aggregation functions to apply to the scores of each row. No aggregation for results is calculated if not provided. -# """ - -# type: Literal["regex_parser"] = "regex_parser" -# parsing_regexes: Optional[List[str]] = Field( -# description="Regexes to extract the answer from generated response", -# default_factory=list, -# ) -# aggregation_functions: Optional[List[AggregationFunctionType]] = Field( -# description="Aggregation functions to apply to the scores of each row", -# default_factory=list, -# ) - - -# @json_schema_type -# class BasicScoringFnParams(BaseModel): -# """ -# Parameters for a non-parameterized scoring function. - -# :param aggregation_functions: (Optional) Aggregation functions to apply to the scores of each row. No aggregation for results is calculated if not provided. -# """ - -# type: Literal["basic"] = "basic" -# aggregation_functions: Optional[List[AggregationFunctionType]] = Field( -# description="Aggregation functions to apply to the scores of each row", -# default_factory=list, -# ) - - -# ScoringFnParams = register_schema( -# Annotated[ -# Union[ -# LLMAsJudgeScoringFnParams, -# RegexParserScoringFnParams, -# BasicScoringFnParams, -# ], -# Field(discriminator="type"), -# ], -# name="ScoringFnParams", -# ) - -# ============= END OF OPTION 1 ============= - - -# TODO(xiyan): -# ============= OPTION 2: MERGE ScoringFnParamsType + ScoringFunctionType into ScoringFunctionType ============= -class RegexParserScoringFnParamsCommon(BaseModel): - parsing_regexes: Optional[List[str]] = Field( - description="Regexes to extract the answer from generated response", - default_factory=list, - ) aggregation_functions: Optional[List[AggregationFunctionType]] = Field( description="Aggregation functions to apply to the scores of each row", default_factory=list, ) -class BasicScoringFnParamsCommon(BaseModel): +class RegexParserScoringFnParamsCommon(BaseModel): + """ + :param parsing_regexes: (Optional) Regexes to extract the answer from generated response. + :param aggregation_functions: (Optional) Aggregation functions to apply to the scores of each row. If not provided, no aggregation will be performed. + """ + + parsing_regexes: List[str] = Field( + description="Regexes to extract the answer from generated response", + default_factory=list, + ) aggregation_functions: Optional[List[AggregationFunctionType]] = Field( description="Aggregation functions to apply to the scores of each row", default_factory=list, @@ -198,6 +115,51 @@ class SubsetOfcoringFnParams(BasicScoringFnParamsCommon): type: Literal["subset_of"] = "subset_of" +@json_schema_type +class FactualityScoringFnParams(BasicScoringFnParamsCommon): + type: Literal["factuality"] = "factuality" + + +@json_schema_type +class FaithfulnessScoringFnParams(BasicScoringFnParamsCommon): + type: Literal["faithfulness"] = "faithfulness" + + +@json_schema_type +class AnswerCorrectnessScoringFnParams(BasicScoringFnParamsCommon): + type: Literal["answer_correctness"] = "answer_correctness" + + +@json_schema_type +class AnswerRelevancyScoringFnParams(BasicScoringFnParamsCommon): + type: Literal["answer_relevancy"] = "answer_relevancy" + + +@json_schema_type +class AnswerSimilarityScoringFnParams(BasicScoringFnParamsCommon): + type: Literal["answer_similarity"] = "answer_similarity" + + +@json_schema_type +class ContextEntityRecallScoringFnParams(BasicScoringFnParamsCommon): + type: Literal["context_entity_recall"] = "context_entity_recall" + + +@json_schema_type +class ContextPrecisionScoringFnParams(BasicScoringFnParamsCommon): + type: Literal["context_precision"] = "context_precision" + + +@json_schema_type +class ContextRecallScoringFnParams(BasicScoringFnParamsCommon): + type: Literal["context_recall"] = "context_recall" + + +@json_schema_type +class ContextRelevancyScoringFnParams(BasicScoringFnParamsCommon): + type: Literal["context_relevancy"] = "context_relevancy" + + @json_schema_type class LLMAsJudgeScoringFnParams(BaseModel): type: Literal["custom_llm_as_judge"] = "custom_llm_as_judge" @@ -221,6 +183,15 @@ ScoringFnParams = register_schema( RegexParserMathScoringFnParams, EqualityScoringFnParams, SubsetOfcoringFnParams, + FactualityScoringFnParams, + FaithfulnessScoringFnParams, + AnswerCorrectnessScoringFnParams, + AnswerRelevancyScoringFnParams, + AnswerSimilarityScoringFnParams, + ContextEntityRecallScoringFnParams, + ContextPrecisionScoringFnParams, + ContextRecallScoringFnParams, + ContextRelevancyScoringFnParams, ], Field(discriminator="type"), ], @@ -284,9 +255,8 @@ class ScoringFunctions(Protocol): @webmethod(route="/scoring-functions", method="POST") async def register_scoring_function( self, - # TODO(xiyan): scoring_fn_type will not be needed for OPTION 2 - # scoring_fn_type: ScoringFunctionType, - params: Optional[ScoringFnParams] = None, + scoring_fn_type: ScoringFunctionType, + params: ScoringFnParams = None, scoring_fn_id: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, ) -> ScoringFn: @@ -294,7 +264,7 @@ class ScoringFunctions(Protocol): Register a new scoring function with given parameters. Only valid scoring function type that can be parameterized can be registered. - # :param scoring_fn_type: The type of scoring function to register. + :param scoring_fn_type: The type of scoring function to register. :param params: The parameters for the scoring function. :param scoring_fn_id: (Optional) The ID of the scoring function to register. If not provided, a random ID will be generated. :param metadata: (Optional) Any additional metadata to be associated with the scoring function.