From 08c0c5505e68750a48e0fad3149e56d6eadeb3e0 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Tue, 18 Mar 2025 21:09:49 -0700 Subject: [PATCH] feat(eval api): (2.1/n) fix resolver for benchmark routing table + fix precommit (#1691) # What does this PR do? - fixes routing table so that `llama stack run` works - fixes pre-commit - one of many fixes to address implementation fix [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan ``` llama stack run ``` [//]: # (## Documentation) --- .../distribution/routers/routing_tables.py | 35 ++++++++++--------- .../open-benchmark/open_benchmark.py | 10 +++--- llama_stack/templates/open-benchmark/run.yaml | 10 +++--- pyproject.toml | 2 ++ 4 files changed, 31 insertions(+), 26 deletions(-) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 5dea942f7..3e44d2926 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -466,35 +466,38 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks): raise ValueError(f"Benchmark '{benchmark_id}' not found") return benchmark + async def unregister_benchmark(self, benchmark_id: str) -> None: + benchmark = await self.get_benchmark(benchmark_id) + if benchmark is None: + raise ValueError(f"Benchmark {benchmark_id} not found") + await self.unregister_object(benchmark) + async def register_benchmark( self, - benchmark_id: str, dataset_id: str, - scoring_functions: List[str], + grader_ids: List[str], + benchmark_id: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, - provider_benchmark_id: Optional[str] = None, - provider_id: Optional[str] = None, - ) -> None: + ) -> Benchmark: if metadata is None: metadata = {} - if provider_id is None: - if len(self.impls_by_provider_id) == 1: - provider_id = list(self.impls_by_provider_id.keys())[0] - else: - raise ValueError( - "No provider specified and multiple providers available. Please specify a provider_id." - ) - if provider_benchmark_id is None: - provider_benchmark_id = benchmark_id + + # TODO (xiyan): we will need a way to infer provider_id for evaluation + # keep it as meta-reference for now + if len(self.impls_by_provider_id) == 0: + raise ValueError("No evaluation providers available. Please configure an evaluation provider.") + provider_id = list(self.impls_by_provider_id.keys())[0] + benchmark = Benchmark( identifier=benchmark_id, dataset_id=dataset_id, - scoring_functions=scoring_functions, + grader_ids=grader_ids, metadata=metadata, provider_id=provider_id, - provider_resource_id=provider_benchmark_id, + provider_resource_id=benchmark_id, ) await self.register_object(benchmark) + return benchmark class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): diff --git a/llama_stack/templates/open-benchmark/open_benchmark.py b/llama_stack/templates/open-benchmark/open_benchmark.py index b339e8c80..03e524dae 100644 --- a/llama_stack/templates/open-benchmark/open_benchmark.py +++ b/llama_stack/templates/open-benchmark/open_benchmark.py @@ -214,27 +214,27 @@ def get_distribution_template() -> DistributionTemplate: BenchmarkInput( benchmark_id="meta-reference-simpleqa", dataset_id="simpleqa", - scoring_functions=["llm-as-judge::405b-simpleqa"], + grader_ids=["llm-as-judge::405b-simpleqa"], ), BenchmarkInput( benchmark_id="meta-reference-mmlu-cot", dataset_id="mmlu_cot", - scoring_functions=["basic::regex_parser_multiple_choice_answer"], + grader_ids=["basic::regex_parser_multiple_choice_answer"], ), BenchmarkInput( benchmark_id="meta-reference-gpqa-cot", dataset_id="gpqa_cot", - scoring_functions=["basic::regex_parser_multiple_choice_answer"], + grader_ids=["basic::regex_parser_multiple_choice_answer"], ), BenchmarkInput( benchmark_id="meta-reference-math-500", dataset_id="math_500", - scoring_functions=["basic::regex_parser_math_response"], + grader_ids=["basic::regex_parser_math_response"], ), BenchmarkInput( benchmark_id="meta-reference-bfcl", dataset_id="bfcl", - scoring_functions=["basic::bfcl"], + grader_ids=["basic::bfcl"], ), ] return DistributionTemplate( diff --git a/llama_stack/templates/open-benchmark/run.yaml b/llama_stack/templates/open-benchmark/run.yaml index 93f437273..a3c00af56 100644 --- a/llama_stack/templates/open-benchmark/run.yaml +++ b/llama_stack/templates/open-benchmark/run.yaml @@ -196,27 +196,27 @@ datasets: scoring_fns: [] benchmarks: - dataset_id: simpleqa - scoring_functions: + grader_ids: - llm-as-judge::405b-simpleqa metadata: {} benchmark_id: meta-reference-simpleqa - dataset_id: mmlu_cot - scoring_functions: + grader_ids: - basic::regex_parser_multiple_choice_answer metadata: {} benchmark_id: meta-reference-mmlu-cot - dataset_id: gpqa_cot - scoring_functions: + grader_ids: - basic::regex_parser_multiple_choice_answer metadata: {} benchmark_id: meta-reference-gpqa-cot - dataset_id: math_500 - scoring_functions: + grader_ids: - basic::regex_parser_math_response metadata: {} benchmark_id: meta-reference-math-500 - dataset_id: bfcl - scoring_functions: + grader_ids: - basic::bfcl metadata: {} benchmark_id: meta-reference-bfcl diff --git a/pyproject.toml b/pyproject.toml index 107150cee..cf4e81ab8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -186,6 +186,8 @@ exclude = [ "^llama_stack/apis/tools/tools\\.py$", "^llama_stack/apis/vector_dbs/vector_dbs\\.py$", "^llama_stack/apis/vector_io/vector_io\\.py$", + "^llama_stack/apis/graders/graders\\.py$", + "^llama_stack/apis/evaluation/evaluation\\.py$", "^llama_stack/cli/download\\.py$", "^llama_stack/cli/llama\\.py$", "^llama_stack/cli/stack/_build\\.py$",