From 20cdcd87a36198e28d954422df99942c96c15489 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 12 Mar 2025 21:23:20 -0700 Subject: [PATCH] purpose --- .../scoring_functions/scoring_functions.py | 63 +++++++++++++++---- 1 file changed, 52 insertions(+), 11 deletions(-) diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py index 70e33e4f4..35c0dc9d1 100644 --- a/llama_stack/apis/scoring_functions/scoring_functions.py +++ b/llama_stack/apis/scoring_functions/scoring_functions.py @@ -21,6 +21,7 @@ from typing_extensions import Annotated from llama_stack.apis.resource import Resource, ResourceType from llama_stack.schema_utils import json_schema_type, register_schema, webmethod +from llama_stack.apis.datasets import DatasetPurpose # Perhaps more structure can be imposed on these functions. Maybe they could be associated # with standard metrics so they can be rolled up? @@ -50,7 +51,6 @@ class ScoringFunctionType(Enum): context_relevancy = "context_relevancy" -@json_schema_type class AggregationFunctionType(Enum): """ A type of aggregation function. @@ -200,16 +200,10 @@ ScoringFnParams = register_schema( class CommonScoringFnFields(BaseModel): """ - :param scoring_fn_type: The type of scoring function. - :param params: (Optional) The parameters for the scoring function. + :param fn: The scoring function type and parameters. :param metadata: (Optional) Any additional metadata for this definition (e.g. description). """ - - scoring_fn_type: ScoringFunctionType - params: Optional[ScoringFnParams] = Field( - description="The parameters for the scoring function.", - default=None, - ) + fn: ScoringFnParams metadata: Dict[str, Any] = Field( default_factory=dict, description="Any additional metadata for this definition (e.g. description)", @@ -229,6 +223,22 @@ class ScoringFn(CommonScoringFnFields, Resource): return self.provider_resource_id +@json_schema_type +class ScoringFnTypeInfo(BaseModel): + """ + :param type: The type of scoring function. + :param description: A description of the scoring function type. + - E.g. Write your custom judge prompt to score the answer. + :param supported_purposes: The purposes that this scoring function can be used for. + """ + type: ScoringFunctionType + description: str + supported_purposes: List[DatasetPurpose] = Field( + description="The purposes that this scoring function can be used for", + default_factory=list, + ) + + class ScoringFnInput(CommonScoringFnFields, BaseModel): scoring_fn_id: str provider_id: Optional[str] = None @@ -239,16 +249,36 @@ class ListScoringFunctionsResponse(BaseModel): data: List[ScoringFn] +class ListScoringFunctionTypesResponse(BaseModel): + data: List[ScoringFnTypeInfo] + + @runtime_checkable class ScoringFunctions(Protocol): @webmethod(route="/scoring-functions", method="GET") - async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ... + async def list_scoring_functions(self) -> ListScoringFunctionsResponse: + """ + List all registered scoring functions. + """ + ... + + @webmethod(route="/scoring-functions/types", method="GET") + async def list_scoring_function_types(self) -> ListScoringFunctionTypesResponse: + """ + List all available scoring function types information and how to use them. + """ + ... @webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET") async def get_scoring_function( self, scoring_fn_id: str, - ) -> Optional[ScoringFn]: ... + ) -> Optional[ScoringFn]: + """ + Get a scoring function by its ID. + :param scoring_fn_id: The ID of the scoring function to get. + """ + ... @webmethod(route="/scoring-functions", method="POST") async def register_scoring_function( @@ -267,3 +297,14 @@ class ScoringFunctions(Protocol): - E.g. {"description": "This scoring function is used for ..."} """ ... + + @webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="DELETE") + async def unregister_scoring_function( + self, + scoring_fn_id: str, + ) -> None: + """ + Unregister a scoring function by its ID. + :param scoring_fn_id: The ID of the scoring function to unregister. + """ + ...