This commit is contained in:
Xi Yan 2024-10-03 13:47:15 -07:00
parent 7143ecfc0d
commit 8339b2cef3
10 changed files with 174 additions and 51 deletions

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import List, Protocol
from llama_models.schema_utils import webmethod
@ -16,22 +15,6 @@ from llama_stack.apis.dataset import * # noqa: F403
from llama_stack.apis.common.training_types import * # noqa: F403
class TextGenerationMetric(Enum):
perplexity = "perplexity"
rouge = "rouge"
bleu = "bleu"
class QuestionAnsweringMetric(Enum):
em = "em"
f1 = "f1"
class SummarizationMetric(Enum):
rouge = "rouge"
bleu = "bleu"
class EvaluationJob(BaseModel):
job_uuid: str
@ -54,28 +37,7 @@ class EvaluateTaskRequestCommon(BaseModel):
class EvaluateResponse(BaseModel):
"""Scores for evaluation."""
scores = Dict[str, str]
@json_schema_type
class EvaluateTextGenerationRequest(EvaluateTaskRequestCommon):
"""Request to evaluate text generation."""
metrics: List[TextGenerationMetric]
@json_schema_type
class EvaluateQuestionAnsweringRequest(EvaluateTaskRequestCommon):
"""Request to evaluate question answering."""
metrics: List[QuestionAnsweringMetric]
@json_schema_type
class EvaluateSummarizationRequest(EvaluateTaskRequestCommon):
"""Request to evaluate summarization."""
metrics: List[SummarizationMetric]
metrics: Dict[str, float]
@json_schema_type
@ -97,33 +59,36 @@ class EvaluationJobCreateResponse(BaseModel):
job_uuid: str
class Evaluations(Protocol):
@webmethod(route="/evaluate")
async def evaluate(
self, model: str, dataset: str, task: str
class Evals(Protocol):
@webmethod(route="/evals/run")
async def run_evals(
self,
model: str,
dataset: str,
task: str,
) -> EvaluateResponse: ...
@webmethod(route="/evaluate/jobs")
@webmethod(route="/evals/jobs")
def get_evaluation_jobs(self) -> List[EvaluationJob]: ...
@webmethod(route="/evaluate/job/create")
@webmethod(route="/evals/job/create")
async def create_evaluation_job(
self, model: str, dataset: str, task: str
) -> EvaluationJob: ...
@webmethod(route="/evaluate/job/status")
@webmethod(route="/evals/job/status")
def get_evaluation_job_status(
self, job_uuid: str
) -> EvaluationJobStatusResponse: ...
# sends SSE stream of logs
@webmethod(route="/evaluate/job/logs")
@webmethod(route="/evals/job/logs")
def get_evaluation_job_logstream(self, job_uuid: str) -> EvaluationJobLogStream: ...
@webmethod(route="/evaluate/job/cancel")
@webmethod(route="/evals/job/cancel")
def cancel_evaluation_job(self, job_uuid: str) -> None: ...
@webmethod(route="/evaluate/job/artifacts")
@webmethod(route="/evals/job/artifacts")
def get_evaluation_job_artifacts(
self, job_uuid: str
) -> EvaluationJobArtifactsResponse: ...