diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py index 5888f08f5..a242215c6 100644 --- a/llama_stack/apis/scoring_functions/scoring_functions.py +++ b/llama_stack/apis/scoring_functions/scoring_functions.py @@ -4,20 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import ( - Any, - Dict, - List, - Literal, - Optional, - Protocol, - runtime_checkable, - Union, -) +from typing import Any, Dict, List, Optional, Protocol, runtime_checkable from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field -from typing_extensions import Annotated from llama_stack.apis.common.type_system import ParamType @@ -33,21 +23,19 @@ class Parameter(BaseModel): # with standard metrics so they can be rolled up? +class LLMAsJudgeContext(BaseModel): + judge_model: str + prompt_template: Optional[str] = None + + @json_schema_type -class CommonFunctionDef(BaseModel): +class ScoringFunctionDef(BaseModel): identifier: str description: Optional[str] = None metadata: Dict[str, Any] = Field( default_factory=dict, description="Any additional metadata for this definition", ) - # Hack: same with memory_banks for union defs - provider_id: str = "" - - -@json_schema_type -class DeterministicFunctionDef(CommonFunctionDef): - type: Literal["deterministic"] = "deterministic" parameters: List[Parameter] = Field( description="List of parameters for the deterministic function", default_factory=list, @@ -55,24 +43,17 @@ class DeterministicFunctionDef(CommonFunctionDef): return_type: ParamType = Field( description="The return type of the deterministic function", ) + context: Optional[LLMAsJudgeContext] = None # We can optionally add information here to support packaging of code, etc. @json_schema_type -class LLMJudgeFunctionDef(CommonFunctionDef): - type: Literal["judge"] = "judge" - model: str = Field( - description="The LLM model to use for the judge function", +class ScoringFunctionDefWithProvider(ScoringFunctionDef): + provider_id: str = Field( + description="ID of the provider which serves this dataset", ) -ScoringFunctionDef = Annotated[ - Union[DeterministicFunctionDef, LLMJudgeFunctionDef], Field(discriminator="type") -] - -ScoringFunctionDefWithProvider = ScoringFunctionDef - - @runtime_checkable class ScoringFunctions(Protocol): @webmethod(route="/scoring_functions/list", method="GET") diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 10b39e522..dcd588a9e 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -95,17 +95,15 @@ class CommonRoutingTableImpl(RoutingTable): for d in datasets: d.provider_id = pid - add_objects(datasets) - elif api == Api.scoring: p.scoring_function_store = self scoring_functions = await p.list_scoring_functions() - - # do in-memory updates due to pesky Annotated unions - for s in scoring_functions: - s.provider_id = pid - - add_objects(scoring_functions) + add_objects( + [ + ScoringFunctionDefWithProvider(**s.dict(), provider_id=pid) + for s in scoring_functions + ] + ) async def shutdown(self) -> None: for p in self.impls_by_provider_id.values(): diff --git a/llama_stack/providers/impls/meta_reference/scoring/scorer/base_scorer.py b/llama_stack/providers/impls/meta_reference/scoring/scorer/base_scorer.py index 9c982948e..ea8a3f063 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scorer/base_scorer.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scorer/base_scorer.py @@ -17,7 +17,7 @@ class BaseScorer(ABC): - aggregate(self, scorer_results) """ - scoring_function_def: DeterministicFunctionDef + scoring_function_def: ScoringFunctionDef def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) diff --git a/llama_stack/providers/impls/meta_reference/scoring/scorer/equality_scorer.py b/llama_stack/providers/impls/meta_reference/scoring/scorer/equality_scorer.py index b9b8b1eee..ce765bfb5 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scorer/equality_scorer.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scorer/equality_scorer.py @@ -17,7 +17,7 @@ class EqualityScorer(BaseScorer): A scorer that assigns a score of 1.0 if the input string matches the target string, and 0.0 otherwise. """ - scoring_function_def = DeterministicFunctionDef( + scoring_function_def = ScoringFunctionDef( identifier="equality", description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.", parameters=[],