From bca96b5b35b4abd436eeacdf5ea9532f8c95aa5d Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Tue, 5 Nov 2024 14:55:59 -0800 Subject: [PATCH] eval api --- llama_stack/apis/eval/eval.py | 35 ++++++++++++--- llama_stack/apis/eval_tasks/eval_tasks.py | 55 +++++++++++++++++++++++ 2 files changed, 85 insertions(+), 5 deletions(-) create mode 100644 llama_stack/apis/eval_tasks/eval_tasks.py diff --git a/llama_stack/apis/eval/eval.py b/llama_stack/apis/eval/eval.py index 53fe49a8d..60d2567b9 100644 --- a/llama_stack/apis/eval/eval.py +++ b/llama_stack/apis/eval/eval.py @@ -14,6 +14,8 @@ from llama_stack.apis.scoring_functions import * # noqa: F403 from llama_stack.apis.agents import AgentConfig from llama_stack.apis.common.job_types import Job, JobStatus from llama_stack.apis.scoring import * # noqa: F403 +from llama_stack.apis.scoring_functions import * # noqa: F403 +from llama_stack.apis.eval_tasks import * # noqa: F403 @json_schema_type @@ -34,11 +36,27 @@ EvalCandidate = Annotated[ Union[ModelCandidate, AgentCandidate], Field(discriminator="type") ] -# @json_schema_type -# class EvalTaskDef(BaseModel): -# dataset_id: str -# candidate: EvalCandidate -# scoring_functions: List[str] + +@json_schema_type +class BenchmarkEvalTaskConfig(BaseModel): + type: Literal["benchmark"] = "benchmark" + eval_candidate: EvalCandidate # type: ignore + + +@json_schema_type +class AppEvalTaskConfig(BaseModel): + type: Literal["app"] = "app" + eval_candidate: EvalCandidate # type: ignore + scoring_functions_config: Dict[str, ScoringFnConfig] = Field( + description="Map between scoring function id and parameters", + default_factory=dict, + ) + # we could optinally add any GenEval specific dataset config here + + +EvalTaskConfig = Annotated[ + Union[BenchmarkEvalTaskConfig, AppEvalTaskConfig], Field(discriminator="type") +] @json_schema_type @@ -50,6 +68,13 @@ class EvaluateResponse(BaseModel): class Eval(Protocol): + @webmethod(route="/eval/evaluate_batch", method="POST") + async def evaluate_task( + self, + eval_task_def: Union[str, EvalTaskDef], # type: ignore + eval_task_config: EvalTaskConfig, # type: ignore + ) -> Job: ... + @webmethod(route="/eval/evaluate_batch", method="POST") async def evaluate_batch( self, diff --git a/llama_stack/apis/eval_tasks/eval_tasks.py b/llama_stack/apis/eval_tasks/eval_tasks.py new file mode 100644 index 000000000..62d0f7ef1 --- /dev/null +++ b/llama_stack/apis/eval_tasks/eval_tasks.py @@ -0,0 +1,55 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from enum import Enum +from typing import ( + Any, + Dict, + List, + Literal, + Optional, + Protocol, + runtime_checkable, + Union, +) + +from llama_models.schema_utils import json_schema_type, webmethod + +from llama_stack.apis.common.type_system import ParamType +from pydantic import BaseModel, Field +from typing_extensions import Annotated + + +@json_schema_type +class EvalTaskDef(BaseModel): + identifier: str + dataset_id: str + scoring_functions: List[str] + metadata: Dict[str, Any] = Field( + default_factory=dict, + description="Metadata for this evaluation task (e.g. from GECO)", + ) + + +@json_schema_type +class EvalTaskDefWithProvider(EvalTaskDef): + type: Literal["eval_task"] = "eval_task" + provider_id: str = Field( + description="ID of the provider which serves this dataset", + ) + + +@runtime_checkable +class EvalTasks(Protocol): + @webmethod(route="/eval_tasks/list", method="GET") + async def list_eval_tasks(self) -> List[EvalTaskDefWithProvider]: ... + + @webmethod(route="/eval_tasks/get", method="GET") + async def get_eval_tasks(self, name: str) -> Optional[EvalTaskDefWithProvider]: ... + + @webmethod(route="/eval_tasks/register", method="POST") + async def register_eval_tasks( + self, function_def: EvalTaskDefWithProvider + ) -> None: ...