mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
registry refactor
This commit is contained in:
parent
c50686b6fe
commit
95fd53d292
8 changed files with 39 additions and 71 deletions
|
@ -29,8 +29,7 @@ class GenerationOutput(BaseModel):
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class PostprocessedGeneration(BaseModel):
|
class PostprocessedGeneration(BaseModel):
|
||||||
completion_message: str
|
completion_message: str
|
||||||
# structured transformed output from raw_completion_message to compute scorer metrics
|
logprobs: Optional[List[TokenLogProbs]] = None
|
||||||
transformed_generation: Optional[Any] = None
|
|
||||||
|
|
||||||
|
|
||||||
# A sample (row) from dataset
|
# A sample (row) from dataset
|
||||||
|
@ -70,8 +69,15 @@ TScorerInputSample = TypeVar("TScorerInputSample")
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ScorerInputSample(DatasetSample):
|
class ScorerInputSample(DatasetSample):
|
||||||
generation_output: PostprocessedGeneration
|
"""
|
||||||
expected_output: Union[str, List[str]]
|
A dataset is required to have the following columns to be used for scoring:
|
||||||
|
- generated_answer: str
|
||||||
|
- expected_answer: Union[str, List[str]]
|
||||||
|
"""
|
||||||
|
|
||||||
|
generated_answer: str
|
||||||
|
expected_answer: Union[str, List[str]]
|
||||||
|
generation_output: Optional[PostprocessedGeneration] = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -217,18 +217,8 @@ class BaseScorer(ABC, Generic[TScorerInputSample]):
|
||||||
|
|
||||||
|
|
||||||
class BaseTask(ABC):
|
class BaseTask(ABC):
|
||||||
def __init__(
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
self,
|
|
||||||
generator_processor: Optional[BaseGeneratorProcessor] = None,
|
|
||||||
generator: Optional[BaseGenerator] = None,
|
|
||||||
scorer: Optional[BaseScorer] = None,
|
|
||||||
*args,
|
|
||||||
**kwargs
|
|
||||||
) -> None:
|
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.generator_processor = generator_processor
|
|
||||||
self.generator = generator
|
|
||||||
self.scorer = scorer
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def run(self, *args, **kwargs) -> EvalResult:
|
async def run(self, *args, **kwargs) -> EvalResult:
|
||||||
|
|
|
@ -7,5 +7,4 @@ from llama_stack.apis.datasets import * # noqa: F403
|
||||||
from ..registry import Registry
|
from ..registry import Registry
|
||||||
|
|
||||||
|
|
||||||
class DatasetRegistry(Registry[BaseDataset]):
|
DatasetRegistry = Registry[BaseDataset]()
|
||||||
_REGISTRY: Dict[str, BaseDataset] = {}
|
|
||||||
|
|
|
@ -3,36 +3,34 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
from typing import AbstractSet, Dict, Generic, TypeVar
|
from typing import AbstractSet, Generic, TypeVar
|
||||||
|
|
||||||
TRegistry = TypeVar("TRegistry")
|
TRegistry = TypeVar("TRegistry")
|
||||||
|
|
||||||
|
|
||||||
class Registry(Generic[TRegistry]):
|
class Registry(Generic[TRegistry]):
|
||||||
_REGISTRY: Dict[str, TRegistry] = {}
|
|
||||||
|
|
||||||
@staticmethod
|
def __init__(self) -> None:
|
||||||
def names() -> AbstractSet[str]:
|
super().__init__()
|
||||||
return Registry._REGISTRY.keys()
|
self.registry = {}
|
||||||
|
|
||||||
@staticmethod
|
def names(self) -> AbstractSet[str]:
|
||||||
def register(name: str, task: TRegistry) -> None:
|
return self.registry.keys()
|
||||||
if name in Registry._REGISTRY:
|
|
||||||
|
def register(self, name: str, task: TRegistry) -> None:
|
||||||
|
if name in self.registry:
|
||||||
raise ValueError(f"Dataset {name} already exists.")
|
raise ValueError(f"Dataset {name} already exists.")
|
||||||
Registry._REGISTRY[name] = task
|
self.registry[name] = task
|
||||||
|
|
||||||
@staticmethod
|
def get(self, name: str) -> TRegistry:
|
||||||
def get(name: str) -> TRegistry:
|
if name not in self.registry:
|
||||||
if name not in Registry._REGISTRY:
|
|
||||||
raise ValueError(f"Dataset {name} not found.")
|
raise ValueError(f"Dataset {name} not found.")
|
||||||
return Registry._REGISTRY[name]
|
return self.registry[name]
|
||||||
|
|
||||||
@staticmethod
|
def delete(self, name: str) -> None:
|
||||||
def delete(name: str) -> None:
|
if name not in self.registry:
|
||||||
if name not in Registry._REGISTRY:
|
|
||||||
raise ValueError(f"Dataset {name} not found.")
|
raise ValueError(f"Dataset {name} not found.")
|
||||||
del Registry._REGISTRY[name]
|
del self.registry[name]
|
||||||
|
|
||||||
@staticmethod
|
def reset(self) -> None:
|
||||||
def reset() -> None:
|
self.registry = {}
|
||||||
Registry._REGISTRY = {}
|
|
||||||
|
|
|
@ -9,10 +9,7 @@ from llama_stack.providers.impls.meta_reference.evals.scorer.basic_scorers impor
|
||||||
|
|
||||||
from ..registry import Registry
|
from ..registry import Registry
|
||||||
|
|
||||||
|
ScorerRegistry = Registry[BaseScorer]()
|
||||||
class ScorerRegistry(Registry[BaseScorer]):
|
|
||||||
_REGISTRY: Dict[str, BaseScorer] = {}
|
|
||||||
|
|
||||||
|
|
||||||
SCORER_REGISTRY = {
|
SCORER_REGISTRY = {
|
||||||
"accuracy": AccuracyScorer,
|
"accuracy": AccuracyScorer,
|
||||||
|
|
|
@ -71,6 +71,10 @@ class MetaReferenceEvalsImpl(Evals):
|
||||||
dataset_config: EvaluateDatasetConfig,
|
dataset_config: EvaluateDatasetConfig,
|
||||||
eval_scoring_config: EvaluateScoringConfig,
|
eval_scoring_config: EvaluateScoringConfig,
|
||||||
) -> EvaluateResponse:
|
) -> EvaluateResponse:
|
||||||
|
cprint("run_scorer")
|
||||||
|
|
||||||
|
# main logic, we need to convert the datset into List[ScorerInputSample]
|
||||||
|
|
||||||
return EvaluateResponse(
|
return EvaluateResponse(
|
||||||
eval_result={},
|
eval_result={},
|
||||||
)
|
)
|
||||||
|
|
|
@ -153,35 +153,9 @@ class MMLUProcessor(
|
||||||
break
|
break
|
||||||
|
|
||||||
return ScorerInputSample(
|
return ScorerInputSample(
|
||||||
|
generated_answer=extracted_answer,
|
||||||
|
expected_answer=dataset_sample.data["Answer"],
|
||||||
generation_output=PostprocessedGeneration(
|
generation_output=PostprocessedGeneration(
|
||||||
completion_message=response_text,
|
completion_message=response_text,
|
||||||
transformed_generation=extracted_answer,
|
|
||||||
),
|
),
|
||||||
expected_output=dataset_sample.data["Answer"],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# def score_sample(self, sample: ProcessedDictSample) -> SingleEvalResult:
|
|
||||||
# postprocessed_output = sample.postprocessed["postprocessed"]
|
|
||||||
# expected_answer = sample.data["Answer"]
|
|
||||||
|
|
||||||
# extracted_answer = None
|
|
||||||
# for answer_regex in MULTILINGUAL_ANSWER_REGEXES:
|
|
||||||
# regex = MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(answer_regex)
|
|
||||||
# match = re.search(regex, postprocessed_output)
|
|
||||||
# if match:
|
|
||||||
# extracted_answer = normalize_extracted_answer(match.group(1))
|
|
||||||
# break
|
|
||||||
|
|
||||||
# score = 1.0 if extracted_answer and extracted_answer == expected_answer else 0.0
|
|
||||||
|
|
||||||
# return SingleEvalResult(
|
|
||||||
# score_data={
|
|
||||||
# "score": score,
|
|
||||||
# },
|
|
||||||
# )
|
|
||||||
|
|
||||||
# def aggregate_results(self, eval_results: List[SingleEvalResult]) -> EvalResult:
|
|
||||||
# print("aggregate_results", eval_results)
|
|
||||||
# sum_score = sum([result.score_data["score"] for result in eval_results])
|
|
||||||
|
|
||||||
# return EvalResult(metrics={"score": str(sum_score / len(eval_results))})
|
|
||||||
|
|
|
@ -28,8 +28,8 @@ class RandomScorer(BaseScorer[ScorerInputSample]):
|
||||||
|
|
||||||
class AccuracyScorer(BaseScorer[ScorerInputSample]):
|
class AccuracyScorer(BaseScorer[ScorerInputSample]):
|
||||||
def score_sample(self, scorer_input_sample: ScorerInputSample) -> SingleEvalResult:
|
def score_sample(self, scorer_input_sample: ScorerInputSample) -> SingleEvalResult:
|
||||||
extracted_answer = scorer_input_sample.generation_output.transformed_generation
|
extracted_answer = scorer_input_sample.generated_answer
|
||||||
expected_answer = scorer_input_sample.expected_output
|
expected_answer = scorer_input_sample.expected_answer
|
||||||
|
|
||||||
accuracy = (
|
accuracy = (
|
||||||
1.0 if extracted_answer and extracted_answer == expected_answer else 0.0
|
1.0 if extracted_answer and extracted_answer == expected_answer else 0.0
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue