This commit is contained in:
Xi Yan 2025-03-13 15:35:09 -07:00
parent 2cf769e05e
commit 819ffe0518
3 changed files with 190 additions and 176 deletions

View file

@ -12,16 +12,17 @@ from typing import (
Literal,
Optional,
Protocol,
Union,
runtime_checkable,
Union,
)
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.apis.datasets import DatasetPurpose
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
from llama_stack.apis.datasets import DatasetPurpose
# Perhaps more structure can be imposed on these functions. Maybe they could be associated
# with standard metrics so they can be rolled up?
@ -93,6 +94,7 @@ class RegexParserScoringFnParams(BaseModel):
default_factory=list,
)
class CustomLLMAsJudgeScoringFnParams(BaseModel):
type: Literal["custom_llm_as_judge"] = "custom_llm_as_judge"
judge_model: str
@ -102,6 +104,7 @@ class CustomLLMAsJudgeScoringFnParams(BaseModel):
default_factory=list,
)
@json_schema_type
class RegexParserScoringFn(BaseModel):
type: Literal["regex_parser"] = "regex_parser"
@ -113,36 +116,43 @@ class RegexParserMathScoringFn(BaseModel):
type: Literal["regex_parser_math_response"] = "regex_parser_math_response"
regex_parser_math_response: RegexParserScoringFnParams
@json_schema_type
class EqualityScoringFn(BaseModel):
type: Literal["equality"] = "equality"
equality: BasicScoringFnParams
@json_schema_type
class SubsetOfScoringFn(BaseModel):
type: Literal["subset_of"] = "subset_of"
subset_of: BasicScoringFnParams
@json_schema_type
class FactualityScoringFn(BaseModel):
type: Literal["factuality"] = "factuality"
factuality: BasicScoringFnParams
@json_schema_type
class FaithfulnessScoringFn(BaseModel):
type: Literal["faithfulness"] = "faithfulness"
faithfulness: BasicScoringFnParams
@json_schema_type
class AnswerCorrectnessScoringFn(BaseModel):
type: Literal["answer_correctness"] = "answer_correctness"
answer_correctness: BasicScoringFnParams
@json_schema_type
class AnswerRelevancyScoringFn(BaseModel):
type: Literal["answer_relevancy"] = "answer_relevancy"
answer_relevancy: BasicScoringFnParams
@json_schema_type
class AnswerSimilarityScoringFn(BaseModel):
type: Literal["answer_similarity"] = "answer_similarity"
@ -205,9 +215,10 @@ ScoringFnDefinition = register_schema(
class CommonScoringFnFields(BaseModel):
"""
:param fn: The scoring function type and parameters.
:param fn: The scoring function type and parameters.
:param metadata: (Optional) Any additional metadata for this definition (e.g. description).
"""
fn: ScoringFnDefinition
metadata: Dict[str, Any] = Field(
default_factory=dict,
@ -217,7 +228,9 @@ class CommonScoringFnFields(BaseModel):
@json_schema_type
class ScoringFn(CommonScoringFnFields, Resource):
type: Literal[ResourceType.scoring_function.value] = ResourceType.scoring_function.value
type: Literal[ResourceType.scoring_function.value] = (
ResourceType.scoring_function.value
)
@property
def scoring_fn_id(self) -> str:
@ -231,14 +244,15 @@ class ScoringFn(CommonScoringFnFields, Resource):
@json_schema_type
class ScoringFnTypeInfo(BaseModel):
"""
:param type: The type of scoring function.
:param description: A description of the scoring function type.
- E.g. Write your custom judge prompt to score the answer.
:param supported_purposes: The purposes that this scoring function can be used for.
:param type: The type of scoring function.
:param description: A description of the scoring function type.
- E.g. Write your custom judge prompt to score the answer.
:param supported_dataset_purposes: The purposes that this scoring function can be used for.
"""
type: ScoringFunctionType
description: str
supported_purposes: List[DatasetPurpose] = Field(
supported_dataset_purposes: List[DatasetPurpose] = Field(
description="The supported purposes (supported dataset schema) that this scoring function can be used for. E.g. eval/question-answer",
default_factory=list,
)
@ -261,16 +275,16 @@ class ListScoringFunctionTypesResponse(BaseModel):
@runtime_checkable
class ScoringFunctions(Protocol):
@webmethod(route="/scoring-functions", method="GET")
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
"""
List all registered scoring functions.
"""
...
@webmethod(route="/scoring-functions/types", method="GET")
async def list_scoring_function_types(self) -> ListScoringFunctionTypesResponse:
async def list_scoring_function_types(self) -> ListScoringFunctionTypesResponse:
"""
List all available scoring function types information and how to use them.
List all available scoring function types information and how to use them.
"""
...
@ -278,7 +292,7 @@ class ScoringFunctions(Protocol):
async def get_scoring_function(
self,
scoring_fn_id: str,
) -> Optional[ScoringFn]:
) -> Optional[ScoringFn]:
"""
Get a scoring function by its ID.
:param scoring_fn_id: The ID of the scoring function to get.
@ -302,12 +316,12 @@ class ScoringFunctions(Protocol):
- E.g. {"description": "This scoring function is used for ..."}
"""
...
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="DELETE")
async def unregister_scoring_function(
self,
scoring_fn_id: str,
) -> None:
) -> None:
"""
Unregister a scoring function by its ID.
:param scoring_fn_id: The ID of the scoring function to unregister.