From d1633dc412468056c49ec784b5f9066a02c67af8 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 7 Nov 2024 15:20:22 -0800 Subject: [PATCH] huggingface provider --- llama_stack/apis/eval/eval.py | 8 +++ llama_stack/distribution/routers/routers.py | 5 +- .../inline/meta_reference/eval/eval.py | 62 ++++++++++++++----- .../scoring_fn/regex_parser_scoring_fn.py | 6 +- llama_stack/providers/tests/eval/test_eval.py | 52 +++++++++++----- 5 files changed, 99 insertions(+), 34 deletions(-) diff --git a/llama_stack/apis/eval/eval.py b/llama_stack/apis/eval/eval.py index 6aa4cae34..549c1123e 100644 --- a/llama_stack/apis/eval/eval.py +++ b/llama_stack/apis/eval/eval.py @@ -40,6 +40,10 @@ EvalCandidate = Annotated[ class BenchmarkEvalTaskConfig(BaseModel): type: Literal["benchmark"] = "benchmark" eval_candidate: EvalCandidate + num_examples: Optional[int] = Field( + description="Number of examples to evaluate (useful for quick debugging), if not provided, all examples in the dataset will be evaluated", + default=None, + ) @json_schema_type @@ -50,6 +54,10 @@ class AppEvalTaskConfig(BaseModel): description="Map between scoring function id and parameters for each scoring function you want to run", default_factory=dict, ) + num_examples: Optional[int] = Field( + description="Number of examples to evaluate (useful for quick debugging), if not provided, all examples in the dataset will be evaluated", + default=None, + ) # we could optinally add any specific dataset config here diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 06d50bd65..331f9225e 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -273,7 +273,10 @@ class EvalRouter(Eval): benchmark_id: str, benchmark_config: BenchmarkEvalTaskConfig, ) -> Job: - pass + return await self.routing_table.get_provider_impl(benchmark_id).run_benchmark( + benchmark_id=benchmark_id, + benchmark_config=benchmark_config, + ) async def run_eval( self, diff --git a/llama_stack/providers/inline/meta_reference/eval/eval.py b/llama_stack/providers/inline/meta_reference/eval/eval.py index a9a1978e9..e40f329f9 100644 --- a/llama_stack/providers/inline/meta_reference/eval/eval.py +++ b/llama_stack/providers/inline/meta_reference/eval/eval.py @@ -16,6 +16,8 @@ from .....apis.eval.eval import ( JobStatus, ) from llama_stack.apis.common.type_system import * # noqa: F403 +from tqdm import tqdm + from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets from llama_stack.apis.eval_tasks import EvalTaskDef @@ -58,22 +60,32 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): # TODO: assume sync job, will need jobs API for async scheduling self.jobs = {} - async def initialize(self) -> None: ... + # Keep track of benchmark eval tasks that are supported by this provider + self.eval_tasks = {} + + async def initialize(self) -> None: + self.eval_tasks = { + # NOTE: In order to be routed to this provider, the eval task def must have + # a EvalTaskDef with identifier defined as DEFAULT_EVAL_TASK_IDENTIFIER + # for app eval where eval task benchmark_id is not pre-registered + DEFAULT_EVAL_TASK_IDENTIFIER: EvalTaskDef( + identifier=DEFAULT_EVAL_TASK_IDENTIFIER, + dataset_id="", + scoring_functions=[], + ), + "meta-reference-mmlu": EvalTaskDef( + identifier="meta-reference-mmlu", + dataset_id="llamastack_mmlu", + scoring_functions=[ + "meta-reference::regex_parser_multiple_choice_answer" + ], + ), + } async def shutdown(self) -> None: ... async def list_eval_tasks(self) -> List[EvalTaskDef]: - # NOTE: In order to be routed to this provider, the eval task def must have - # a EvalTaskDef with identifier defined as DEFAULT_EVAL_TASK_IDENTIFIER - # for app eval where eval task benchmark_id is not pre-registered - eval_tasks = [ - EvalTaskDef( - identifier=DEFAULT_EVAL_TASK_IDENTIFIER, - dataset_id="", - scoring_functions=[], - ) - ] - return eval_tasks + return list(self.eval_tasks.values()) async def validate_eval_input_dataset_schema(self, dataset_id: str) -> None: dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id) @@ -103,7 +115,25 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): benchmark_id: str, benchmark_config: BenchmarkEvalTaskConfig, ) -> Job: - raise NotImplementedError("Benchmark eval is not implemented yet") + eval_task_def = self.eval_tasks[benchmark_id] + all_rows = await self.datasetio_api.get_rows_paginated( + dataset_id=eval_task_def.dataset_id, + rows_in_page=( + -1 + if benchmark_config.num_examples is None + else benchmark_config.num_examples + ), + ) + res = await self.evaluate_rows( + input_rows=all_rows.rows, + scoring_functions=eval_task_def.scoring_functions, + task_config=benchmark_config, + ) + # TODO: currently needs to wait for generation before returning + # need job scheduler queue (celery) w/ jobs api + job_id = str(len(self.jobs)) + self.jobs[job_id] = res + return Job(job_id=job_id) async def run_eval( self, @@ -117,7 +147,9 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): await self.validate_eval_input_dataset_schema(dataset_id=dataset_id) all_rows = await self.datasetio_api.get_rows_paginated( dataset_id=dataset_id, - rows_in_page=-1, + rows_in_page=( + -1 if task_config.num_examples is None else task_config.num_examples + ), ) res = await self.evaluate_rows( input_rows=all_rows.rows, @@ -148,7 +180,7 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): ), "SamplingParams.max_tokens must be provided" generations = [] - for x in input_rows: + for x in tqdm(input_rows): if ColumnName.completion_input.value in x: input_content = eval(str(x[ColumnName.completion_input.value])) response = await self.inference_api.completion( diff --git a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/regex_parser_scoring_fn.py b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/regex_parser_scoring_fn.py index 70113cf48..0aff2f535 100644 --- a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/regex_parser_scoring_fn.py +++ b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/regex_parser_scoring_fn.py @@ -31,11 +31,15 @@ class RegexParserScoringFn(BaseScoringFn): self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None, + scoring_params: Optional[ScoringFnParams] = None, ) -> ScoringResultRow: assert ( scoring_fn_identifier is not None ), "Scoring function identifier not found." fn_def = self.supported_fn_defs_registry[scoring_fn_identifier] + if scoring_params is not None: + fn_def.params = scoring_params + assert ( fn_def.params is not None and fn_def.params.type == ScoringConfigType.regex_parser.value @@ -46,7 +50,7 @@ class RegexParserScoringFn(BaseScoringFn): # parse answer according to regex parsed_answer = None - for regex in fn_def.params.parsing_regex: + for regex in fn_def.params.parsing_regexes: match = re.search(regex, generated_answer) if match: parsed_answer = match.group(1) diff --git a/llama_stack/providers/tests/eval/test_eval.py b/llama_stack/providers/tests/eval/test_eval.py index 1a41bb44a..fe09832bb 100644 --- a/llama_stack/providers/tests/eval/test_eval.py +++ b/llama_stack/providers/tests/eval/test_eval.py @@ -9,7 +9,12 @@ import pytest from llama_models.llama3.api import SamplingParams -from llama_stack.apis.eval.eval import AppEvalTaskConfig, EvalTaskDef, ModelCandidate +from llama_stack.apis.eval.eval import ( + AppEvalTaskConfig, + BenchmarkEvalTaskConfig, + EvalTaskDef, + ModelCandidate, +) from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset @@ -36,6 +41,12 @@ class Testeval: await register_dataset( datasets_impl, for_generation=True, dataset_id="test_dataset_for_eval" ) + provider = datasetio_impl.routing_table.get_provider_impl( + "test_dataset_for_eval" + ) + if provider.__provider_spec__.provider_type != "meta-reference": + pytest.skip("Only meta-reference provider supports registering datasets") + response = await datasets_impl.list_datasets() assert len(response) == 1 rows = await datasetio_impl.get_rows_paginated( @@ -69,6 +80,11 @@ class Testeval: await register_dataset( datasets_impl, for_generation=True, dataset_id="test_dataset_for_eval" ) + provider = datasetio_impl.routing_table.get_provider_impl( + "test_dataset_for_eval" + ) + if provider.__provider_spec__.provider_type != "meta-reference": + pytest.skip("Only meta-reference provider supports registering datasets") scoring_functions = [ "meta-reference::llm_as_judge_8b_correctness", @@ -107,27 +123,29 @@ class Testeval: async def test_eval_run_benchmark_eval(self, eval_stack): eval_impl, eval_tasks_impl, _, _, datasetio_impl, datasets_impl = eval_stack response = await datasets_impl.list_datasets() - assert len(response) == 1 + assert len(response) > 0 + if response[0].provider_id != "huggingface": + pytest.skip( + "Only huggingface provider supports pre-registered benchmarks datasets" + ) - rows = await datasetio_impl.get_rows_paginated( - dataset_id="llamastack_mmlu", - rows_in_page=3, - ) - assert len(rows.rows) == 3 + # list benchmarks + response = await eval_tasks_impl.list_eval_tasks() + assert len(response) > 0 - scoring_functions = [ - "meta-reference::regex_parser_multiple_choice_answer", - ] - - response = await eval_impl.evaluate_rows( - input_rows=rows.rows, - scoring_functions=scoring_functions, - eval_task_config=AppEvalTaskConfig( + benchmark_id = "meta-reference-mmlu" + response = await eval_impl.run_benchmark( + benchmark_id=benchmark_id, + benchmark_config=BenchmarkEvalTaskConfig( eval_candidate=ModelCandidate( model="Llama3.2-3B-Instruct", sampling_params=SamplingParams(), ), + num_examples=3, ), ) - - print(response) + job_status = await eval_impl.job_status(response.job_id, benchmark_id) + assert job_status and job_status.value == "completed" + eval_response = await eval_impl.job_result(response.job_id, benchmark_id) + assert eval_response is not None + assert len(eval_response.generations) == 3