mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-28 15:02:37 +00:00
registry refactor
This commit is contained in:
parent
c50686b6fe
commit
95fd53d292
8 changed files with 39 additions and 71 deletions
|
@ -29,8 +29,7 @@ class GenerationOutput(BaseModel):
|
|||
@json_schema_type
|
||||
class PostprocessedGeneration(BaseModel):
|
||||
completion_message: str
|
||||
# structured transformed output from raw_completion_message to compute scorer metrics
|
||||
transformed_generation: Optional[Any] = None
|
||||
logprobs: Optional[List[TokenLogProbs]] = None
|
||||
|
||||
|
||||
# A sample (row) from dataset
|
||||
|
@ -70,8 +69,15 @@ TScorerInputSample = TypeVar("TScorerInputSample")
|
|||
|
||||
@json_schema_type
|
||||
class ScorerInputSample(DatasetSample):
|
||||
generation_output: PostprocessedGeneration
|
||||
expected_output: Union[str, List[str]]
|
||||
"""
|
||||
A dataset is required to have the following columns to be used for scoring:
|
||||
- generated_answer: str
|
||||
- expected_answer: Union[str, List[str]]
|
||||
"""
|
||||
|
||||
generated_answer: str
|
||||
expected_answer: Union[str, List[str]]
|
||||
generation_output: Optional[PostprocessedGeneration] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -217,18 +217,8 @@ class BaseScorer(ABC, Generic[TScorerInputSample]):
|
|||
|
||||
|
||||
class BaseTask(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
generator_processor: Optional[BaseGeneratorProcessor] = None,
|
||||
generator: Optional[BaseGenerator] = None,
|
||||
scorer: Optional[BaseScorer] = None,
|
||||
*args,
|
||||
**kwargs
|
||||
) -> None:
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.generator_processor = generator_processor
|
||||
self.generator = generator
|
||||
self.scorer = scorer
|
||||
|
||||
@abstractmethod
|
||||
async def run(self, *args, **kwargs) -> EvalResult:
|
||||
|
|
|
@ -7,5 +7,4 @@ from llama_stack.apis.datasets import * # noqa: F403
|
|||
from ..registry import Registry
|
||||
|
||||
|
||||
class DatasetRegistry(Registry[BaseDataset]):
|
||||
_REGISTRY: Dict[str, BaseDataset] = {}
|
||||
DatasetRegistry = Registry[BaseDataset]()
|
||||
|
|
|
@ -3,36 +3,34 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import AbstractSet, Dict, Generic, TypeVar
|
||||
from typing import AbstractSet, Generic, TypeVar
|
||||
|
||||
TRegistry = TypeVar("TRegistry")
|
||||
|
||||
|
||||
class Registry(Generic[TRegistry]):
|
||||
_REGISTRY: Dict[str, TRegistry] = {}
|
||||
|
||||
@staticmethod
|
||||
def names() -> AbstractSet[str]:
|
||||
return Registry._REGISTRY.keys()
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.registry = {}
|
||||
|
||||
@staticmethod
|
||||
def register(name: str, task: TRegistry) -> None:
|
||||
if name in Registry._REGISTRY:
|
||||
def names(self) -> AbstractSet[str]:
|
||||
return self.registry.keys()
|
||||
|
||||
def register(self, name: str, task: TRegistry) -> None:
|
||||
if name in self.registry:
|
||||
raise ValueError(f"Dataset {name} already exists.")
|
||||
Registry._REGISTRY[name] = task
|
||||
self.registry[name] = task
|
||||
|
||||
@staticmethod
|
||||
def get(name: str) -> TRegistry:
|
||||
if name not in Registry._REGISTRY:
|
||||
def get(self, name: str) -> TRegistry:
|
||||
if name not in self.registry:
|
||||
raise ValueError(f"Dataset {name} not found.")
|
||||
return Registry._REGISTRY[name]
|
||||
return self.registry[name]
|
||||
|
||||
@staticmethod
|
||||
def delete(name: str) -> None:
|
||||
if name not in Registry._REGISTRY:
|
||||
def delete(self, name: str) -> None:
|
||||
if name not in self.registry:
|
||||
raise ValueError(f"Dataset {name} not found.")
|
||||
del Registry._REGISTRY[name]
|
||||
del self.registry[name]
|
||||
|
||||
@staticmethod
|
||||
def reset() -> None:
|
||||
Registry._REGISTRY = {}
|
||||
def reset(self) -> None:
|
||||
self.registry = {}
|
||||
|
|
|
@ -9,10 +9,7 @@ from llama_stack.providers.impls.meta_reference.evals.scorer.basic_scorers impor
|
|||
|
||||
from ..registry import Registry
|
||||
|
||||
|
||||
class ScorerRegistry(Registry[BaseScorer]):
|
||||
_REGISTRY: Dict[str, BaseScorer] = {}
|
||||
|
||||
ScorerRegistry = Registry[BaseScorer]()
|
||||
|
||||
SCORER_REGISTRY = {
|
||||
"accuracy": AccuracyScorer,
|
||||
|
|
|
@ -71,6 +71,10 @@ class MetaReferenceEvalsImpl(Evals):
|
|||
dataset_config: EvaluateDatasetConfig,
|
||||
eval_scoring_config: EvaluateScoringConfig,
|
||||
) -> EvaluateResponse:
|
||||
cprint("run_scorer")
|
||||
|
||||
# main logic, we need to convert the datset into List[ScorerInputSample]
|
||||
|
||||
return EvaluateResponse(
|
||||
eval_result={},
|
||||
)
|
||||
|
|
|
@ -153,35 +153,9 @@ class MMLUProcessor(
|
|||
break
|
||||
|
||||
return ScorerInputSample(
|
||||
generated_answer=extracted_answer,
|
||||
expected_answer=dataset_sample.data["Answer"],
|
||||
generation_output=PostprocessedGeneration(
|
||||
completion_message=response_text,
|
||||
transformed_generation=extracted_answer,
|
||||
),
|
||||
expected_output=dataset_sample.data["Answer"],
|
||||
)
|
||||
|
||||
# def score_sample(self, sample: ProcessedDictSample) -> SingleEvalResult:
|
||||
# postprocessed_output = sample.postprocessed["postprocessed"]
|
||||
# expected_answer = sample.data["Answer"]
|
||||
|
||||
# extracted_answer = None
|
||||
# for answer_regex in MULTILINGUAL_ANSWER_REGEXES:
|
||||
# regex = MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(answer_regex)
|
||||
# match = re.search(regex, postprocessed_output)
|
||||
# if match:
|
||||
# extracted_answer = normalize_extracted_answer(match.group(1))
|
||||
# break
|
||||
|
||||
# score = 1.0 if extracted_answer and extracted_answer == expected_answer else 0.0
|
||||
|
||||
# return SingleEvalResult(
|
||||
# score_data={
|
||||
# "score": score,
|
||||
# },
|
||||
# )
|
||||
|
||||
# def aggregate_results(self, eval_results: List[SingleEvalResult]) -> EvalResult:
|
||||
# print("aggregate_results", eval_results)
|
||||
# sum_score = sum([result.score_data["score"] for result in eval_results])
|
||||
|
||||
# return EvalResult(metrics={"score": str(sum_score / len(eval_results))})
|
||||
|
|
|
@ -28,8 +28,8 @@ class RandomScorer(BaseScorer[ScorerInputSample]):
|
|||
|
||||
class AccuracyScorer(BaseScorer[ScorerInputSample]):
|
||||
def score_sample(self, scorer_input_sample: ScorerInputSample) -> SingleEvalResult:
|
||||
extracted_answer = scorer_input_sample.generation_output.transformed_generation
|
||||
expected_answer = scorer_input_sample.expected_output
|
||||
extracted_answer = scorer_input_sample.generated_answer
|
||||
expected_answer = scorer_input_sample.expected_answer
|
||||
|
||||
accuracy = (
|
||||
1.0 if extracted_answer and extracted_answer == expected_answer else 0.0
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue