diff --git a/llama_stack/apis/scoring/scoring.py b/llama_stack/apis/scoring/scoring.py index c2bfdcd23..2c643a28e 100644 --- a/llama_stack/apis/scoring/scoring.py +++ b/llama_stack/apis/scoring/scoring.py @@ -37,7 +37,7 @@ class ScoreResponse(BaseModel): class ScoringFunctionStore(Protocol): - def get_scoring_function(self, name: str) -> ScoringFnDefWithProvider: ... + def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn: ... @runtime_checkable diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py index 140376242..6b2408e0d 100644 --- a/llama_stack/apis/scoring_functions/scoring_functions.py +++ b/llama_stack/apis/scoring_functions/scoring_functions.py @@ -22,19 +22,21 @@ from typing_extensions import Annotated from llama_stack.apis.common.type_system import ParamType +from llama_stack.apis.resource import Resource, ResourceType + # Perhaps more structure can be imposed on these functions. Maybe they could be associated # with standard metrics so they can be rolled up? @json_schema_type -class ScoringConfigType(Enum): +class ScoringFnParamsType(Enum): llm_as_judge = "llm_as_judge" regex_parser = "regex_parser" @json_schema_type class LLMAsJudgeScoringFnParams(BaseModel): - type: Literal[ScoringConfigType.llm_as_judge.value] = ( - ScoringConfigType.llm_as_judge.value + type: Literal[ScoringFnParamsType.llm_as_judge.value] = ( + ScoringFnParamsType.llm_as_judge.value ) judge_model: str prompt_template: Optional[str] = None @@ -46,8 +48,8 @@ class LLMAsJudgeScoringFnParams(BaseModel): @json_schema_type class RegexParserScoringFnParams(BaseModel): - type: Literal[ScoringConfigType.regex_parser.value] = ( - ScoringConfigType.regex_parser.value + type: Literal[ScoringFnParamsType.regex_parser.value] = ( + ScoringFnParamsType.regex_parser.value ) parsing_regexes: Optional[List[str]] = Field( description="Regex to extract the answer from generated response", @@ -65,8 +67,10 @@ ScoringFnParams = Annotated[ @json_schema_type -class ScoringFnDef(BaseModel): - identifier: str +class ScoringFn(Resource): + type: Literal[ResourceType.scoring_function.value] = ( + ResourceType.scoring_function.value + ) description: Optional[str] = None metadata: Dict[str, Any] = Field( default_factory=dict, @@ -79,28 +83,23 @@ class ScoringFnDef(BaseModel): description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval", default=None, ) - # We can optionally add information here to support packaging of code, etc. - - -@json_schema_type -class ScoringFnDefWithProvider(ScoringFnDef): - type: Literal["scoring_fn"] = "scoring_fn" - provider_id: str = Field( - description="ID of the provider which serves this dataset", - ) @runtime_checkable class ScoringFunctions(Protocol): @webmethod(route="/scoring_functions/list", method="GET") - async def list_scoring_functions(self) -> List[ScoringFnDefWithProvider]: ... + async def list_scoring_functions(self) -> List[ScoringFn]: ... @webmethod(route="/scoring_functions/get", method="GET") - async def get_scoring_function( - self, name: str - ) -> Optional[ScoringFnDefWithProvider]: ... + async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]: ... @webmethod(route="/scoring_functions/register", method="POST") async def register_scoring_function( - self, function_def: ScoringFnDefWithProvider + self, + scoring_fn_id: str, + description: str, + return_type: ParamType, + provider_scoring_fn_id: Optional[str] = None, + provider_id: Optional[str] = None, + params: Optional[ScoringFnParams] = None, ) -> None: ... diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index 9098f4331..51b56dd5f 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -35,7 +35,7 @@ RoutableObject = Union[ Shield, MemoryBank, Dataset, - ScoringFnDef, + ScoringFn, ] @@ -45,7 +45,7 @@ RoutableObjectWithProvider = Annotated[ Shield, MemoryBank, Dataset, - ScoringFnDefWithProvider, + ScoringFn, ], Field(discriminator="type"), ] diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index b0091f5a0..efed54ab8 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -81,7 +81,10 @@ class CommonRoutingTableImpl(RoutingTable): # so we should just override the provider in-place obj.provider_id = provider_id else: - obj = cls(**obj.model_dump(), provider_id=provider_id) + # Create a copy of the model data and explicitly set provider_id + model_data = obj.model_dump() + model_data["provider_id"] = provider_id + obj = cls(**model_data) await self.dist_registry.register(obj) # Register all objects from providers @@ -101,7 +104,7 @@ class CommonRoutingTableImpl(RoutingTable): elif api == Api.scoring: p.scoring_function_store = self scoring_functions = await p.list_scoring_functions() - await add_objects(scoring_functions, pid, ScoringFnDefWithProvider) + await add_objects(scoring_functions, pid, ScoringFn) elif api == Api.eval: p.eval_task_store = self @@ -340,18 +343,41 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): - async def list_scoring_functions(self) -> List[ScoringFnDefWithProvider]: - return await self.get_all_with_type("scoring_fn") + async def list_scoring_functions(self) -> List[ScoringFn]: + return await self.get_all_with_type(ResourceType.scoring_function.value) - async def get_scoring_function( - self, name: str - ) -> Optional[ScoringFnDefWithProvider]: - return await self.get_object_by_identifier(name) + async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]: + return await self.get_object_by_identifier(scoring_fn_id) async def register_scoring_function( - self, function_def: ScoringFnDefWithProvider + self, + scoring_fn_id: str, + description: str, + return_type: ParamType, + provider_scoring_fn_id: Optional[str] = None, + provider_id: Optional[str] = None, + params: Optional[ScoringFnParams] = None, ) -> None: - await self.register_object(function_def) + if params is None: + params = {} + if provider_scoring_fn_id is None: + provider_scoring_fn_id = scoring_fn_id + if provider_id is None: + if len(self.impls_by_provider_id) == 1: + provider_id = list(self.impls_by_provider_id.keys())[0] + else: + raise ValueError( + "No provider specified and multiple providers available. Please specify a provider_id." + ) + scoring_fn = ScoringFn( + identifier=scoring_fn_id, + description=description, + return_type=return_type, + provider_resource_id=provider_scoring_fn_id, + params=params, + ) + scoring_fn.provider_id = provider_id + await self.register_object(scoring_fn) class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks): diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index f065d4f33..5a259ae2d 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -15,7 +15,7 @@ from llama_stack.apis.datasets import Dataset from llama_stack.apis.eval_tasks import EvalTask from llama_stack.apis.memory_banks.memory_banks import MemoryBank from llama_stack.apis.models import Model -from llama_stack.apis.scoring_functions import ScoringFnDef +from llama_stack.apis.scoring_functions import ScoringFn from llama_stack.apis.shields import Shield @@ -61,9 +61,9 @@ class DatasetsProtocolPrivate(Protocol): class ScoringFunctionsProtocolPrivate(Protocol): - async def list_scoring_functions(self) -> List[ScoringFnDef]: ... + async def list_scoring_functions(self) -> List[ScoringFn]: ... - async def register_scoring_function(self, function_def: ScoringFnDef) -> None: ... + async def register_scoring_function(self, scoring_fn: ScoringFn) -> None: ... class EvalTasksProtocolPrivate(Protocol): diff --git a/llama_stack/providers/inline/scoring/braintrust/braintrust.py b/llama_stack/providers/inline/scoring/braintrust/braintrust.py index 57723bb47..9105a4978 100644 --- a/llama_stack/providers/inline/scoring/braintrust/braintrust.py +++ b/llama_stack/providers/inline/scoring/braintrust/braintrust.py @@ -48,7 +48,7 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): async def shutdown(self) -> None: ... - async def list_scoring_functions(self) -> List[ScoringFnDef]: + async def list_scoring_functions(self) -> List[ScoringFn]: scoring_fn_defs_list = [x for x in self.supported_fn_defs_registry.values()] for f in scoring_fn_defs_list: assert f.identifier.startswith( @@ -57,7 +57,7 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): return scoring_fn_defs_list - async def register_scoring_function(self, function_def: ScoringFnDef) -> None: + async def register_scoring_function(self, scoring_fn: ScoringFn) -> None: raise NotImplementedError( "Registering scoring function not allowed for braintrust provider" ) diff --git a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_correctness.py b/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_correctness.py index ca6a46d0e..554590f12 100644 --- a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_correctness.py +++ b/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_correctness.py @@ -5,12 +5,14 @@ # the root directory of this source tree. from llama_stack.apis.common.type_system import NumberType -from llama_stack.apis.scoring_functions import ScoringFnDef +from llama_stack.apis.scoring_functions import ScoringFn -answer_correctness_fn_def = ScoringFnDef( +answer_correctness_fn_def = ScoringFn( identifier="braintrust::answer-correctness", description="Test whether an output is factual, compared to an original (`expected`) value. One of Braintrust LLM basd scorer https://github.com/braintrustdata/autoevals/blob/main/py/autoevals/llm.py", - parameters=[], + params=None, + provider_id="braintrust", + provider_resource_id="answer-correctness", return_type=NumberType(), ) diff --git a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/factuality.py b/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/factuality.py index cbf9cd01c..b733f10c8 100644 --- a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/factuality.py +++ b/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/factuality.py @@ -5,12 +5,14 @@ # the root directory of this source tree. from llama_stack.apis.common.type_system import NumberType -from llama_stack.apis.scoring_functions import ScoringFnDef +from llama_stack.apis.scoring_functions import ScoringFn -factuality_fn_def = ScoringFnDef( +factuality_fn_def = ScoringFn( identifier="braintrust::factuality", description="Test whether an output is factual, compared to an original (`expected`) value. One of Braintrust LLM basd scorer https://github.com/braintrustdata/autoevals/blob/main/py/autoevals/llm.py", - parameters=[], + params=None, + provider_id="braintrust", + provider_resource_id="factuality", return_type=NumberType(), ) diff --git a/llama_stack/providers/inline/scoring/meta_reference/scoring.py b/llama_stack/providers/inline/scoring/meta_reference/scoring.py index 6370ea5e5..b78379062 100644 --- a/llama_stack/providers/inline/scoring/meta_reference/scoring.py +++ b/llama_stack/providers/inline/scoring/meta_reference/scoring.py @@ -52,7 +52,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): async def shutdown(self) -> None: ... - async def list_scoring_functions(self) -> List[ScoringFnDef]: + async def list_scoring_functions(self) -> List[ScoringFn]: scoring_fn_defs_list = [ fn_def for impl in self.scoring_fn_id_impls.values() @@ -66,7 +66,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): return scoring_fn_defs_list - async def register_scoring_function(self, function_def: ScoringFnDef) -> None: + async def register_scoring_function(self, function_def: ScoringFn) -> None: raise NotImplementedError("Register scoring function not implemented yet") async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None: diff --git a/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/base_scoring_fn.py b/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/base_scoring_fn.py index 532686ebd..e356bc289 100644 --- a/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/base_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/base_scoring_fn.py @@ -24,15 +24,15 @@ class BaseScoringFn(ABC): def __str__(self) -> str: return self.__class__.__name__ - def get_supported_scoring_fn_defs(self) -> List[ScoringFnDef]: + def get_supported_scoring_fn_defs(self) -> List[ScoringFn]: return [x for x in self.supported_fn_defs_registry.values()] - def register_scoring_fn_def(self, scoring_fn_def: ScoringFnDef) -> None: - if scoring_fn_def.identifier in self.supported_fn_defs_registry: + def register_scoring_fn_def(self, scoring_fn: ScoringFn) -> None: + if scoring_fn.identifier in self.supported_fn_defs_registry: raise ValueError( - f"Scoring function def with identifier {scoring_fn_def.identifier} already exists." + f"Scoring function def with identifier {scoring_fn.identifier} already exists." ) - self.supported_fn_defs_registry[scoring_fn_def.identifier] = scoring_fn_def + self.supported_fn_defs_registry[scoring_fn.identifier] = scoring_fn @abstractmethod async def score_row( diff --git a/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/fn_defs/equality.py b/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/fn_defs/equality.py index b54bf7ae8..b3fbb5d2f 100644 --- a/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/fn_defs/equality.py +++ b/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/fn_defs/equality.py @@ -5,11 +5,14 @@ # the root directory of this source tree. from llama_stack.apis.common.type_system import NumberType -from llama_stack.apis.scoring_functions import ScoringFnDef +from llama_stack.apis.scoring_functions import ScoringFn -equality = ScoringFnDef( +equality = ScoringFn( identifier="meta-reference::equality", description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.", + params=None, + provider_id="meta-reference", + provider_resource_id="equality", return_type=NumberType(), ) diff --git a/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/fn_defs/llm_as_judge_base.py b/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/fn_defs/llm_as_judge_base.py index 69d96e1bf..ad07ea1b8 100644 --- a/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/fn_defs/llm_as_judge_base.py +++ b/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/fn_defs/llm_as_judge_base.py @@ -5,11 +5,13 @@ # the root directory of this source tree. from llama_stack.apis.common.type_system import NumberType -from llama_stack.apis.scoring_functions import ScoringFnDef +from llama_stack.apis.scoring_functions import ScoringFn -llm_as_judge_base = ScoringFnDef( +llm_as_judge_base = ScoringFn( identifier="meta-reference::llm_as_judge_base", description="Llm As Judge Scoring Function", return_type=NumberType(), + provider_id="meta-reference", + provider_resource_id="llm-as-judge-base", ) diff --git a/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/fn_defs/regex_parser_multiple_choice_answer.py b/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/fn_defs/regex_parser_multiple_choice_answer.py index 84e518887..20b59c273 100644 --- a/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/fn_defs/regex_parser_multiple_choice_answer.py +++ b/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/fn_defs/regex_parser_multiple_choice_answer.py @@ -56,10 +56,12 @@ MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = ( r"(?i){}\s*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[A]|[B]|[C]|[D])" ) -regex_parser_multiple_choice_answer = ScoringFnDef( +regex_parser_multiple_choice_answer = ScoringFn( identifier="meta-reference::regex_parser_multiple_choice_answer", description="Extract answer from response matching Answer: [the_answer_letter], and compare with expected result", return_type=NumberType(), + provider_id="meta-reference", + provider_resource_id="regex-parser-multiple-choice-answer", params=RegexParserScoringFnParams( parsing_regexes=[ MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(x) diff --git a/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/fn_defs/subset_of.py b/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/fn_defs/subset_of.py index 5a3e2e8fb..b2759f3ee 100644 --- a/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/fn_defs/subset_of.py +++ b/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/fn_defs/subset_of.py @@ -5,12 +5,13 @@ # the root directory of this source tree. from llama_stack.apis.common.type_system import NumberType -from llama_stack.apis.scoring_functions import ScoringFnDef +from llama_stack.apis.scoring_functions import ScoringFn -subset_of = ScoringFnDef( +subset_of = ScoringFn( identifier="meta-reference::subset_of", description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.", - parameters=[], return_type=NumberType(), + provider_id="meta-reference", + provider_resource_id="subset-of", ) diff --git a/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/regex_parser_scoring_fn.py b/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/regex_parser_scoring_fn.py index 3cbc6cbe4..33773b7bb 100644 --- a/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/regex_parser_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/regex_parser_scoring_fn.py @@ -42,7 +42,7 @@ class RegexParserScoringFn(BaseScoringFn): assert ( fn_def.params is not None - and fn_def.params.type == ScoringConfigType.regex_parser.value + and fn_def.params.type == ScoringFnParamsType.regex_parser.value ), f"RegexParserScoringFnParams not found for {fn_def}." expected_answer = input_row["expected_answer"] diff --git a/llama_stack/providers/tests/scoring/fixtures.py b/llama_stack/providers/tests/scoring/fixtures.py index 648d35859..20631f5cf 100644 --- a/llama_stack/providers/tests/scoring/fixtures.py +++ b/llama_stack/providers/tests/scoring/fixtures.py @@ -48,7 +48,7 @@ SCORING_FIXTURES = ["meta_reference", "remote", "braintrust"] @pytest_asyncio.fixture(scope="session") -async def scoring_stack(request): +async def scoring_stack(request, inference_model): fixture_dict = request.param providers = {} @@ -65,4 +65,18 @@ async def scoring_stack(request): provider_data, ) + provider_id = providers["inference"][0].provider_id + await impls[Api.models].register_model( + model_id=inference_model, + provider_id=provider_id, + ) + await impls[Api.models].register_model( + model_id="Llama3.1-405B-Instruct", + provider_id=provider_id, + ) + await impls[Api.models].register_model( + model_id="Llama3.1-8B-Instruct", + provider_id=provider_id, + ) + return impls