diff --git a/llama_stack/apis/scoring/scoring.py b/llama_stack/apis/scoring/scoring.py index adac34d55..1fd523dcb 100644 --- a/llama_stack/apis/scoring/scoring.py +++ b/llama_stack/apis/scoring/scoring.py @@ -37,7 +37,7 @@ class ScoreResponse(BaseModel): class ScoringFunctionStore(Protocol): - def get_scoring_function(self, name: str) -> ScoringFunctionDefWithProvider: ... + def get_scoring_function(self, name: str) -> ScoringFnDefWithProvider: ... @runtime_checkable diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py index a242215c6..fc3584f90 100644 --- a/llama_stack/apis/scoring_functions/scoring_functions.py +++ b/llama_stack/apis/scoring_functions/scoring_functions.py @@ -29,7 +29,7 @@ class LLMAsJudgeContext(BaseModel): @json_schema_type -class ScoringFunctionDef(BaseModel): +class ScoringFnDef(BaseModel): identifier: str description: Optional[str] = None metadata: Dict[str, Any] = Field( @@ -48,7 +48,7 @@ class ScoringFunctionDef(BaseModel): @json_schema_type -class ScoringFunctionDefWithProvider(ScoringFunctionDef): +class ScoringFnDefWithProvider(ScoringFnDef): provider_id: str = Field( description="ID of the provider which serves this dataset", ) @@ -57,14 +57,14 @@ class ScoringFunctionDefWithProvider(ScoringFunctionDef): @runtime_checkable class ScoringFunctions(Protocol): @webmethod(route="/scoring_functions/list", method="GET") - async def list_scoring_functions(self) -> List[ScoringFunctionDefWithProvider]: ... + async def list_scoring_functions(self) -> List[ScoringFnDefWithProvider]: ... @webmethod(route="/scoring_functions/get", method="GET") async def get_scoring_function( self, name: str - ) -> Optional[ScoringFunctionDefWithProvider]: ... + ) -> Optional[ScoringFnDefWithProvider]: ... @webmethod(route="/scoring_functions/register", method="POST") async def register_scoring_function( - self, function_def: ScoringFunctionDefWithProvider + self, function_def: ScoringFnDefWithProvider ) -> None: ... diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index 318809baf..9ad82cd79 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -34,7 +34,7 @@ RoutableObject = Union[ ShieldDef, MemoryBankDef, DatasetDef, - ScoringFunctionDef, + ScoringFnDef, ] RoutableObjectWithProvider = Union[ @@ -42,7 +42,7 @@ RoutableObjectWithProvider = Union[ ShieldDefWithProvider, MemoryBankDefWithProvider, DatasetDefWithProvider, - ScoringFunctionDefWithProvider, + ScoringFnDefWithProvider, ] RoutedProtocol = Union[ diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index dcd588a9e..3e07b9162 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -100,7 +100,7 @@ class CommonRoutingTableImpl(RoutingTable): scoring_functions = await p.list_scoring_functions() add_objects( [ - ScoringFunctionDefWithProvider(**s.dict(), provider_id=pid) + ScoringFnDefWithProvider(**s.dict(), provider_id=pid) for s in scoring_functions ] ) @@ -239,7 +239,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, Scoring): - async def list_scoring_functions(self) -> List[ScoringFunctionDefWithProvider]: + async def list_scoring_functions(self) -> List[ScoringFnDefWithProvider]: objects = [] for objs in self.registry.values(): objects.extend(objs) @@ -247,10 +247,10 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, Scoring): async def get_scoring_function( self, name: str - ) -> Optional[ScoringFunctionDefWithProvider]: + ) -> Optional[ScoringFnDefWithProvider]: return self.get_object_by_identifier(name) async def register_scoring_function( - self, function_def: ScoringFunctionDefWithProvider + self, function_def: ScoringFnDefWithProvider ) -> None: await self.register_object(function_def) diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 8d476a509..eace0ea1a 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -13,7 +13,7 @@ from pydantic import BaseModel, Field from llama_stack.apis.datasets import DatasetDef from llama_stack.apis.memory_banks import MemoryBankDef from llama_stack.apis.models import ModelDef -from llama_stack.apis.scoring_functions import ScoringFunctionDef +from llama_stack.apis.scoring_functions import ScoringFnDef from llama_stack.apis.shields import ShieldDef @@ -64,11 +64,9 @@ class DatasetsProtocolPrivate(Protocol): class ScoringFunctionsProtocolPrivate(Protocol): - async def list_scoring_functions(self) -> List[ScoringFunctionDef]: ... + async def list_scoring_functions(self) -> List[ScoringFnDef]: ... - async def register_scoring_function( - self, function_def: ScoringFunctionDef - ) -> None: ... + async def register_scoring_function(self, function_def: ScoringFnDef) -> None: ... @json_schema_type diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring.py b/llama_stack/providers/impls/meta_reference/scoring/scoring.py index 05ace33b4..b1d561533 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring.py @@ -13,22 +13,22 @@ from llama_stack.apis.datasetio import * # noqa: F403 from llama_stack.apis.datasets import * # noqa: F403 from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate -from llama_stack.providers.impls.meta_reference.scoring.scorer.equality_scorer import ( - EqualityScorer, +from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.equality_scoring_fn import ( + EqualityScoringFn, ) -from llama_stack.providers.impls.meta_reference.scoring.scorer.subset_of_scorer import ( - SubsetOfScorer, +from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.subset_of_scoring_fn import ( + SubsetOfScoringFn, ) from .config import MetaReferenceScoringConfig -SUPPORTED_SCORERS = [ - EqualityScorer, - SubsetOfScorer, +SUPPORTED_SCORING_FNS = [ + EqualityScoringFn, + SubsetOfScoringFn, ] -SCORER_REGISTRY = {x.scoring_function_def.identifier: x for x in SUPPORTED_SCORERS} +SCORER_REGISTRY = {x.scoring_function_def.identifier: x for x in SUPPORTED_SCORING_FNS} class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): @@ -46,10 +46,10 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): async def shutdown(self) -> None: ... - async def list_scoring_functions(self) -> List[ScoringFunctionDef]: - return [x.scoring_function_def for x in SUPPORTED_SCORERS] + async def list_scoring_functions(self) -> List[ScoringFnDef]: + return [x.scoring_function_def for x in SUPPORTED_SCORING_FNS] - async def register_scoring_function(self, function_def: ScoringFunctionDef) -> None: + async def register_scoring_function(self, function_def: ScoringFnDef) -> None: raise NotImplementedError( "Dynamically registering scoring functions is not supported" ) @@ -101,9 +101,9 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): for scoring_fn_id in scoring_functions: if scoring_fn_id not in SCORER_REGISTRY: raise ValueError(f"Scoring function {scoring_fn_id} is not supported.") - scorer = SCORER_REGISTRY[scoring_fn_id]() - score_results = scorer.score(input_rows) - agg_results = scorer.aggregate(score_results) + scoring_fn = SCORER_REGISTRY[scoring_fn_id]() + score_results = scoring_fn.score(input_rows) + agg_results = scoring_fn.aggregate(score_results) res[scoring_fn_id] = ScoringResult( score_rows=score_results, aggregated_results=agg_results, diff --git a/llama_stack/providers/impls/meta_reference/scoring/scorer/__init__.py b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/__init__.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/scoring/scorer/__init__.py rename to llama_stack/providers/impls/meta_reference/scoring/scoring_fn/__init__.py diff --git a/llama_stack/providers/impls/meta_reference/scoring/scorer/base_scorer.py b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/base_scoring_fn.py similarity index 81% rename from llama_stack/providers/impls/meta_reference/scoring/scorer/base_scorer.py rename to llama_stack/providers/impls/meta_reference/scoring/scoring_fn/base_scoring_fn.py index ea8a3f063..952d46bb2 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scorer/base_scorer.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/base_scoring_fn.py @@ -9,15 +9,15 @@ from llama_stack.apis.scoring_functions import * # noqa: F401, F403 from llama_stack.apis.scoring import * # noqa: F401, F403 -class BaseScorer(ABC): +class BaseScoringFn(ABC): """ - Base interface class for all meta-reference scorers. - Each scorer needs to implement the following methods: + Base interface class for all meta-reference scoring_fns. + Each scoring_fn needs to implement the following methods: - score_row(self, row) - - aggregate(self, scorer_results) + - aggregate(self, scoring_fn_results) """ - scoring_function_def: ScoringFunctionDef + scoring_function_def: ScoringFnDef def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) diff --git a/llama_stack/providers/impls/meta_reference/scoring/scorer/common.py b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/common.py similarity index 100% rename from llama_stack/providers/impls/meta_reference/scoring/scorer/common.py rename to llama_stack/providers/impls/meta_reference/scoring/scoring_fn/common.py diff --git a/llama_stack/providers/impls/meta_reference/scoring/scorer/equality_scorer.py b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/equality_scoring_fn.py similarity index 76% rename from llama_stack/providers/impls/meta_reference/scoring/scorer/equality_scorer.py rename to llama_stack/providers/impls/meta_reference/scoring/scoring_fn/equality_scoring_fn.py index 0c7751f35..cce0f948a 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scorer/equality_scorer.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/equality_scoring_fn.py @@ -4,23 +4,23 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.providers.impls.meta_reference.scoring.scorer.base_scorer import ( - BaseScorer, +from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.base_scoring_fn import ( + BaseScoringFn, ) from llama_stack.apis.scoring_functions import * # noqa: F401, F403 from llama_stack.apis.scoring import * # noqa: F401, F403 from llama_stack.apis.common.type_system import * # noqa: F403 -from llama_stack.providers.impls.meta_reference.scoring.scorer.common import ( +from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import ( aggregate_accuracy, ) -class EqualityScorer(BaseScorer): +class EqualityScoringFn(BaseScoringFn): """ - A scorer that assigns a score of 1.0 if the input string matches the target string, and 0.0 otherwise. + A scoring_fn that assigns a score of 1.0 if the input string matches the target string, and 0.0 otherwise. """ - scoring_function_def = ScoringFunctionDef( + scoring_function_def = ScoringFnDef( identifier="equality", description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.", parameters=[], diff --git a/llama_stack/providers/impls/meta_reference/scoring/scorer/subset_of_scorer.py b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/subset_of_scoring_fn.py similarity index 76% rename from llama_stack/providers/impls/meta_reference/scoring/scorer/subset_of_scorer.py rename to llama_stack/providers/impls/meta_reference/scoring/scoring_fn/subset_of_scoring_fn.py index e72b5ed0f..c7ee68e26 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scorer/subset_of_scorer.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/subset_of_scoring_fn.py @@ -4,23 +4,23 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.providers.impls.meta_reference.scoring.scorer.base_scorer import ( - BaseScorer, +from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.base_scoring_fn import ( + BaseScoringFn, ) from llama_stack.apis.scoring_functions import * # noqa: F401, F403 from llama_stack.apis.scoring import * # noqa: F401, F403 from llama_stack.apis.common.type_system import * # noqa: F403 -from llama_stack.providers.impls.meta_reference.scoring.scorer.common import ( +from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import ( aggregate_accuracy, ) -class SubsetOfScorer(BaseScorer): +class SubsetOfScoringFn(BaseScoringFn): """ - A scorer that assigns a score of 1.0 if the expected string is included in the generated string, and 0.0 otherwise. + A scoring_fn that assigns a score of 1.0 if the expected string is included in the generated string, and 0.0 otherwise. """ - scoring_function_def = ScoringFunctionDef( + scoring_function_def = ScoringFnDef( identifier="subset_of", description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.", parameters=[],