From e45f121c77d709c313910011d23e7ce4fd4c24b2 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Tue, 22 Oct 2024 09:31:19 -0700 Subject: [PATCH] [Evals API] [1/n] Initial API (#287) * type system api * datasets api * fix * datasetio api * kill reward scoring * scoring functions + evals * move jobs, fix errors --- llama_stack/apis/common/job_types.py | 12 ++ llama_stack/apis/common/type_system.py | 65 ++++++++++ llama_stack/apis/dataset/dataset.py | 63 --------- .../{reward_scoring => datasetio}/__init__.py | 2 +- llama_stack/apis/datasetio/datasetio.py | 39 ++++++ llama_stack/apis/datasets/__init__.py | 7 + llama_stack/apis/datasets/datasets.py | 60 +++++++++ llama_stack/apis/{evals => eval}/__init__.py | 2 +- llama_stack/apis/eval/eval.py | 70 ++++++++++ llama_stack/apis/evals/evals.py | 122 ------------------ .../apis/reward_scoring/reward_scoring.py | 55 -------- .../apis/{dataset => scoring}/__init__.py | 2 +- llama_stack/apis/scoring/scoring.py | 46 +++++++ .../apis/scoring_functions/__init__.py | 7 + .../scoring_functions/scoring_functions.py | 88 +++++++++++++ 15 files changed, 397 insertions(+), 243 deletions(-) create mode 100644 llama_stack/apis/common/job_types.py create mode 100644 llama_stack/apis/common/type_system.py delete mode 100644 llama_stack/apis/dataset/dataset.py rename llama_stack/apis/{reward_scoring => datasetio}/__init__.py (80%) create mode 100644 llama_stack/apis/datasetio/datasetio.py create mode 100644 llama_stack/apis/datasets/__init__.py create mode 100644 llama_stack/apis/datasets/datasets.py rename llama_stack/apis/{evals => eval}/__init__.py (83%) create mode 100644 llama_stack/apis/eval/eval.py delete mode 100644 llama_stack/apis/evals/evals.py delete mode 100644 llama_stack/apis/reward_scoring/reward_scoring.py rename llama_stack/apis/{dataset => scoring}/__init__.py (82%) create mode 100644 llama_stack/apis/scoring/scoring.py create mode 100644 llama_stack/apis/scoring_functions/__init__.py create mode 100644 llama_stack/apis/scoring_functions/scoring_functions.py diff --git a/llama_stack/apis/common/job_types.py b/llama_stack/apis/common/job_types.py new file mode 100644 index 000000000..ab203ebb8 --- /dev/null +++ b/llama_stack/apis/common/job_types.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from llama_models.schema_utils import json_schema_type +from pydantic import BaseModel + + +@json_schema_type +class Job(BaseModel): + job_id: str diff --git a/llama_stack/apis/common/type_system.py b/llama_stack/apis/common/type_system.py new file mode 100644 index 000000000..35a26e9ef --- /dev/null +++ b/llama_stack/apis/common/type_system.py @@ -0,0 +1,65 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Dict, List, Literal, Union + +from pydantic import BaseModel, Field +from typing_extensions import Annotated + + +class StringType(BaseModel): + type: Literal["string"] = "string" + + +class NumberType(BaseModel): + type: Literal["number"] = "number" + + +class BooleanType(BaseModel): + type: Literal["boolean"] = "boolean" + + +class ArrayType(BaseModel): + type: Literal["array"] = "array" + items: "ParamType" + + +class ObjectType(BaseModel): + type: Literal["object"] = "object" + properties: Dict[str, "ParamType"] = Field(default_factory=dict) + + +class JsonType(BaseModel): + type: Literal["json"] = "json" + + +class UnionType(BaseModel): + type: Literal["union"] = "union" + options: List["ParamType"] = Field(default_factory=list) + + +class CustomType(BaseModel): + type: Literal["custom"] = "custom" + validator_class: str + + +ParamType = Annotated[ + Union[ + StringType, + NumberType, + BooleanType, + ArrayType, + ObjectType, + JsonType, + UnionType, + CustomType, + ], + Field(discriminator="type"), +] + +ArrayType.model_rebuild() +ObjectType.model_rebuild() +UnionType.model_rebuild() diff --git a/llama_stack/apis/dataset/dataset.py b/llama_stack/apis/dataset/dataset.py deleted file mode 100644 index 2fa8bb4e5..000000000 --- a/llama_stack/apis/dataset/dataset.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from enum import Enum -from typing import Any, Dict, Optional, Protocol - -from llama_models.llama3.api.datatypes import URL - -from llama_models.schema_utils import json_schema_type, webmethod - -from pydantic import BaseModel - - -@json_schema_type -class TrainEvalDatasetColumnType(Enum): - dialog = "dialog" - text = "text" - media = "media" - number = "number" - json = "json" - - -@json_schema_type -class TrainEvalDataset(BaseModel): - """Dataset to be used for training or evaluating language models.""" - - # TODO(ashwin): figure out if we need to add an enum for a "dataset type" - - columns: Dict[str, TrainEvalDatasetColumnType] - content_url: URL - metadata: Optional[Dict[str, Any]] = None - - -@json_schema_type -class CreateDatasetRequest(BaseModel): - """Request to create a dataset.""" - - uuid: str - dataset: TrainEvalDataset - - -class Datasets(Protocol): - @webmethod(route="/datasets/create") - def create_dataset( - self, - uuid: str, - dataset: TrainEvalDataset, - ) -> None: ... - - @webmethod(route="/datasets/get") - def get_dataset( - self, - dataset_uuid: str, - ) -> TrainEvalDataset: ... - - @webmethod(route="/datasets/delete") - def delete_dataset( - self, - dataset_uuid: str, - ) -> None: ... diff --git a/llama_stack/apis/reward_scoring/__init__.py b/llama_stack/apis/datasetio/__init__.py similarity index 80% rename from llama_stack/apis/reward_scoring/__init__.py rename to llama_stack/apis/datasetio/__init__.py index 7ea62c241..378afbba8 100644 --- a/llama_stack/apis/reward_scoring/__init__.py +++ b/llama_stack/apis/datasetio/__init__.py @@ -4,4 +4,4 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .reward_scoring import * # noqa: F401 F403 +from .datasetio import * # noqa: F401 F403 diff --git a/llama_stack/apis/datasetio/datasetio.py b/llama_stack/apis/datasetio/datasetio.py new file mode 100644 index 000000000..e8811d233 --- /dev/null +++ b/llama_stack/apis/datasetio/datasetio.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict, List, Optional, Protocol, runtime_checkable + +from llama_models.schema_utils import json_schema_type, webmethod +from pydantic import BaseModel + +from llama_stack.apis.datasets import * # noqa: F403 + + +@json_schema_type +class PaginatedRowsResult(BaseModel): + # the rows obey the DatasetSchema for the given dataset + rows: List[Dict[str, Any]] + total_count: int + next_page_token: Optional[str] = None + + +class DatasetStore(Protocol): + def get_dataset(self, identifier: str) -> DatasetDefWithProvider: ... + + +@runtime_checkable +class DatasetIO(Protocol): + # keeping for aligning with inference/safety, but this is not used + dataset_store: DatasetStore + + @webmethod(route="/dataio/get_rows_paginated") + async def get_rows_paginated( + self, + dataset_id: str, + rows_in_page: int, + page_token: Optional[str] = None, + filter_condition: Optional[str] = None, + ) -> PaginatedRowsResult: ... diff --git a/llama_stack/apis/datasets/__init__.py b/llama_stack/apis/datasets/__init__.py new file mode 100644 index 000000000..102b9927f --- /dev/null +++ b/llama_stack/apis/datasets/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .datasets import * # noqa: F401 F403 diff --git a/llama_stack/apis/datasets/datasets.py b/llama_stack/apis/datasets/datasets.py new file mode 100644 index 000000000..9160e1e13 --- /dev/null +++ b/llama_stack/apis/datasets/datasets.py @@ -0,0 +1,60 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict, List, Optional, Protocol + +from llama_models.llama3.api.datatypes import URL + +from llama_models.schema_utils import json_schema_type, webmethod + +from pydantic import BaseModel, Field + +from llama_stack.apis.common.type_system import ParamType + + +@json_schema_type +class DatasetDef(BaseModel): + identifier: str = Field( + description="A unique name for the dataset", + ) + columns_schema: Dict[str, ParamType] = Field( + description="The schema definition for this dataset", + ) + url: URL + metadata: Dict[str, Any] = Field( + default_factory=dict, + description="Any additional metadata for this dataset", + ) + + +@json_schema_type +class DatasetDefWithProvider(DatasetDef): + provider_id: str = Field( + description="ID of the provider which serves this dataset", + ) + + +class Datasets(Protocol): + @webmethod(route="/datasets/register", method="POST") + async def register_dataset( + self, + dataset_def: DatasetDefWithProvider, + ) -> None: ... + + @webmethod(route="/datasets/get", method="GET") + async def get_dataset( + self, + dataset_identifier: str, + ) -> Optional[DatasetDefWithProvider]: ... + + @webmethod(route="/datasets/delete") + async def delete_dataset( + self, + dataset_identifier: str, + ) -> None: ... + + @webmethod(route="/datasets/list", method="GET") + async def list_datasets(self) -> List[DatasetDefWithProvider]: ... diff --git a/llama_stack/apis/evals/__init__.py b/llama_stack/apis/eval/__init__.py similarity index 83% rename from llama_stack/apis/evals/__init__.py rename to llama_stack/apis/eval/__init__.py index d21b97d0a..5f91ad70d 100644 --- a/llama_stack/apis/evals/__init__.py +++ b/llama_stack/apis/eval/__init__.py @@ -4,4 +4,4 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .evals import * # noqa: F401 F403 +from .eval import * # noqa: F401 F403 diff --git a/llama_stack/apis/eval/eval.py b/llama_stack/apis/eval/eval.py new file mode 100644 index 000000000..a97af1fc0 --- /dev/null +++ b/llama_stack/apis/eval/eval.py @@ -0,0 +1,70 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Literal, Optional, Protocol, Union + +from typing_extensions import Annotated + +from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_models.schema_utils import json_schema_type, webmethod +from llama_stack.apis.scoring_functions import * # noqa: F403 +from llama_stack.apis.agents import AgentConfig +from llama_stack.apis.common.job_types import Job +from llama_stack.apis.scoring import * # noqa: F403 + + +@json_schema_type +class ModelCandidate(BaseModel): + type: Literal["model"] = "model" + model: str + sampling_params: SamplingParams + system_message: Optional[SystemMessage] = None + + +@json_schema_type +class AgentCandidate(BaseModel): + type: Literal["agent"] = "agent" + config: AgentConfig + + +EvalCandidate = Annotated[ + Union[ModelCandidate, AgentCandidate], Field(discriminator="type") +] + + +@json_schema_type +class EvaluateResponse(BaseModel): + generations: List[Dict[str, Any]] + + # each key in the dict is a scoring function name + scores: List[Dict[str, ScoringResult]] + + +class Eval(Protocol): + @webmethod(route="/eval/evaluate_batch", method="POST") + async def evaluate_batch( + self, + dataset_id: str, + candidate: EvalCandidate, + scoring_functions: List[str], + ) -> Job: ... + + @webmethod(route="/eval/evaluate", method="POST") + async def evaluate( + self, + input_rows: List[Dict[str, Any]], + candidate: EvalCandidate, + scoring_functions: List[str], + ) -> EvaluateResponse: ... + + @webmethod(route="/eval/job/status", method="GET") + async def job_status(self, job_id: str) -> None: ... + + @webmethod(route="/eval/job/cancel", method="POST") + async def job_cancel(self, job_id: str) -> None: ... + + @webmethod(route="/eval/job/result", method="GET") + async def job_result(self, job_id: str) -> None: ... diff --git a/llama_stack/apis/evals/evals.py b/llama_stack/apis/evals/evals.py deleted file mode 100644 index 0be2243ab..000000000 --- a/llama_stack/apis/evals/evals.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from enum import Enum -from typing import List, Protocol - -from llama_models.schema_utils import webmethod - -from pydantic import BaseModel - -from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.apis.dataset import * # noqa: F403 -from llama_stack.apis.common.training_types import * # noqa: F403 - - -class TextGenerationMetric(Enum): - perplexity = "perplexity" - rouge = "rouge" - bleu = "bleu" - - -class QuestionAnsweringMetric(Enum): - em = "em" - f1 = "f1" - - -class SummarizationMetric(Enum): - rouge = "rouge" - bleu = "bleu" - - -class EvaluationJob(BaseModel): - job_uuid: str - - -class EvaluationJobLogStream(BaseModel): - job_uuid: str - - -class EvaluateTaskRequestCommon(BaseModel): - job_uuid: str - dataset: TrainEvalDataset - - checkpoint: Checkpoint - - # generation params - sampling_params: SamplingParams = SamplingParams() - - -@json_schema_type -class EvaluateTextGenerationRequest(EvaluateTaskRequestCommon): - """Request to evaluate text generation.""" - - metrics: List[TextGenerationMetric] - - -@json_schema_type -class EvaluateQuestionAnsweringRequest(EvaluateTaskRequestCommon): - """Request to evaluate question answering.""" - - metrics: List[QuestionAnsweringMetric] - - -@json_schema_type -class EvaluateSummarizationRequest(EvaluateTaskRequestCommon): - """Request to evaluate summarization.""" - - metrics: List[SummarizationMetric] - - -class EvaluationJobStatusResponse(BaseModel): - job_uuid: str - - -@json_schema_type -class EvaluationJobArtifactsResponse(BaseModel): - """Artifacts of a evaluation job.""" - - job_uuid: str - - -class Evaluations(Protocol): - @webmethod(route="/evaluate/text_generation/") - def evaluate_text_generation( - self, - metrics: List[TextGenerationMetric], - ) -> EvaluationJob: ... - - @webmethod(route="/evaluate/question_answering/") - def evaluate_question_answering( - self, - metrics: List[QuestionAnsweringMetric], - ) -> EvaluationJob: ... - - @webmethod(route="/evaluate/summarization/") - def evaluate_summarization( - self, - metrics: List[SummarizationMetric], - ) -> EvaluationJob: ... - - @webmethod(route="/evaluate/jobs") - def get_evaluation_jobs(self) -> List[EvaluationJob]: ... - - @webmethod(route="/evaluate/job/status") - def get_evaluation_job_status( - self, job_uuid: str - ) -> EvaluationJobStatusResponse: ... - - # sends SSE stream of logs - @webmethod(route="/evaluate/job/logs") - def get_evaluation_job_logstream(self, job_uuid: str) -> EvaluationJobLogStream: ... - - @webmethod(route="/evaluate/job/cancel") - def cancel_evaluation_job(self, job_uuid: str) -> None: ... - - @webmethod(route="/evaluate/job/artifacts") - def get_evaluation_job_artifacts( - self, job_uuid: str - ) -> EvaluationJobArtifactsResponse: ... diff --git a/llama_stack/apis/reward_scoring/reward_scoring.py b/llama_stack/apis/reward_scoring/reward_scoring.py deleted file mode 100644 index 9d689f232..000000000 --- a/llama_stack/apis/reward_scoring/reward_scoring.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import List, Protocol, Union - -from llama_models.schema_utils import json_schema_type, webmethod - -from pydantic import BaseModel - -from llama_models.llama3.api.datatypes import * # noqa: F403 - - -@json_schema_type -class ScoredMessage(BaseModel): - message: Message - score: float - - -@json_schema_type -class DialogGenerations(BaseModel): - dialog: List[Message] - sampled_generations: List[Message] - - -@json_schema_type -class ScoredDialogGenerations(BaseModel): - dialog: List[Message] - scored_generations: List[ScoredMessage] - - -@json_schema_type -class RewardScoringRequest(BaseModel): - """Request to score a reward function. A list of prompts and a list of responses per prompt.""" - - dialog_generations: List[DialogGenerations] - model: str - - -@json_schema_type -class RewardScoringResponse(BaseModel): - """Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold.""" - - scored_generations: List[ScoredDialogGenerations] - - -class RewardScoring(Protocol): - @webmethod(route="/reward_scoring/score") - def reward_score( - self, - dialog_generations: List[DialogGenerations], - model: str, - ) -> Union[RewardScoringResponse]: ... diff --git a/llama_stack/apis/dataset/__init__.py b/llama_stack/apis/scoring/__init__.py similarity index 82% rename from llama_stack/apis/dataset/__init__.py rename to llama_stack/apis/scoring/__init__.py index 33557a0ab..0739dfc80 100644 --- a/llama_stack/apis/dataset/__init__.py +++ b/llama_stack/apis/scoring/__init__.py @@ -4,4 +4,4 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .dataset import * # noqa: F401 F403 +from .scoring import * # noqa: F401 F403 diff --git a/llama_stack/apis/scoring/scoring.py b/llama_stack/apis/scoring/scoring.py new file mode 100644 index 000000000..ec50ecab1 --- /dev/null +++ b/llama_stack/apis/scoring/scoring.py @@ -0,0 +1,46 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict, List, Protocol, runtime_checkable + +from llama_models.schema_utils import json_schema_type, webmethod +from pydantic import BaseModel + +from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_stack.apis.scoring_functions import * # noqa: F403 + + +ScoringResult = Dict[str, Any] + + +@json_schema_type +class ScoreBatchResponse(BaseModel): + dataset_id: str + + +@json_schema_type +class ScoreResponse(BaseModel): + # each key in the dict is a scoring function name + results: List[Dict[str, ScoringResult]] + + +class ScoringFunctionStore(Protocol): + def get_scoring_function(self, name: str) -> ScoringFunctionDefWithProvider: ... + + +@runtime_checkable +class Scoring(Protocol): + scoring_function_store: ScoringFunctionStore + + @webmethod(route="/scoring/score_batch") + async def score_batch( + self, dataset_id: str, scoring_functions: List[str] + ) -> ScoreBatchResponse: ... + + @webmethod(route="/scoring/score") + async def score( + self, input_rows: List[Dict[str, Any]], scoring_functions: List[str] + ) -> ScoreResponse: ... diff --git a/llama_stack/apis/scoring_functions/__init__.py b/llama_stack/apis/scoring_functions/__init__.py new file mode 100644 index 000000000..b96acb45f --- /dev/null +++ b/llama_stack/apis/scoring_functions/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .scoring_functions import * # noqa: F401 F403 diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py new file mode 100644 index 000000000..1d71c51f3 --- /dev/null +++ b/llama_stack/apis/scoring_functions/scoring_functions.py @@ -0,0 +1,88 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import ( + Any, + Dict, + List, + Literal, + Optional, + Protocol, + runtime_checkable, + Union, +) + +from llama_models.schema_utils import json_schema_type, webmethod +from pydantic import BaseModel, Field +from typing_extensions import Annotated + +from llama_stack.apis.common.type_system import ParamType + + +@json_schema_type +class Parameter(BaseModel): + name: str + type: ParamType + description: Optional[str] = None + + +# Perhaps more structure can be imposed on these functions. Maybe they could be associated +# with standard metrics so they can be rolled up? + + +@json_schema_type +class CommonDef(BaseModel): + name: str + description: Optional[str] = None + metadata: Dict[str, Any] = Field( + default_factory=dict, + description="Any additional metadata for this definition", + ) + # Hack: same with memory_banks for union defs + provider_id: str = "" + + +@json_schema_type +class DeterministicFunctionDef(CommonDef): + type: Literal["deterministic"] = "deterministic" + parameters: List[Parameter] = Field( + description="List of parameters for the deterministic function", + ) + return_type: ParamType = Field( + description="The return type of the deterministic function", + ) + # We can optionally add information here to support packaging of code, etc. + + +@json_schema_type +class LLMJudgeFunctionDef(CommonDef): + type: Literal["judge"] = "judge" + model: str = Field( + description="The LLM model to use for the judge function", + ) + + +ScoringFunctionDef = Annotated[ + Union[DeterministicFunctionDef, LLMJudgeFunctionDef], Field(discriminator="type") +] + +ScoringFunctionDefWithProvider = ScoringFunctionDef + + +@runtime_checkable +class ScoringFunctions(Protocol): + @webmethod(route="/scoring_functions/list", method="GET") + async def list_scoring_functions(self) -> List[ScoringFunctionDefWithProvider]: ... + + @webmethod(route="/scoring_functions/get", method="GET") + async def get_scoring_function( + self, name: str + ) -> Optional[ScoringFunctionDefWithProvider]: ... + + @webmethod(route="/scoring_functions/register", method="POST") + async def register_scoring_function( + self, function: ScoringFunctionDefWithProvider + ) -> None: ...