This commit is contained in:
Xi Yan 2025-03-12 21:23:20 -07:00
parent 6408bdbc9d
commit 20cdcd87a3

View file

@ -21,6 +21,7 @@ from typing_extensions import Annotated
from llama_stack.apis.resource import Resource, ResourceType from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
from llama_stack.apis.datasets import DatasetPurpose
# Perhaps more structure can be imposed on these functions. Maybe they could be associated # Perhaps more structure can be imposed on these functions. Maybe they could be associated
# with standard metrics so they can be rolled up? # with standard metrics so they can be rolled up?
@ -50,7 +51,6 @@ class ScoringFunctionType(Enum):
context_relevancy = "context_relevancy" context_relevancy = "context_relevancy"
@json_schema_type
class AggregationFunctionType(Enum): class AggregationFunctionType(Enum):
""" """
A type of aggregation function. A type of aggregation function.
@ -200,16 +200,10 @@ ScoringFnParams = register_schema(
class CommonScoringFnFields(BaseModel): class CommonScoringFnFields(BaseModel):
""" """
:param scoring_fn_type: The type of scoring function. :param fn: The scoring function type and parameters.
:param params: (Optional) The parameters for the scoring function.
:param metadata: (Optional) Any additional metadata for this definition (e.g. description). :param metadata: (Optional) Any additional metadata for this definition (e.g. description).
""" """
fn: ScoringFnParams
scoring_fn_type: ScoringFunctionType
params: Optional[ScoringFnParams] = Field(
description="The parameters for the scoring function.",
default=None,
)
metadata: Dict[str, Any] = Field( metadata: Dict[str, Any] = Field(
default_factory=dict, default_factory=dict,
description="Any additional metadata for this definition (e.g. description)", description="Any additional metadata for this definition (e.g. description)",
@ -229,6 +223,22 @@ class ScoringFn(CommonScoringFnFields, Resource):
return self.provider_resource_id return self.provider_resource_id
@json_schema_type
class ScoringFnTypeInfo(BaseModel):
"""
:param type: The type of scoring function.
:param description: A description of the scoring function type.
- E.g. Write your custom judge prompt to score the answer.
:param supported_purposes: The purposes that this scoring function can be used for.
"""
type: ScoringFunctionType
description: str
supported_purposes: List[DatasetPurpose] = Field(
description="The purposes that this scoring function can be used for",
default_factory=list,
)
class ScoringFnInput(CommonScoringFnFields, BaseModel): class ScoringFnInput(CommonScoringFnFields, BaseModel):
scoring_fn_id: str scoring_fn_id: str
provider_id: Optional[str] = None provider_id: Optional[str] = None
@ -239,16 +249,36 @@ class ListScoringFunctionsResponse(BaseModel):
data: List[ScoringFn] data: List[ScoringFn]
class ListScoringFunctionTypesResponse(BaseModel):
data: List[ScoringFnTypeInfo]
@runtime_checkable @runtime_checkable
class ScoringFunctions(Protocol): class ScoringFunctions(Protocol):
@webmethod(route="/scoring-functions", method="GET") @webmethod(route="/scoring-functions", method="GET")
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ... async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
"""
List all registered scoring functions.
"""
...
@webmethod(route="/scoring-functions/types", method="GET")
async def list_scoring_function_types(self) -> ListScoringFunctionTypesResponse:
"""
List all available scoring function types information and how to use them.
"""
...
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET") @webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET")
async def get_scoring_function( async def get_scoring_function(
self, self,
scoring_fn_id: str, scoring_fn_id: str,
) -> Optional[ScoringFn]: ... ) -> Optional[ScoringFn]:
"""
Get a scoring function by its ID.
:param scoring_fn_id: The ID of the scoring function to get.
"""
...
@webmethod(route="/scoring-functions", method="POST") @webmethod(route="/scoring-functions", method="POST")
async def register_scoring_function( async def register_scoring_function(
@ -267,3 +297,14 @@ class ScoringFunctions(Protocol):
- E.g. {"description": "This scoring function is used for ..."} - E.g. {"description": "This scoring function is used for ..."}
""" """
... ...
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="DELETE")
async def unregister_scoring_function(
self,
scoring_fn_id: str,
) -> None:
"""
Unregister a scoring function by its ID.
:param scoring_fn_id: The ID of the scoring function to unregister.
"""
...