scoring functions + evals

This commit is contained in:
Xi Yan 2024-10-22 08:53:46 -07:00
parent 1dc2962a33
commit 5836c09c57
7 changed files with 220 additions and 123 deletions

View file

@ -4,4 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .evals import * # noqa: F401 F403
from .eval import * # noqa: F401 F403

View file

@ -0,0 +1,72 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Literal, Optional, Protocol, Union
from typing_extensions import Annotated
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.schema_utils import json_schema_type, webmethod
from llama_stack.apis.scoring_functions import * # noqa: F403
@json_schema_type
class ModelCandidate(BaseModel):
type: Literal["model"] = "model"
model: str
sampling_params: SamplingParams
system_message: Optional[SystemMessage] = None
@json_schema_type
class AgentCandidate(BaseModel):
type: Literal["agent"] = "agent"
config: AgentConfig
EvalCandidate = Annotated[
Union[ModelCandidate, AgentCandidate], Field(discriminator="type")
]
@json_schema_type
class Job(BaseModel):
job_id: str
@json_schema_type
class EvaluateResponse(BaseModel):
generations: List[Dict[str, Any]]
# each key in the dict is a scoring function name
scores: List[Dict[str, ScoringResult]]
class Eval(Protocol):
@webmethod(route="/eval/evaluate_batch", method="POST")
async def evaluate_batch(
self,
dataset_id: str,
candidate: EvalCandidate,
scoring_functions: List[str],
) -> Job: ...
@webmethod(route="/eval/evaluate", method="POST")
async def evaluate(
self,
input_rows: List[Dict[str, Any]],
candidate: EvalCandidate,
scoring_functions: List[str],
) -> EvaluateResponse: ...
@webmethod(route="/eval/job/status", method="GET")
async def job_status(self, job_id: str) -> None: ...
@webmethod(route="/eval/job/cancel", method="POST")
async def job_cancel(self, job_id: str) -> None: ...
@webmethod(route="/eval/job/result", method="GET")
async def job_result(self, job_id: str) -> None: ...

View file

@ -1,122 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import List, Protocol
from llama_models.schema_utils import webmethod
from pydantic import BaseModel
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.dataset import * # noqa: F403
from llama_stack.apis.common.training_types import * # noqa: F403
class TextGenerationMetric(Enum):
perplexity = "perplexity"
rouge = "rouge"
bleu = "bleu"
class QuestionAnsweringMetric(Enum):
em = "em"
f1 = "f1"
class SummarizationMetric(Enum):
rouge = "rouge"
bleu = "bleu"
class EvaluationJob(BaseModel):
job_uuid: str
class EvaluationJobLogStream(BaseModel):
job_uuid: str
class EvaluateTaskRequestCommon(BaseModel):
job_uuid: str
dataset: TrainEvalDataset
checkpoint: Checkpoint
# generation params
sampling_params: SamplingParams = SamplingParams()
@json_schema_type
class EvaluateTextGenerationRequest(EvaluateTaskRequestCommon):
"""Request to evaluate text generation."""
metrics: List[TextGenerationMetric]
@json_schema_type
class EvaluateQuestionAnsweringRequest(EvaluateTaskRequestCommon):
"""Request to evaluate question answering."""
metrics: List[QuestionAnsweringMetric]
@json_schema_type
class EvaluateSummarizationRequest(EvaluateTaskRequestCommon):
"""Request to evaluate summarization."""
metrics: List[SummarizationMetric]
class EvaluationJobStatusResponse(BaseModel):
job_uuid: str
@json_schema_type
class EvaluationJobArtifactsResponse(BaseModel):
"""Artifacts of a evaluation job."""
job_uuid: str
class Evaluations(Protocol):
@webmethod(route="/evaluate/text_generation/")
def evaluate_text_generation(
self,
metrics: List[TextGenerationMetric],
) -> EvaluationJob: ...
@webmethod(route="/evaluate/question_answering/")
def evaluate_question_answering(
self,
metrics: List[QuestionAnsweringMetric],
) -> EvaluationJob: ...
@webmethod(route="/evaluate/summarization/")
def evaluate_summarization(
self,
metrics: List[SummarizationMetric],
) -> EvaluationJob: ...
@webmethod(route="/evaluate/jobs")
def get_evaluation_jobs(self) -> List[EvaluationJob]: ...
@webmethod(route="/evaluate/job/status")
def get_evaluation_job_status(
self, job_uuid: str
) -> EvaluationJobStatusResponse: ...
# sends SSE stream of logs
@webmethod(route="/evaluate/job/logs")
def get_evaluation_job_logstream(self, job_uuid: str) -> EvaluationJobLogStream: ...
@webmethod(route="/evaluate/job/cancel")
def cancel_evaluation_job(self, job_uuid: str) -> None: ...
@webmethod(route="/evaluate/job/artifacts")
def get_evaluation_job_artifacts(
self, job_uuid: str
) -> EvaluationJobArtifactsResponse: ...

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .scoring import * # noqa: F401 F403

