registry refactor

This commit is contained in:
Xi Yan 2024-10-14 16:09:55 -07:00
parent c50686b6fe
commit 95fd53d292
8 changed files with 39 additions and 71 deletions

View file

@ -29,8 +29,7 @@ class GenerationOutput(BaseModel):
@json_schema_type @json_schema_type
class PostprocessedGeneration(BaseModel): class PostprocessedGeneration(BaseModel):
completion_message: str completion_message: str
# structured transformed output from raw_completion_message to compute scorer metrics logprobs: Optional[List[TokenLogProbs]] = None
transformed_generation: Optional[Any] = None
# A sample (row) from dataset # A sample (row) from dataset
@ -70,8 +69,15 @@ TScorerInputSample = TypeVar("TScorerInputSample")
@json_schema_type @json_schema_type
class ScorerInputSample(DatasetSample): class ScorerInputSample(DatasetSample):
generation_output: PostprocessedGeneration """
expected_output: Union[str, List[str]] A dataset is required to have the following columns to be used for scoring:
- generated_answer: str
- expected_answer: Union[str, List[str]]
"""
generated_answer: str
expected_answer: Union[str, List[str]]
generation_output: Optional[PostprocessedGeneration] = None
@json_schema_type @json_schema_type

View file

@ -217,18 +217,8 @@ class BaseScorer(ABC, Generic[TScorerInputSample]):
class BaseTask(ABC): class BaseTask(ABC):
def __init__( def __init__(self, *args, **kwargs) -> None:
self,
generator_processor: Optional[BaseGeneratorProcessor] = None,
generator: Optional[BaseGenerator] = None,
scorer: Optional[BaseScorer] = None,
*args,
**kwargs
) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.generator_processor = generator_processor
self.generator = generator
self.scorer = scorer
@abstractmethod @abstractmethod
async def run(self, *args, **kwargs) -> EvalResult: async def run(self, *args, **kwargs) -> EvalResult:

View file

@ -7,5 +7,4 @@ from llama_stack.apis.datasets import * # noqa: F403
from ..registry import Registry from ..registry import Registry
class DatasetRegistry(Registry[BaseDataset]): DatasetRegistry = Registry[BaseDataset]()
_REGISTRY: Dict[str, BaseDataset] = {}

View file

@ -3,36 +3,34 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import AbstractSet, Dict, Generic, TypeVar from typing import AbstractSet, Generic, TypeVar
TRegistry = TypeVar("TRegistry") TRegistry = TypeVar("TRegistry")
class Registry(Generic[TRegistry]): class Registry(Generic[TRegistry]):
_REGISTRY: Dict[str, TRegistry] = {}
@staticmethod def __init__(self) -> None:
def names() -> AbstractSet[str]: super().__init__()
return Registry._REGISTRY.keys() self.registry = {}
@staticmethod def names(self) -> AbstractSet[str]:
def register(name: str, task: TRegistry) -> None: return self.registry.keys()
if name in Registry._REGISTRY:
def register(self, name: str, task: TRegistry) -> None:
if name in self.registry:
raise ValueError(f"Dataset {name} already exists.") raise ValueError(f"Dataset {name} already exists.")
Registry._REGISTRY[name] = task self.registry[name] = task
@staticmethod def get(self, name: str) -> TRegistry:
def get(name: str) -> TRegistry: if name not in self.registry:
if name not in Registry._REGISTRY:
raise ValueError(f"Dataset {name} not found.") raise ValueError(f"Dataset {name} not found.")
return Registry._REGISTRY[name] return self.registry[name]
@staticmethod def delete(self, name: str) -> None:
def delete(name: str) -> None: if name not in self.registry:
if name not in Registry._REGISTRY:
raise ValueError(f"Dataset {name} not found.") raise ValueError(f"Dataset {name} not found.")
del Registry._REGISTRY[name] del self.registry[name]
@staticmethod def reset(self) -> None:
def reset() -> None: self.registry = {}
Registry._REGISTRY = {}

View file

@ -9,10 +9,7 @@ from llama_stack.providers.impls.meta_reference.evals.scorer.basic_scorers impor
from ..registry import Registry from ..registry import Registry
ScorerRegistry = Registry[BaseScorer]()
class ScorerRegistry(Registry[BaseScorer]):
_REGISTRY: Dict[str, BaseScorer] = {}
SCORER_REGISTRY = { SCORER_REGISTRY = {
"accuracy": AccuracyScorer, "accuracy": AccuracyScorer,

View file

@ -71,6 +71,10 @@ class MetaReferenceEvalsImpl(Evals):
dataset_config: EvaluateDatasetConfig, dataset_config: EvaluateDatasetConfig,
eval_scoring_config: EvaluateScoringConfig, eval_scoring_config: EvaluateScoringConfig,
) -> EvaluateResponse: ) -> EvaluateResponse:
cprint("run_scorer")
# main logic, we need to convert the datset into List[ScorerInputSample]
return EvaluateResponse( return EvaluateResponse(
eval_result={}, eval_result={},
) )

View file

@ -153,35 +153,9 @@ class MMLUProcessor(
break break
return ScorerInputSample( return ScorerInputSample(
generated_answer=extracted_answer,
expected_answer=dataset_sample.data["Answer"],
generation_output=PostprocessedGeneration( generation_output=PostprocessedGeneration(
completion_message=response_text, completion_message=response_text,
transformed_generation=extracted_answer,
), ),
expected_output=dataset_sample.data["Answer"],
) )
# def score_sample(self, sample: ProcessedDictSample) -> SingleEvalResult:
# postprocessed_output = sample.postprocessed["postprocessed"]
# expected_answer = sample.data["Answer"]
# extracted_answer = None
# for answer_regex in MULTILINGUAL_ANSWER_REGEXES:
# regex = MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(answer_regex)
# match = re.search(regex, postprocessed_output)
# if match:
# extracted_answer = normalize_extracted_answer(match.group(1))
# break
# score = 1.0 if extracted_answer and extracted_answer == expected_answer else 0.0
# return SingleEvalResult(
# score_data={
# "score": score,
# },
# )
# def aggregate_results(self, eval_results: List[SingleEvalResult]) -> EvalResult:
# print("aggregate_results", eval_results)
# sum_score = sum([result.score_data["score"] for result in eval_results])
# return EvalResult(metrics={"score": str(sum_score / len(eval_results))})

View file

@ -28,8 +28,8 @@ class RandomScorer(BaseScorer[ScorerInputSample]):
class AccuracyScorer(BaseScorer[ScorerInputSample]): class AccuracyScorer(BaseScorer[ScorerInputSample]):
def score_sample(self, scorer_input_sample: ScorerInputSample) -> SingleEvalResult: def score_sample(self, scorer_input_sample: ScorerInputSample) -> SingleEvalResult:
extracted_answer = scorer_input_sample.generation_output.transformed_generation extracted_answer = scorer_input_sample.generated_answer
expected_answer = scorer_input_sample.expected_output expected_answer = scorer_input_sample.expected_answer
accuracy = ( accuracy = (
1.0 if extracted_answer and extracted_answer == expected_answer else 0.0 1.0 if extracted_answer and extracted_answer == expected_answer else 0.0