mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-08 19:10:56 +00:00
wip api
This commit is contained in:
parent
7143ecfc0d
commit
8339b2cef3
10 changed files with 174 additions and 51 deletions
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import List, Protocol
|
||||
|
||||
from llama_models.schema_utils import webmethod
|
||||
|
|
@ -16,22 +15,6 @@ from llama_stack.apis.dataset import * # noqa: F403
|
|||
from llama_stack.apis.common.training_types import * # noqa: F403
|
||||
|
||||
|
||||
class TextGenerationMetric(Enum):
|
||||
perplexity = "perplexity"
|
||||
rouge = "rouge"
|
||||
bleu = "bleu"
|
||||
|
||||
|
||||
class QuestionAnsweringMetric(Enum):
|
||||
em = "em"
|
||||
f1 = "f1"
|
||||
|
||||
|
||||
class SummarizationMetric(Enum):
|
||||
rouge = "rouge"
|
||||
bleu = "bleu"
|
||||
|
||||
|
||||
class EvaluationJob(BaseModel):
|
||||
job_uuid: str
|
||||
|
||||
|
|
@ -54,28 +37,7 @@ class EvaluateTaskRequestCommon(BaseModel):
|
|||
class EvaluateResponse(BaseModel):
|
||||
"""Scores for evaluation."""
|
||||
|
||||
scores = Dict[str, str]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class EvaluateTextGenerationRequest(EvaluateTaskRequestCommon):
|
||||
"""Request to evaluate text generation."""
|
||||
|
||||
metrics: List[TextGenerationMetric]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class EvaluateQuestionAnsweringRequest(EvaluateTaskRequestCommon):
|
||||
"""Request to evaluate question answering."""
|
||||
|
||||
metrics: List[QuestionAnsweringMetric]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class EvaluateSummarizationRequest(EvaluateTaskRequestCommon):
|
||||
"""Request to evaluate summarization."""
|
||||
|
||||
metrics: List[SummarizationMetric]
|
||||
metrics: Dict[str, float]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -97,33 +59,36 @@ class EvaluationJobCreateResponse(BaseModel):
|
|||
job_uuid: str
|
||||
|
||||
|
||||
class Evaluations(Protocol):
|
||||
@webmethod(route="/evaluate")
|
||||
async def evaluate(
|
||||
self, model: str, dataset: str, task: str
|
||||
class Evals(Protocol):
|
||||
@webmethod(route="/evals/run")
|
||||
async def run_evals(
|
||||
self,
|
||||
model: str,
|
||||
dataset: str,
|
||||
task: str,
|
||||
) -> EvaluateResponse: ...
|
||||
|
||||
@webmethod(route="/evaluate/jobs")
|
||||
@webmethod(route="/evals/jobs")
|
||||
def get_evaluation_jobs(self) -> List[EvaluationJob]: ...
|
||||
|
||||
@webmethod(route="/evaluate/job/create")
|
||||
@webmethod(route="/evals/job/create")
|
||||
async def create_evaluation_job(
|
||||
self, model: str, dataset: str, task: str
|
||||
) -> EvaluationJob: ...
|
||||
|
||||
@webmethod(route="/evaluate/job/status")
|
||||
@webmethod(route="/evals/job/status")
|
||||
def get_evaluation_job_status(
|
||||
self, job_uuid: str
|
||||
) -> EvaluationJobStatusResponse: ...
|
||||
|
||||
# sends SSE stream of logs
|
||||
@webmethod(route="/evaluate/job/logs")
|
||||
@webmethod(route="/evals/job/logs")
|
||||
def get_evaluation_job_logstream(self, job_uuid: str) -> EvaluationJobLogStream: ...
|
||||
|
||||
@webmethod(route="/evaluate/job/cancel")
|
||||
@webmethod(route="/evals/job/cancel")
|
||||
def cancel_evaluation_job(self, job_uuid: str) -> None: ...
|
||||
|
||||
@webmethod(route="/evaluate/job/artifacts")
|
||||
@webmethod(route="/evals/job/artifacts")
|
||||
def get_evaluation_job_artifacts(
|
||||
self, job_uuid: str
|
||||
) -> EvaluationJobArtifactsResponse: ...
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue