This commit is contained in:
Xi Yan 2024-11-11 23:02:56 -05:00
parent 87dc116782
commit fd424e7900
6 changed files with 8 additions and 68 deletions

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .base_scoring_fn import BaseScoringFn
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
import re
from .base_scoring_fn import BaseScoringFn
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .base_scoring_fn import BaseScoringFn
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403

View file

@ -1,61 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
class BaseScoringFn(ABC):
"""
Base interface class for all meta-reference scoring_fns.
Each scoring_fn needs to implement the following methods:
- score_row(self, row)
- aggregate(self, scoring_fn_results)
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.supported_fn_defs_registry = {}
def __str__(self) -> str:
return self.__class__.__name__
def get_supported_scoring_fn_defs(self) -> List[ScoringFn]:
return [x for x in self.supported_fn_defs_registry.values()]
def register_scoring_fn_def(self, scoring_fn: ScoringFn) -> None:
if scoring_fn.identifier in self.supported_fn_defs_registry:
raise ValueError(
f"Scoring function def with identifier {scoring_fn.identifier} already exists."
)
self.supported_fn_defs_registry[scoring_fn.identifier] = scoring_fn
@abstractmethod
async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow:
raise NotImplementedError()
@abstractmethod
async def aggregate(
self, scoring_results: List[ScoringResultRow]
) -> Dict[str, Any]:
raise NotImplementedError()
async def score(
self,
input_rows: List[Dict[str, Any]],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> List[ScoringResultRow]:
return [
await self.score_row(input_row, scoring_fn_identifier, scoring_params)
for input_row in input_rows
]

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from llama_stack.apis.inference.inference import Inference
from .base_scoring_fn import BaseScoringFn
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403

View file

@ -4,9 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from typing import Any, Dict, List, Optional
from llama_stack.apis.scoring import ScoringFnParams, ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFn
class BaseScoringFn(ABC):