View file

@ -0,0 +1,46 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.scoring_functions import * # noqa: F403
ScoringResult = Dict[str, Any]
@json_schema_type
class ScoreBatchResponse(BaseModel):
dataset_id: str
@json_schema_type
class ScoreResponse(BaseModel):
# each key in the dict is a scoring function name
results: List[Dict[str, ScoringResult]]
class ScoringFunctionStore(Protocol):
def get_scoring_function(self, name: str) -> ScoringFunctionDefWithProvider: ...
@runtime_checkable
class Scoring(Protocol):
scoring_function_store: ScoringFunctionStore
@webmethod(route="/scoring/score_batch")
async def score_batch(
self, dataset_id: str, scoring_functions: List[str]
) -> ScoreBatchResponse: ...
@webmethod(route="/scoring/score")
async def score(
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
) -> ScoreResponse: ...

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .scoring_functions import * # noqa: F401 F403

View file

@ -0,0 +1,87 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import (
Any,
Dict,
List,
Literal,
Optional,
Protocol,
runtime_checkable,
Union,
)
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.apis.common.type_system import ParamType
class Parameter(BaseModel):
name: str
type: ParamType
description: Optional[str] = None
# Perhaps more structure can be imposed on these functions. Maybe they could be associated
# with standard metrics so they can be rolled up?
class CommonDef(BaseModel):
name: str
description: Optional[str] = None
metadata: Dict[str, Any] = Field(
default_factory=dict,
description="Any additional metadata for this definition",
)
class DeterministicFunctionDef(CommonDef):
type: Literal["deterministic"] = "deterministic"
parameters: List[Parameter] = Field(
description="List of parameters for the deterministic function",
)
return_type: ParamType = Field(
description="The return type of the deterministic function",
)
# We can optionally add information here to support packaging of code, etc.
class LLMJudgeFunctionDef(CommonDef):
type: Literal["judge"] = "judge"
model: str = Field(
description="The LLM model to use for the judge function",
)
ScoringFunctionDef = Annotated[
Union[DeterministicFunctionDef, LLMJudgeFunctionDef], Field(discriminator="type")
]
@json_schema_type
class ScoringFunctionDefWithProvider(ScoringFunctionDef):
provider_id: str = Field(
description="The provider ID for this scoring function",
)
@runtime_checkable
class ScoringFunctions(Protocol):
@webmethod(route="/scoring_functions/list", method="GET")
async def list_scoring_functions(self) -> List[ScoringFunctionDefWithProvider]: ...
@webmethod(route="/scoring_functions/get", method="GET")
async def get_scoring_function(
self, name: str
) -> Optional[ScoringFunctionDefWithProvider]: ...
@webmethod(route="/scoring_functions/register", method="POST")
async def register_scoring_function(
self, function: ScoringFunctionDefWithProvider
) -> None: ...