From 5836c09c572c9ca30e766b1968906ba4da14d96a Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Tue, 22 Oct 2024 08:53:46 -0700 Subject: [PATCH] scoring functions + evals --- llama_stack/apis/{evals => eval}/__init__.py | 2 +- llama_stack/apis/eval/eval.py | 72 +++++++++++ llama_stack/apis/evals/evals.py | 122 ------------------ llama_stack/apis/scoring/__init__.py | 7 + llama_stack/apis/scoring/scoring.py | 46 +++++++ .../apis/scoring_functions/__init__.py | 7 + .../scoring_functions/scoring_functions.py | 87 +++++++++++++ 7 files changed, 220 insertions(+), 123 deletions(-) rename llama_stack/apis/{evals => eval}/__init__.py (83%) create mode 100644 llama_stack/apis/eval/eval.py delete mode 100644 llama_stack/apis/evals/evals.py create mode 100644 llama_stack/apis/scoring/__init__.py create mode 100644 llama_stack/apis/scoring/scoring.py create mode 100644 llama_stack/apis/scoring_functions/__init__.py create mode 100644 llama_stack/apis/scoring_functions/scoring_functions.py diff --git a/llama_stack/apis/evals/__init__.py b/llama_stack/apis/eval/__init__.py similarity index 83% rename from llama_stack/apis/evals/__init__.py rename to llama_stack/apis/eval/__init__.py index d21b97d0a..5f91ad70d 100644 --- a/llama_stack/apis/evals/__init__.py +++ b/llama_stack/apis/eval/__init__.py @@ -4,4 +4,4 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .evals import * # noqa: F401 F403 +from .eval import * # noqa: F401 F403 diff --git a/llama_stack/apis/eval/eval.py b/llama_stack/apis/eval/eval.py new file mode 100644 index 000000000..5fcd267d9 --- /dev/null +++ b/llama_stack/apis/eval/eval.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Literal, Optional, Protocol, Union + +from typing_extensions import Annotated + +from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_models.schema_utils import json_schema_type, webmethod +from llama_stack.apis.scoring_functions import * # noqa: F403 + + +@json_schema_type +class ModelCandidate(BaseModel): + type: Literal["model"] = "model" + model: str + sampling_params: SamplingParams + system_message: Optional[SystemMessage] = None + + +@json_schema_type +class AgentCandidate(BaseModel): + type: Literal["agent"] = "agent" + config: AgentConfig + + +EvalCandidate = Annotated[ + Union[ModelCandidate, AgentCandidate], Field(discriminator="type") +] + + +@json_schema_type +class Job(BaseModel): + job_id: str + + +@json_schema_type +class EvaluateResponse(BaseModel): + generations: List[Dict[str, Any]] + + # each key in the dict is a scoring function name + scores: List[Dict[str, ScoringResult]] + + +class Eval(Protocol): + @webmethod(route="/eval/evaluate_batch", method="POST") + async def evaluate_batch( + self, + dataset_id: str, + candidate: EvalCandidate, + scoring_functions: List[str], + ) -> Job: ... + + @webmethod(route="/eval/evaluate", method="POST") + async def evaluate( + self, + input_rows: List[Dict[str, Any]], + candidate: EvalCandidate, + scoring_functions: List[str], + ) -> EvaluateResponse: ... + + @webmethod(route="/eval/job/status", method="GET") + async def job_status(self, job_id: str) -> None: ... + + @webmethod(route="/eval/job/cancel", method="POST") + async def job_cancel(self, job_id: str) -> None: ... + + @webmethod(route="/eval/job/result", method="GET") + async def job_result(self, job_id: str) -> None: ... diff --git a/llama_stack/apis/evals/evals.py b/llama_stack/apis/evals/evals.py deleted file mode 100644 index 0be2243ab..000000000 --- a/llama_stack/apis/evals/evals.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from enum import Enum -from typing import List, Protocol - -from llama_models.schema_utils import webmethod - -from pydantic import BaseModel - -from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.apis.dataset import * # noqa: F403 -from llama_stack.apis.common.training_types import * # noqa: F403 - - -class TextGenerationMetric(Enum): - perplexity = "perplexity" - rouge = "rouge" - bleu = "bleu" - - -class QuestionAnsweringMetric(Enum): - em = "em" - f1 = "f1" - - -class SummarizationMetric(Enum): - rouge = "rouge" - bleu = "bleu" - - -class EvaluationJob(BaseModel): - job_uuid: str - - -class EvaluationJobLogStream(BaseModel): - job_uuid: str - - -class EvaluateTaskRequestCommon(BaseModel): - job_uuid: str - dataset: TrainEvalDataset - - checkpoint: Checkpoint - - # generation params - sampling_params: SamplingParams = SamplingParams() - - -@json_schema_type -class EvaluateTextGenerationRequest(EvaluateTaskRequestCommon): - """Request to evaluate text generation.""" - - metrics: List[TextGenerationMetric] - - -@json_schema_type -class EvaluateQuestionAnsweringRequest(EvaluateTaskRequestCommon): - """Request to evaluate question answering.""" - - metrics: List[QuestionAnsweringMetric] - - -@json_schema_type -class EvaluateSummarizationRequest(EvaluateTaskRequestCommon): - """Request to evaluate summarization.""" - - metrics: List[SummarizationMetric] - - -class EvaluationJobStatusResponse(BaseModel): - job_uuid: str - - -@json_schema_type -class EvaluationJobArtifactsResponse(BaseModel): - """Artifacts of a evaluation job.""" - - job_uuid: str - - -class Evaluations(Protocol): - @webmethod(route="/evaluate/text_generation/") - def evaluate_text_generation( - self, - metrics: List[TextGenerationMetric], - ) -> EvaluationJob: ... - - @webmethod(route="/evaluate/question_answering/") - def evaluate_question_answering( - self, - metrics: List[QuestionAnsweringMetric], - ) -> EvaluationJob: ... - - @webmethod(route="/evaluate/summarization/") - def evaluate_summarization( - self, - metrics: List[SummarizationMetric], - ) -> EvaluationJob: ... - - @webmethod(route="/evaluate/jobs") - def get_evaluation_jobs(self) -> List[EvaluationJob]: ... - - @webmethod(route="/evaluate/job/status") - def get_evaluation_job_status( - self, job_uuid: str - ) -> EvaluationJobStatusResponse: ... - - # sends SSE stream of logs - @webmethod(route="/evaluate/job/logs") - def get_evaluation_job_logstream(self, job_uuid: str) -> EvaluationJobLogStream: ... - - @webmethod(route="/evaluate/job/cancel") - def cancel_evaluation_job(self, job_uuid: str) -> None: ... - - @webmethod(route="/evaluate/job/artifacts") - def get_evaluation_job_artifacts( - self, job_uuid: str - ) -> EvaluationJobArtifactsResponse: ... diff --git a/llama_stack/apis/scoring/__init__.py b/llama_stack/apis/scoring/__init__.py new file mode 100644 index 000000000..0739dfc80 --- /dev/null +++ b/llama_stack/apis/scoring/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .scoring import * # noqa: F401 F403 diff --git a/llama_stack/apis/scoring/scoring.py b/llama_stack/apis/scoring/scoring.py new file mode 100644 index 000000000..ec50ecab1 --- /dev/null +++ b/llama_stack/apis/scoring/scoring.py @@ -0,0 +1,46 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict, List, Protocol, runtime_checkable + +from llama_models.schema_utils import json_schema_type, webmethod +from pydantic import BaseModel + +from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_stack.apis.scoring_functions import * # noqa: F403 + + +ScoringResult = Dict[str, Any] + + +@json_schema_type +class ScoreBatchResponse(BaseModel): + dataset_id: str + + +@json_schema_type +class ScoreResponse(BaseModel): + # each key in the dict is a scoring function name + results: List[Dict[str, ScoringResult]] + + +class ScoringFunctionStore(Protocol): + def get_scoring_function(self, name: str) -> ScoringFunctionDefWithProvider: ... + + +@runtime_checkable +class Scoring(Protocol): + scoring_function_store: ScoringFunctionStore + + @webmethod(route="/scoring/score_batch") + async def score_batch( + self, dataset_id: str, scoring_functions: List[str] + ) -> ScoreBatchResponse: ... + + @webmethod(route="/scoring/score") + async def score( + self, input_rows: List[Dict[str, Any]], scoring_functions: List[str] + ) -> ScoreResponse: ... diff --git a/llama_stack/apis/scoring_functions/__init__.py b/llama_stack/apis/scoring_functions/__init__.py new file mode 100644 index 000000000..b96acb45f --- /dev/null +++ b/llama_stack/apis/scoring_functions/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .scoring_functions import * # noqa: F401 F403 diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py new file mode 100644 index 000000000..a5aca34fe --- /dev/null +++ b/llama_stack/apis/scoring_functions/scoring_functions.py @@ -0,0 +1,87 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import ( + Any, + Dict, + List, + Literal, + Optional, + Protocol, + runtime_checkable, + Union, +) + +from llama_models.schema_utils import json_schema_type, webmethod +from pydantic import BaseModel, Field +from typing_extensions import Annotated + +from llama_stack.apis.common.type_system import ParamType + + +class Parameter(BaseModel): + name: str + type: ParamType + description: Optional[str] = None + + +# Perhaps more structure can be imposed on these functions. Maybe they could be associated +# with standard metrics so they can be rolled up? + + +class CommonDef(BaseModel): + name: str + description: Optional[str] = None + metadata: Dict[str, Any] = Field( + default_factory=dict, + description="Any additional metadata for this definition", + ) + + +class DeterministicFunctionDef(CommonDef): + type: Literal["deterministic"] = "deterministic" + parameters: List[Parameter] = Field( + description="List of parameters for the deterministic function", + ) + return_type: ParamType = Field( + description="The return type of the deterministic function", + ) + # We can optionally add information here to support packaging of code, etc. + + +class LLMJudgeFunctionDef(CommonDef): + type: Literal["judge"] = "judge" + model: str = Field( + description="The LLM model to use for the judge function", + ) + + +ScoringFunctionDef = Annotated[ + Union[DeterministicFunctionDef, LLMJudgeFunctionDef], Field(discriminator="type") +] + + +@json_schema_type +class ScoringFunctionDefWithProvider(ScoringFunctionDef): + provider_id: str = Field( + description="The provider ID for this scoring function", + ) + + +@runtime_checkable +class ScoringFunctions(Protocol): + @webmethod(route="/scoring_functions/list", method="GET") + async def list_scoring_functions(self) -> List[ScoringFunctionDefWithProvider]: ... + + @webmethod(route="/scoring_functions/get", method="GET") + async def get_scoring_function( + self, name: str + ) -> Optional[ScoringFunctionDefWithProvider]: ... + + @webmethod(route="/scoring_functions/register", method="POST") + async def register_scoring_function( + self, function: ScoringFunctionDefWithProvider + ) -> None: ...