migrate evals to resource (#421)

* migrate evals to resource

* remove listing of providers's evals

* change the order of params in register

* fix after rebase

* linter fix

---------

Co-authored-by: Dinesh Yeduguru <dineshyv@fb.com>
This commit is contained in:
Dinesh Yeduguru 2024-11-11 17:24:03 -08:00 committed by GitHub
parent b95cb5308f
commit 3802edfc50
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 63 additions and 56 deletions

View file

@ -7,12 +7,14 @@ from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkab
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field from pydantic import Field
from llama_stack.apis.resource import Resource
@json_schema_type @json_schema_type
class EvalTaskDef(BaseModel): class EvalTask(Resource):
identifier: str type: Literal["eval_task"] = "eval_task"
dataset_id: str dataset_id: str
scoring_functions: List[str] scoring_functions: List[str]
metadata: Dict[str, Any] = Field( metadata: Dict[str, Any] = Field(
@ -21,23 +23,21 @@ class EvalTaskDef(BaseModel):
) )
@json_schema_type
class EvalTaskDefWithProvider(EvalTaskDef):
type: Literal["eval_task"] = "eval_task"
provider_id: str = Field(
description="ID of the provider which serves this dataset",
)
@runtime_checkable @runtime_checkable
class EvalTasks(Protocol): class EvalTasks(Protocol):
@webmethod(route="/eval_tasks/list", method="GET") @webmethod(route="/eval_tasks/list", method="GET")
async def list_eval_tasks(self) -> List[EvalTaskDefWithProvider]: ... async def list_eval_tasks(self) -> List[EvalTask]: ...
@webmethod(route="/eval_tasks/get", method="GET") @webmethod(route="/eval_tasks/get", method="GET")
async def get_eval_task(self, name: str) -> Optional[EvalTaskDefWithProvider]: ... async def get_eval_task(self, name: str) -> Optional[EvalTask]: ...
@webmethod(route="/eval_tasks/register", method="POST") @webmethod(route="/eval_tasks/register", method="POST")
async def register_eval_task( async def register_eval_task(
self, eval_task_def: EvalTaskDefWithProvider self,
eval_task_id: str,
dataset_id: str,
scoring_functions: List[str],
provider_eval_task_id: Optional[str] = None,
provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> None: ... ) -> None: ...

View file

@ -105,8 +105,6 @@ class CommonRoutingTableImpl(RoutingTable):
elif api == Api.eval: elif api == Api.eval:
p.eval_task_store = self p.eval_task_store = self
eval_tasks = await p.list_eval_tasks()
await add_objects(eval_tasks, pid, EvalTaskDefWithProvider)
async def shutdown(self) -> None: async def shutdown(self) -> None:
for p in self.impls_by_provider_id.values(): for p in self.impls_by_provider_id.values():
@ -357,11 +355,38 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks): class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks):
async def list_eval_tasks(self) -> List[ScoringFnDefWithProvider]: async def list_eval_tasks(self) -> List[EvalTask]:
return await self.get_all_with_type("eval_task") return await self.get_all_with_type("eval_task")
async def get_eval_task(self, name: str) -> Optional[EvalTaskDefWithProvider]: async def get_eval_task(self, name: str) -> Optional[EvalTask]:
return await self.get_object_by_identifier(name) return await self.get_object_by_identifier(name)
async def register_eval_task(self, eval_task_def: EvalTaskDefWithProvider) -> None: async def register_eval_task(
await self.register_object(eval_task_def) self,
eval_task_id: str,
dataset_id: str,
scoring_functions: List[str],
metadata: Optional[Dict[str, Any]] = None,
provider_eval_task_id: Optional[str] = None,
provider_id: Optional[str] = None,
) -> None:
if metadata is None:
metadata = {}
if provider_id is None:
if len(self.impls_by_provider_id) == 1:
provider_id = list(self.impls_by_provider_id.keys())[0]
else:
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id."
)
if provider_eval_task_id is None:
provider_eval_task_id = eval_task_id
eval_task = EvalTask(
identifier=eval_task_id,
dataset_id=dataset_id,
scoring_functions=scoring_functions,
metadata=metadata,
provider_id=provider_id,
provider_resource_id=provider_eval_task_id,
)
await self.register_object(eval_task)

View file

@ -12,7 +12,7 @@ from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_stack.apis.datasets import Dataset from llama_stack.apis.datasets import Dataset
from llama_stack.apis.eval_tasks import EvalTaskDef from llama_stack.apis.eval_tasks import EvalTask
from llama_stack.apis.memory_banks.memory_banks import MemoryBank from llama_stack.apis.memory_banks.memory_banks import MemoryBank
from llama_stack.apis.models import Model from llama_stack.apis.models import Model
from llama_stack.apis.scoring_functions import ScoringFnDef from llama_stack.apis.scoring_functions import ScoringFnDef
@ -67,9 +67,7 @@ class ScoringFunctionsProtocolPrivate(Protocol):
class EvalTasksProtocolPrivate(Protocol): class EvalTasksProtocolPrivate(Protocol):
async def list_eval_tasks(self) -> List[EvalTaskDef]: ... async def register_eval_task(self, eval_task: EvalTask) -> None: ...
async def register_eval_task(self, eval_task_def: EvalTaskDef) -> None: ...
@json_schema_type @json_schema_type

View file

@ -11,7 +11,7 @@ from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatu
from llama_stack.apis.common.type_system import * # noqa: F403 from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval_tasks import EvalTaskDef from llama_stack.apis.eval_tasks import EvalTask
from llama_stack.apis.inference import Inference from llama_stack.apis.inference import Inference
from llama_stack.apis.scoring import Scoring from llama_stack.apis.scoring import Scoring
from llama_stack.providers.datatypes import EvalTasksProtocolPrivate from llama_stack.providers.datatypes import EvalTasksProtocolPrivate
@ -53,15 +53,12 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
async def shutdown(self) -> None: ... async def shutdown(self) -> None: ...
async def register_eval_task(self, task_def: EvalTaskDef) -> None: async def register_eval_task(self, task_def: EvalTask) -> None:
self.eval_tasks[task_def.identifier] = task_def self.eval_tasks[task_def.identifier] = task_def
async def list_eval_tasks(self) -> List[EvalTaskDef]:
return list(self.eval_tasks.values())
async def validate_eval_input_dataset_schema(self, dataset_id: str) -> None: async def validate_eval_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id) dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0: if not dataset_def.schema or len(dataset_def.schema) == 0:
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.") raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
expected_schemas = [ expected_schemas = [
@ -77,7 +74,7 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
}, },
] ]
if dataset_def.dataset_schema not in expected_schemas: if dataset_def.schema not in expected_schemas:
raise ValueError( raise ValueError(
f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}" f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}"
) )

View file

@ -11,12 +11,9 @@ from llama_models.llama3.api import SamplingParams, URL
from llama_stack.apis.common.type_system import ChatCompletionInputType, StringType from llama_stack.apis.common.type_system import ChatCompletionInputType, StringType
from llama_stack.apis.datasetio.datasetio import DatasetDefWithProvider
from llama_stack.apis.eval.eval import ( from llama_stack.apis.eval.eval import (
AppEvalTaskConfig, AppEvalTaskConfig,
BenchmarkEvalTaskConfig, BenchmarkEvalTaskConfig,
EvalTaskDefWithProvider,
ModelCandidate, ModelCandidate,
) )
from llama_stack.apis.scoring_functions import LLMAsJudgeScoringFnParams from llama_stack.apis.scoring_functions import LLMAsJudgeScoringFnParams
@ -70,13 +67,11 @@ class Testeval:
"meta-reference::equality", "meta-reference::equality",
] ]
task_id = "meta-reference::app_eval" task_id = "meta-reference::app_eval"
task_def = EvalTaskDefWithProvider( await eval_tasks_impl.register_eval_task(
identifier=task_id, eval_task_id=task_id,
dataset_id="test_dataset_for_eval", dataset_id="test_dataset_for_eval",
scoring_functions=scoring_functions, scoring_functions=scoring_functions,
provider_id="meta-reference",
) )
await eval_tasks_impl.register_eval_task(task_def)
response = await eval_impl.evaluate_rows( response = await eval_impl.evaluate_rows(
task_id=task_id, task_id=task_id,
input_rows=rows.rows, input_rows=rows.rows,
@ -125,13 +120,11 @@ class Testeval:
] ]
task_id = "meta-reference::app_eval-2" task_id = "meta-reference::app_eval-2"
task_def = EvalTaskDefWithProvider( await eval_tasks_impl.register_eval_task(
identifier=task_id, eval_task_id=task_id,
dataset_id="test_dataset_for_eval", dataset_id="test_dataset_for_eval",
scoring_functions=scoring_functions, scoring_functions=scoring_functions,
provider_id="meta-reference",
) )
await eval_tasks_impl.register_eval_task(task_def)
response = await eval_impl.run_eval( response = await eval_impl.run_eval(
task_id=task_id, task_id=task_id,
task_config=AppEvalTaskConfig( task_config=AppEvalTaskConfig(
@ -169,35 +162,29 @@ class Testeval:
pytest.skip( pytest.skip(
"Only huggingface provider supports pre-registered remote datasets" "Only huggingface provider supports pre-registered remote datasets"
) )
# register dataset
mmlu = DatasetDefWithProvider( await datasets_impl.register_dataset(
identifier="mmlu", dataset_id="mmlu",
url=URL(uri="https://huggingface.co/datasets/llamastack/evals"), schema={
dataset_schema={
"input_query": StringType(), "input_query": StringType(),
"expected_answer": StringType(), "expected_answer": StringType(),
"chat_completion_input": ChatCompletionInputType(), "chat_completion_input": ChatCompletionInputType(),
}, },
url=URL(uri="https://huggingface.co/datasets/llamastack/evals"),
metadata={ metadata={
"path": "llamastack/evals", "path": "llamastack/evals",
"name": "evals__mmlu__details", "name": "evals__mmlu__details",
"split": "train", "split": "train",
}, },
provider_id="",
) )
await datasets_impl.register_dataset(mmlu)
# register eval task # register eval task
meta_reference_mmlu = EvalTaskDefWithProvider( await eval_tasks_impl.register_eval_task(
identifier="meta-reference-mmlu", eval_task_id="meta-reference-mmlu",
dataset_id="mmlu", dataset_id="mmlu",
scoring_functions=["meta-reference::regex_parser_multiple_choice_answer"], scoring_functions=["meta-reference::regex_parser_multiple_choice_answer"],
provider_id="",
) )
await eval_tasks_impl.register_eval_task(meta_reference_mmlu)
# list benchmarks # list benchmarks
response = await eval_tasks_impl.list_eval_tasks() response = await eval_tasks_impl.list_eval_tasks()
assert len(response) > 0 assert len(response) > 0