mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-02 18:20:03 +00:00
feat(api): define a more coherent jobs api across different flows
Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
71ed47ea76
commit
0f50cfa561
15 changed files with 1864 additions and 1670 deletions
|
|
@ -20,9 +20,10 @@ from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
|
|||
)
|
||||
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.schema_utils import webmethod
|
||||
|
||||
from .....apis.common.job_types import Job, JobStatus
|
||||
from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse
|
||||
from .....apis.common.job_types import JobArtifact, JobStatus
|
||||
from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateJob, ListEvaluateJobsResponse
|
||||
from .config import MetaReferenceEvalConfig
|
||||
|
||||
EVAL_TASKS_PREFIX = "benchmarks:"
|
||||
|
|
@ -75,11 +76,11 @@ class MetaReferenceEvalImpl(
|
|||
)
|
||||
self.benchmarks[task_def.identifier] = task_def
|
||||
|
||||
async def run_eval(
|
||||
async def evaluate(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> Job:
|
||||
) -> EvaluateJob:
|
||||
task_def = self.benchmarks[benchmark_id]
|
||||
dataset_id = task_def.dataset_id
|
||||
scoring_functions = task_def.scoring_functions
|
||||
|
|
@ -91,18 +92,35 @@ class MetaReferenceEvalImpl(
|
|||
dataset_id=dataset_id,
|
||||
limit=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples),
|
||||
)
|
||||
res = await self.evaluate_rows(
|
||||
|
||||
generations, scoring_results = await self._evaluate_rows(
|
||||
benchmark_id=benchmark_id,
|
||||
input_rows=all_rows.data,
|
||||
scoring_functions=scoring_functions,
|
||||
benchmark_config=benchmark_config,
|
||||
)
|
||||
artifacts = [
|
||||
JobArtifact(
|
||||
type="generation",
|
||||
name=f"generation-{i}",
|
||||
metadata=generation,
|
||||
)
|
||||
for i, generation in enumerate(generations)
|
||||
] + [
|
||||
JobArtifact(
|
||||
type="scoring_results",
|
||||
name="scoring_results",
|
||||
metadata=scoring_results,
|
||||
)
|
||||
]
|
||||
|
||||
# TODO: currently needs to wait for generation before returning
|
||||
# need job scheduler queue (ray/celery) w/ jobs api
|
||||
job_id = str(len(self.jobs))
|
||||
self.jobs[job_id] = res
|
||||
return Job(job_id=job_id, status=JobStatus.completed)
|
||||
job = EvaluateJob(id=job_id, artifacts=artifacts)
|
||||
job.update_status(JobStatus.completed)
|
||||
self.jobs[job_id] = job
|
||||
return job
|
||||
|
||||
async def _run_agent_generation(
|
||||
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
|
||||
|
|
@ -182,13 +200,13 @@ class MetaReferenceEvalImpl(
|
|||
|
||||
return generations
|
||||
|
||||
async def evaluate_rows(
|
||||
async def _evaluate_rows(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: List[str],
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> EvaluateResponse:
|
||||
) -> tuple[list[dict[str, Any]], dict[str, Any]]:
|
||||
candidate = benchmark_config.eval_candidate
|
||||
if candidate.type == "agent":
|
||||
generations = await self._run_agent_generation(input_rows, benchmark_config)
|
||||
|
|
@ -214,21 +232,26 @@ class MetaReferenceEvalImpl(
|
|||
input_rows=score_input_rows, scoring_functions=scoring_functions_dict
|
||||
)
|
||||
|
||||
return EvaluateResponse(generations=generations, scores=score_response.results)
|
||||
|
||||
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
|
||||
if job_id in self.jobs:
|
||||
return Job(job_id=job_id, status=JobStatus.completed)
|
||||
|
||||
raise ValueError(f"Job {job_id} not found")
|
||||
|
||||
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
|
||||
raise NotImplementedError("Job cancel is not implemented yet")
|
||||
|
||||
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
|
||||
job = await self.job_status(benchmark_id, job_id)
|
||||
status = job.status
|
||||
if not status or status != JobStatus.completed:
|
||||
raise ValueError(f"Job is not completed, Status: {status.value}")
|
||||
return generations, score_response.results
|
||||
|
||||
# CRUD operations on running jobs
|
||||
@webmethod(route="/evaluate/jobs/{job_id:path}", method="GET")
|
||||
async def get_evaluate_job(self, job_id: str) -> EvaluateJob:
|
||||
return self.jobs[job_id]
|
||||
|
||||
@webmethod(route="/evaluate/jobs", method="GET")
|
||||
async def list_evaluate_jobs(self) -> ListEvaluateJobsResponse:
|
||||
return ListEvaluateJobsResponse(data=list(self.jobs.values()))
|
||||
|
||||
@webmethod(route="/evaluate/jobs/{job_id:path}", method="POST")
|
||||
async def update_evaluate_job(self, job: EvaluateJob) -> EvaluateJob:
|
||||
raise NotImplementedError
|
||||
|
||||
@webmethod(route="/evaluate/job/{job_id:path}", method="DELETE")
|
||||
async def delete_evaluate_job(self, job_id: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
# Note: pause/resume/cancel are achieved as follows:
|
||||
# - POST with status=paused
|
||||
# - POST with status=resuming
|
||||
# - POST with status=cancelled
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue