mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 17:29:01 +00:00
refactor base scoring fn v.s. registerable scoring fn
This commit is contained in:
parent
03607a68c7
commit
0096c1a6fc
3 changed files with 45 additions and 6 deletions
|
@ -9,12 +9,12 @@ from typing import Any, Dict, Optional
|
||||||
from llama_stack.apis.scoring import ScoringResultRow
|
from llama_stack.apis.scoring import ScoringResultRow
|
||||||
|
|
||||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||||
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
|
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||||
|
|
||||||
from .fn_defs.equality import equality
|
from .fn_defs.equality import equality
|
||||||
|
|
||||||
|
|
||||||
class EqualityScoringFn(BaseScoringFn):
|
class EqualityScoringFn(RegisteredBaseScoringFn):
|
||||||
"""
|
"""
|
||||||
A scoring_fn that assigns a score of 1.0 if the input string matches the target string, and 0.0 otherwise.
|
A scoring_fn that assigns a score of 1.0 if the input string matches the target string, and 0.0 otherwise.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -12,14 +12,14 @@ from llama_stack.apis.inference.inference import Inference
|
||||||
from llama_stack.apis.scoring import ScoringResultRow
|
from llama_stack.apis.scoring import ScoringResultRow
|
||||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||||
|
|
||||||
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
|
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||||
|
|
||||||
from .fn_defs.llm_as_judge_405b_simpleqa import llm_as_judge_405b_simpleqa
|
from .fn_defs.llm_as_judge_405b_simpleqa import llm_as_judge_405b_simpleqa
|
||||||
|
|
||||||
from .fn_defs.llm_as_judge_base import llm_as_judge_base
|
from .fn_defs.llm_as_judge_base import llm_as_judge_base
|
||||||
|
|
||||||
|
|
||||||
class LlmAsJudgeScoringFn(BaseScoringFn):
|
class LlmAsJudgeScoringFn(RegisteredBaseScoringFn):
|
||||||
"""
|
"""
|
||||||
A scoring_fn that assigns
|
A scoring_fn that assigns
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -13,12 +13,51 @@ from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metr
|
||||||
|
|
||||||
class BaseScoringFn(ABC):
|
class BaseScoringFn(ABC):
|
||||||
"""
|
"""
|
||||||
Base interface class for all native scoring_fns.
|
Base interface class for Scoring Functions.
|
||||||
Each scoring_fn needs to implement the following methods:
|
Each scoring function needs to implement the following methods:
|
||||||
- score_row(self, row)
|
- score_row(self, row)
|
||||||
- aggregate(self, scoring_fn_results)
|
- aggregate(self, scoring_fn_results)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return self.__class__.__name__
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def score_row(
|
||||||
|
self,
|
||||||
|
input_row: Dict[str, Any],
|
||||||
|
scoring_fn_identifier: Optional[str] = None,
|
||||||
|
scoring_params: Optional[ScoringFnParams] = None,
|
||||||
|
) -> ScoringResultRow:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def aggregate(
|
||||||
|
self,
|
||||||
|
scoring_results: List[ScoringResultRow],
|
||||||
|
scoring_fn_identifier: Optional[str] = None,
|
||||||
|
scoring_params: Optional[ScoringFnParams] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def score(
|
||||||
|
self,
|
||||||
|
input_rows: List[Dict[str, Any]],
|
||||||
|
scoring_fn_identifier: Optional[str] = None,
|
||||||
|
scoring_params: Optional[ScoringFnParams] = None,
|
||||||
|
) -> List[ScoringResultRow]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class RegisteredBaseScoringFn(BaseScoringFn):
|
||||||
|
"""
|
||||||
|
Interface for native scoring functions that are registered in LlamaStack.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs) -> None:
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.supported_fn_defs_registry = {}
|
self.supported_fn_defs_registry = {}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue