forked from phoenix-oss/llama-stack-mirror
migrate evals to resource (#421)
* migrate evals to resource * remove listing of providers's evals * change the order of params in register * fix after rebase * linter fix --------- Co-authored-by: Dinesh Yeduguru <dineshyv@fb.com>
This commit is contained in:
parent
b95cb5308f
commit
3802edfc50
5 changed files with 63 additions and 56 deletions
|
@ -7,12 +7,14 @@ from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkab
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import Field
|
||||||
|
|
||||||
|
from llama_stack.apis.resource import Resource
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class EvalTaskDef(BaseModel):
|
class EvalTask(Resource):
|
||||||
identifier: str
|
type: Literal["eval_task"] = "eval_task"
|
||||||
dataset_id: str
|
dataset_id: str
|
||||||
scoring_functions: List[str]
|
scoring_functions: List[str]
|
||||||
metadata: Dict[str, Any] = Field(
|
metadata: Dict[str, Any] = Field(
|
||||||
|
@ -21,23 +23,21 @@ class EvalTaskDef(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class EvalTaskDefWithProvider(EvalTaskDef):
|
|
||||||
type: Literal["eval_task"] = "eval_task"
|
|
||||||
provider_id: str = Field(
|
|
||||||
description="ID of the provider which serves this dataset",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class EvalTasks(Protocol):
|
class EvalTasks(Protocol):
|
||||||
@webmethod(route="/eval_tasks/list", method="GET")
|
@webmethod(route="/eval_tasks/list", method="GET")
|
||||||
async def list_eval_tasks(self) -> List[EvalTaskDefWithProvider]: ...
|
async def list_eval_tasks(self) -> List[EvalTask]: ...
|
||||||
|
|
||||||
@webmethod(route="/eval_tasks/get", method="GET")
|
@webmethod(route="/eval_tasks/get", method="GET")
|
||||||
async def get_eval_task(self, name: str) -> Optional[EvalTaskDefWithProvider]: ...
|
async def get_eval_task(self, name: str) -> Optional[EvalTask]: ...
|
||||||
|
|
||||||
@webmethod(route="/eval_tasks/register", method="POST")
|
@webmethod(route="/eval_tasks/register", method="POST")
|
||||||
async def register_eval_task(
|
async def register_eval_task(
|
||||||
self, eval_task_def: EvalTaskDefWithProvider
|
self,
|
||||||
|
eval_task_id: str,
|
||||||
|
dataset_id: str,
|
||||||
|
scoring_functions: List[str],
|
||||||
|
provider_eval_task_id: Optional[str] = None,
|
||||||
|
provider_id: Optional[str] = None,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
|
@ -105,8 +105,6 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
|
|
||||||
elif api == Api.eval:
|
elif api == Api.eval:
|
||||||
p.eval_task_store = self
|
p.eval_task_store = self
|
||||||
eval_tasks = await p.list_eval_tasks()
|
|
||||||
await add_objects(eval_tasks, pid, EvalTaskDefWithProvider)
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
for p in self.impls_by_provider_id.values():
|
for p in self.impls_by_provider_id.values():
|
||||||
|
@ -357,11 +355,38 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
||||||
|
|
||||||
|
|
||||||
class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks):
|
class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks):
|
||||||
async def list_eval_tasks(self) -> List[ScoringFnDefWithProvider]:
|
async def list_eval_tasks(self) -> List[EvalTask]:
|
||||||
return await self.get_all_with_type("eval_task")
|
return await self.get_all_with_type("eval_task")
|
||||||
|
|
||||||
async def get_eval_task(self, name: str) -> Optional[EvalTaskDefWithProvider]:
|
async def get_eval_task(self, name: str) -> Optional[EvalTask]:
|
||||||
return await self.get_object_by_identifier(name)
|
return await self.get_object_by_identifier(name)
|
||||||
|
|
||||||
async def register_eval_task(self, eval_task_def: EvalTaskDefWithProvider) -> None:
|
async def register_eval_task(
|
||||||
await self.register_object(eval_task_def)
|
self,
|
||||||
|
eval_task_id: str,
|
||||||
|
dataset_id: str,
|
||||||
|
scoring_functions: List[str],
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
provider_eval_task_id: Optional[str] = None,
|
||||||
|
provider_id: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
if metadata is None:
|
||||||
|
metadata = {}
|
||||||
|
if provider_id is None:
|
||||||
|
if len(self.impls_by_provider_id) == 1:
|
||||||
|
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||||
|
)
|
||||||
|
if provider_eval_task_id is None:
|
||||||
|
provider_eval_task_id = eval_task_id
|
||||||
|
eval_task = EvalTask(
|
||||||
|
identifier=eval_task_id,
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
scoring_functions=scoring_functions,
|
||||||
|
metadata=metadata,
|
||||||
|
provider_id=provider_id,
|
||||||
|
provider_resource_id=provider_eval_task_id,
|
||||||
|
)
|
||||||
|
await self.register_object(eval_task)
|
||||||
|
|
|
@ -12,7 +12,7 @@ from llama_models.schema_utils import json_schema_type
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_stack.apis.datasets import Dataset
|
from llama_stack.apis.datasets import Dataset
|
||||||
from llama_stack.apis.eval_tasks import EvalTaskDef
|
from llama_stack.apis.eval_tasks import EvalTask
|
||||||
from llama_stack.apis.memory_banks.memory_banks import MemoryBank
|
from llama_stack.apis.memory_banks.memory_banks import MemoryBank
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
from llama_stack.apis.scoring_functions import ScoringFnDef
|
from llama_stack.apis.scoring_functions import ScoringFnDef
|
||||||
|
@ -67,9 +67,7 @@ class ScoringFunctionsProtocolPrivate(Protocol):
|
||||||
|
|
||||||
|
|
||||||
class EvalTasksProtocolPrivate(Protocol):
|
class EvalTasksProtocolPrivate(Protocol):
|
||||||
async def list_eval_tasks(self) -> List[EvalTaskDef]: ...
|
async def register_eval_task(self, eval_task: EvalTask) -> None: ...
|
||||||
|
|
||||||
async def register_eval_task(self, eval_task_def: EvalTaskDef) -> None: ...
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -11,7 +11,7 @@ from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatu
|
||||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.apis.datasets import Datasets
|
from llama_stack.apis.datasets import Datasets
|
||||||
from llama_stack.apis.eval_tasks import EvalTaskDef
|
from llama_stack.apis.eval_tasks import EvalTask
|
||||||
from llama_stack.apis.inference import Inference
|
from llama_stack.apis.inference import Inference
|
||||||
from llama_stack.apis.scoring import Scoring
|
from llama_stack.apis.scoring import Scoring
|
||||||
from llama_stack.providers.datatypes import EvalTasksProtocolPrivate
|
from llama_stack.providers.datatypes import EvalTasksProtocolPrivate
|
||||||
|
@ -53,15 +53,12 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
|
||||||
|
|
||||||
async def shutdown(self) -> None: ...
|
async def shutdown(self) -> None: ...
|
||||||
|
|
||||||
async def register_eval_task(self, task_def: EvalTaskDef) -> None:
|
async def register_eval_task(self, task_def: EvalTask) -> None:
|
||||||
self.eval_tasks[task_def.identifier] = task_def
|
self.eval_tasks[task_def.identifier] = task_def
|
||||||
|
|
||||||
async def list_eval_tasks(self) -> List[EvalTaskDef]:
|
|
||||||
return list(self.eval_tasks.values())
|
|
||||||
|
|
||||||
async def validate_eval_input_dataset_schema(self, dataset_id: str) -> None:
|
async def validate_eval_input_dataset_schema(self, dataset_id: str) -> None:
|
||||||
dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id)
|
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||||
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
|
if not dataset_def.schema or len(dataset_def.schema) == 0:
|
||||||
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
|
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
|
||||||
|
|
||||||
expected_schemas = [
|
expected_schemas = [
|
||||||
|
@ -77,7 +74,7 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
if dataset_def.dataset_schema not in expected_schemas:
|
if dataset_def.schema not in expected_schemas:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}"
|
f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}"
|
||||||
)
|
)
|
||||||
|
|
|
@ -11,12 +11,9 @@ from llama_models.llama3.api import SamplingParams, URL
|
||||||
|
|
||||||
from llama_stack.apis.common.type_system import ChatCompletionInputType, StringType
|
from llama_stack.apis.common.type_system import ChatCompletionInputType, StringType
|
||||||
|
|
||||||
from llama_stack.apis.datasetio.datasetio import DatasetDefWithProvider
|
|
||||||
|
|
||||||
from llama_stack.apis.eval.eval import (
|
from llama_stack.apis.eval.eval import (
|
||||||
AppEvalTaskConfig,
|
AppEvalTaskConfig,
|
||||||
BenchmarkEvalTaskConfig,
|
BenchmarkEvalTaskConfig,
|
||||||
EvalTaskDefWithProvider,
|
|
||||||
ModelCandidate,
|
ModelCandidate,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.scoring_functions import LLMAsJudgeScoringFnParams
|
from llama_stack.apis.scoring_functions import LLMAsJudgeScoringFnParams
|
||||||
|
@ -70,13 +67,11 @@ class Testeval:
|
||||||
"meta-reference::equality",
|
"meta-reference::equality",
|
||||||
]
|
]
|
||||||
task_id = "meta-reference::app_eval"
|
task_id = "meta-reference::app_eval"
|
||||||
task_def = EvalTaskDefWithProvider(
|
await eval_tasks_impl.register_eval_task(
|
||||||
identifier=task_id,
|
eval_task_id=task_id,
|
||||||
dataset_id="test_dataset_for_eval",
|
dataset_id="test_dataset_for_eval",
|
||||||
scoring_functions=scoring_functions,
|
scoring_functions=scoring_functions,
|
||||||
provider_id="meta-reference",
|
|
||||||
)
|
)
|
||||||
await eval_tasks_impl.register_eval_task(task_def)
|
|
||||||
response = await eval_impl.evaluate_rows(
|
response = await eval_impl.evaluate_rows(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
input_rows=rows.rows,
|
input_rows=rows.rows,
|
||||||
|
@ -125,13 +120,11 @@ class Testeval:
|
||||||
]
|
]
|
||||||
|
|
||||||
task_id = "meta-reference::app_eval-2"
|
task_id = "meta-reference::app_eval-2"
|
||||||
task_def = EvalTaskDefWithProvider(
|
await eval_tasks_impl.register_eval_task(
|
||||||
identifier=task_id,
|
eval_task_id=task_id,
|
||||||
dataset_id="test_dataset_for_eval",
|
dataset_id="test_dataset_for_eval",
|
||||||
scoring_functions=scoring_functions,
|
scoring_functions=scoring_functions,
|
||||||
provider_id="meta-reference",
|
|
||||||
)
|
)
|
||||||
await eval_tasks_impl.register_eval_task(task_def)
|
|
||||||
response = await eval_impl.run_eval(
|
response = await eval_impl.run_eval(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
task_config=AppEvalTaskConfig(
|
task_config=AppEvalTaskConfig(
|
||||||
|
@ -169,35 +162,29 @@ class Testeval:
|
||||||
pytest.skip(
|
pytest.skip(
|
||||||
"Only huggingface provider supports pre-registered remote datasets"
|
"Only huggingface provider supports pre-registered remote datasets"
|
||||||
)
|
)
|
||||||
# register dataset
|
|
||||||
mmlu = DatasetDefWithProvider(
|
await datasets_impl.register_dataset(
|
||||||
identifier="mmlu",
|
dataset_id="mmlu",
|
||||||
url=URL(uri="https://huggingface.co/datasets/llamastack/evals"),
|
schema={
|
||||||
dataset_schema={
|
|
||||||
"input_query": StringType(),
|
"input_query": StringType(),
|
||||||
"expected_answer": StringType(),
|
"expected_answer": StringType(),
|
||||||
"chat_completion_input": ChatCompletionInputType(),
|
"chat_completion_input": ChatCompletionInputType(),
|
||||||
},
|
},
|
||||||
|
url=URL(uri="https://huggingface.co/datasets/llamastack/evals"),
|
||||||
metadata={
|
metadata={
|
||||||
"path": "llamastack/evals",
|
"path": "llamastack/evals",
|
||||||
"name": "evals__mmlu__details",
|
"name": "evals__mmlu__details",
|
||||||
"split": "train",
|
"split": "train",
|
||||||
},
|
},
|
||||||
provider_id="",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
await datasets_impl.register_dataset(mmlu)
|
|
||||||
|
|
||||||
# register eval task
|
# register eval task
|
||||||
meta_reference_mmlu = EvalTaskDefWithProvider(
|
await eval_tasks_impl.register_eval_task(
|
||||||
identifier="meta-reference-mmlu",
|
eval_task_id="meta-reference-mmlu",
|
||||||
dataset_id="mmlu",
|
dataset_id="mmlu",
|
||||||
scoring_functions=["meta-reference::regex_parser_multiple_choice_answer"],
|
scoring_functions=["meta-reference::regex_parser_multiple_choice_answer"],
|
||||||
provider_id="",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
await eval_tasks_impl.register_eval_task(meta_reference_mmlu)
|
|
||||||
|
|
||||||
# list benchmarks
|
# list benchmarks
|
||||||
response = await eval_tasks_impl.list_eval_tasks()
|
response = await eval_tasks_impl.list_eval_tasks()
|
||||||
assert len(response) > 0
|
assert len(response) > 0
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue