[Evals API] [1/n] Initial API (#287)

* type system api

* datasets api

* fix

* datasetio api

* kill reward scoring

* scoring functions + evals

* move jobs, fix errors
This commit is contained in:
Xi Yan 2024-10-22 09:31:19 -07:00 committed by GitHub
parent b279d3bc58
commit e45f121c77
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 397 additions and 243 deletions

View file

@ -0,0 +1,12 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel
@json_schema_type
class Job(BaseModel):
job_id: str

View file

@ -0,0 +1,65 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict, List, Literal, Union
from pydantic import BaseModel, Field
from typing_extensions import Annotated
class StringType(BaseModel):
type: Literal["string"] = "string"
class NumberType(BaseModel):
type: Literal["number"] = "number"
class BooleanType(BaseModel):
type: Literal["boolean"] = "boolean"
class ArrayType(BaseModel):
type: Literal["array"] = "array"
items: "ParamType"
class ObjectType(BaseModel):
type: Literal["object"] = "object"
properties: Dict[str, "ParamType"] = Field(default_factory=dict)
class JsonType(BaseModel):
type: Literal["json"] = "json"
class UnionType(BaseModel):
type: Literal["union"] = "union"
options: List["ParamType"] = Field(default_factory=list)
class CustomType(BaseModel):
type: Literal["custom"] = "custom"
validator_class: str
ParamType = Annotated[
Union[
StringType,
NumberType,
BooleanType,
ArrayType,
ObjectType,
JsonType,
UnionType,
CustomType,
],
Field(discriminator="type"),
]
ArrayType.model_rebuild()
ObjectType.model_rebuild()
UnionType.model_rebuild()

View file

@ -1,63 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, Optional, Protocol
from llama_models.llama3.api.datatypes import URL
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
@json_schema_type
class TrainEvalDatasetColumnType(Enum):
dialog = "dialog"
text = "text"
media = "media"
number = "number"
json = "json"
@json_schema_type
class TrainEvalDataset(BaseModel):
"""Dataset to be used for training or evaluating language models."""
# TODO(ashwin): figure out if we need to add an enum for a "dataset type"
columns: Dict[str, TrainEvalDatasetColumnType]
content_url: URL
metadata: Optional[Dict[str, Any]] = None
@json_schema_type
class CreateDatasetRequest(BaseModel):
"""Request to create a dataset."""
uuid: str
dataset: TrainEvalDataset
class Datasets(Protocol):
@webmethod(route="/datasets/create")
def create_dataset(
self,
uuid: str,
dataset: TrainEvalDataset,
) -> None: ...
@webmethod(route="/datasets/get")
def get_dataset(
self,
dataset_uuid: str,
) -> TrainEvalDataset: ...
@webmethod(route="/datasets/delete")
def delete_dataset(
self,
dataset_uuid: str,
) -> None: ...

View file

@ -4,4 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .reward_scoring import * # noqa: F401 F403
from .datasetio import * # noqa: F401 F403

View file

@ -0,0 +1,39 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
from llama_stack.apis.datasets import * # noqa: F403
@json_schema_type
class PaginatedRowsResult(BaseModel):
# the rows obey the DatasetSchema for the given dataset
rows: List[Dict[str, Any]]
total_count: int
next_page_token: Optional[str] = None
class DatasetStore(Protocol):
def get_dataset(self, identifier: str) -> DatasetDefWithProvider: ...
@runtime_checkable
class DatasetIO(Protocol):
# keeping for aligning with inference/safety, but this is not used
dataset_store: DatasetStore
@webmethod(route="/dataio/get_rows_paginated")
async def get_rows_paginated(
self,
dataset_id: str,
rows_in_page: int,
page_token: Optional[str] = None,
filter_condition: Optional[str] = None,
) -> PaginatedRowsResult: ...

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .datasets import * # noqa: F401 F403

View file

@ -0,0 +1,60 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Optional, Protocol
from llama_models.llama3.api.datatypes import URL
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_stack.apis.common.type_system import ParamType
@json_schema_type
class DatasetDef(BaseModel):
identifier: str = Field(
description="A unique name for the dataset",
)
columns_schema: Dict[str, ParamType] = Field(
description="The schema definition for this dataset",
)
url: URL
metadata: Dict[str, Any] = Field(
default_factory=dict,
description="Any additional metadata for this dataset",
)
@json_schema_type
class DatasetDefWithProvider(DatasetDef):
provider_id: str = Field(
description="ID of the provider which serves this dataset",
)
class Datasets(Protocol):
@webmethod(route="/datasets/register", method="POST")
async def register_dataset(
self,
dataset_def: DatasetDefWithProvider,
) -> None: ...
@webmethod(route="/datasets/get", method="GET")
async def get_dataset(
self,
dataset_identifier: str,
) -> Optional[DatasetDefWithProvider]: ...
@webmethod(route="/datasets/delete")
async def delete_dataset(
self,
dataset_identifier: str,
) -> None: ...
@webmethod(route="/datasets/list", method="GET")
async def list_datasets(self) -> List[DatasetDefWithProvider]: ...

View file

@ -4,4 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .evals import * # noqa: F401 F403
from .eval import * # noqa: F401 F403

View file

@ -0,0 +1,70 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Literal, Optional, Protocol, Union
from typing_extensions import Annotated
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.schema_utils import json_schema_type, webmethod
from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.agents import AgentConfig
from llama_stack.apis.common.job_types import Job
from llama_stack.apis.scoring import * # noqa: F403
@json_schema_type
class ModelCandidate(BaseModel):
type: Literal["model"] = "model"
model: str
sampling_params: SamplingParams
system_message: Optional[SystemMessage] = None
@json_schema_type
class AgentCandidate(BaseModel):
type: Literal["agent"] = "agent"
config: AgentConfig
EvalCandidate = Annotated[
Union[ModelCandidate, AgentCandidate], Field(discriminator="type")
]
@json_schema_type
class EvaluateResponse(BaseModel):
generations: List[Dict[str, Any]]
# each key in the dict is a scoring function name
scores: List[Dict[str, ScoringResult]]
class Eval(Protocol):
@webmethod(route="/eval/evaluate_batch", method="POST")
async def evaluate_batch(
self,
dataset_id: str,
candidate: EvalCandidate,
scoring_functions: List[str],
) -> Job: ...
@webmethod(route="/eval/evaluate", method="POST")
async def evaluate(
self,
input_rows: List[Dict[str, Any]],
candidate: EvalCandidate,
scoring_functions: List[str],
) -> EvaluateResponse: ...
@webmethod(route="/eval/job/status", method="GET")
async def job_status(self, job_id: str) -> None: ...
@webmethod(route="/eval/job/cancel", method="POST")
async def job_cancel(self, job_id: str) -> None: ...
@webmethod(route="/eval/job/result", method="GET")
async def job_result(self, job_id: str) -> None: ...

View file

@ -1,122 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import List, Protocol
from llama_models.schema_utils import webmethod
from pydantic import BaseModel
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.dataset import * # noqa: F403
from llama_stack.apis.common.training_types import * # noqa: F403
class TextGenerationMetric(Enum):
perplexity = "perplexity"
rouge = "rouge"
bleu = "bleu"
class QuestionAnsweringMetric(Enum):
em = "em"
f1 = "f1"
class SummarizationMetric(Enum):
rouge = "rouge"
bleu = "bleu"
class EvaluationJob(BaseModel):
job_uuid: str
class EvaluationJobLogStream(BaseModel):
job_uuid: str
class EvaluateTaskRequestCommon(BaseModel):
job_uuid: str
dataset: TrainEvalDataset
checkpoint: Checkpoint
# generation params
sampling_params: SamplingParams = SamplingParams()
@json_schema_type
class EvaluateTextGenerationRequest(EvaluateTaskRequestCommon):
"""Request to evaluate text generation."""
metrics: List[TextGenerationMetric]
@json_schema_type
class EvaluateQuestionAnsweringRequest(EvaluateTaskRequestCommon):
"""Request to evaluate question answering."""
metrics: List[QuestionAnsweringMetric]
@json_schema_type
class EvaluateSummarizationRequest(EvaluateTaskRequestCommon):
"""Request to evaluate summarization."""
metrics: List[SummarizationMetric]
class EvaluationJobStatusResponse(BaseModel):
job_uuid: str
@json_schema_type
class EvaluationJobArtifactsResponse(BaseModel):
"""Artifacts of a evaluation job."""
job_uuid: str
class Evaluations(Protocol):
@webmethod(route="/evaluate/text_generation/")
def evaluate_text_generation(
self,
metrics: List[TextGenerationMetric],
) -> EvaluationJob: ...
@webmethod(route="/evaluate/question_answering/")
def evaluate_question_answering(
self,
metrics: List[QuestionAnsweringMetric],
) -> EvaluationJob: ...
@webmethod(route="/evaluate/summarization/")
def evaluate_summarization(
self,
metrics: List[SummarizationMetric],
) -> EvaluationJob: ...
@webmethod(route="/evaluate/jobs")
def get_evaluation_jobs(self) -> List[EvaluationJob]: ...
@webmethod(route="/evaluate/job/status")
def get_evaluation_job_status(
self, job_uuid: str
) -> EvaluationJobStatusResponse: ...
# sends SSE stream of logs
@webmethod(route="/evaluate/job/logs")
def get_evaluation_job_logstream(self, job_uuid: str) -> EvaluationJobLogStream: ...
@webmethod(route="/evaluate/job/cancel")
def cancel_evaluation_job(self, job_uuid: str) -> None: ...
@webmethod(route="/evaluate/job/artifacts")
def get_evaluation_job_artifacts(
self, job_uuid: str
) -> EvaluationJobArtifactsResponse: ...

View file

@ -1,55 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Protocol, Union
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
from llama_models.llama3.api.datatypes import * # noqa: F403
@json_schema_type
class ScoredMessage(BaseModel):
message: Message
score: float
@json_schema_type
class DialogGenerations(BaseModel):
dialog: List[Message]
sampled_generations: List[Message]
@json_schema_type
class ScoredDialogGenerations(BaseModel):
dialog: List[Message]
scored_generations: List[ScoredMessage]
@json_schema_type
class RewardScoringRequest(BaseModel):
"""Request to score a reward function. A list of prompts and a list of responses per prompt."""
dialog_generations: List[DialogGenerations]
model: str
@json_schema_type
class RewardScoringResponse(BaseModel):
"""Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold."""
scored_generations: List[ScoredDialogGenerations]
class RewardScoring(Protocol):
@webmethod(route="/reward_scoring/score")
def reward_score(
self,
dialog_generations: List[DialogGenerations],
model: str,
) -> Union[RewardScoringResponse]: ...

View file

@ -4,4 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .dataset import * # noqa: F401 F403
from .scoring import * # noqa: F401 F403

View file

@ -0,0 +1,46 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.scoring_functions import * # noqa: F403
ScoringResult = Dict[str, Any]
@json_schema_type
class ScoreBatchResponse(BaseModel):
dataset_id: str
@json_schema_type
class ScoreResponse(BaseModel):
# each key in the dict is a scoring function name
results: List[Dict[str, ScoringResult]]
class ScoringFunctionStore(Protocol):
def get_scoring_function(self, name: str) -> ScoringFunctionDefWithProvider: ...
@runtime_checkable
class Scoring(Protocol):
scoring_function_store: ScoringFunctionStore
@webmethod(route="/scoring/score_batch")
async def score_batch(
self, dataset_id: str, scoring_functions: List[str]
) -> ScoreBatchResponse: ...
@webmethod(route="/scoring/score")
async def score(
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
) -> ScoreResponse: ...

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .scoring_functions import * # noqa: F401 F403

View file

@ -0,0 +1,88 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import (
Any,
Dict,
List,
Literal,
Optional,
Protocol,
runtime_checkable,
Union,
)
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.apis.common.type_system import ParamType
@json_schema_type
class Parameter(BaseModel):
name: str
type: ParamType
description: Optional[str] = None
# Perhaps more structure can be imposed on these functions. Maybe they could be associated
# with standard metrics so they can be rolled up?
@json_schema_type
class CommonDef(BaseModel):
name: str
description: Optional[str] = None
metadata: Dict[str, Any] = Field(
default_factory=dict,
description="Any additional metadata for this definition",
)
# Hack: same with memory_banks for union defs
provider_id: str = ""
@json_schema_type
class DeterministicFunctionDef(CommonDef):
type: Literal["deterministic"] = "deterministic"
parameters: List[Parameter] = Field(
description="List of parameters for the deterministic function",
)
return_type: ParamType = Field(
description="The return type of the deterministic function",
)
# We can optionally add information here to support packaging of code, etc.
@json_schema_type
class LLMJudgeFunctionDef(CommonDef):
type: Literal["judge"] = "judge"
model: str = Field(
description="The LLM model to use for the judge function",
)
ScoringFunctionDef = Annotated[
Union[DeterministicFunctionDef, LLMJudgeFunctionDef], Field(discriminator="type")
]
ScoringFunctionDefWithProvider = ScoringFunctionDef
@runtime_checkable
class ScoringFunctions(Protocol):
@webmethod(route="/scoring_functions/list", method="GET")
async def list_scoring_functions(self) -> List[ScoringFunctionDefWithProvider]: ...
@webmethod(route="/scoring_functions/get", method="GET")
async def get_scoring_function(
self, name: str
) -> Optional[ScoringFunctionDefWithProvider]: ...
@webmethod(route="/scoring_functions/register", method="POST")
async def register_scoring_function(
self, function: ScoringFunctionDefWithProvider
) -> None: ...