From baf68c665c5f20312396084a16ac4e260d4e13d9 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Fri, 21 Mar 2025 14:04:21 -0700 Subject: [PATCH] fix: fix jobs api literal return type (#1757) # What does this PR do? - We cannot directly return a literal type > Note: this is not final jobs API change [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan image [//]: # (## Documentation) --- docs/_static/llama-stack-spec.html | 58 +++++++++++-------- docs/_static/llama-stack-spec.yaml | 45 ++++++++------ llama_stack/apis/common/job_types.py | 12 ++-- llama_stack/apis/eval/eval.py | 4 +- llama_stack/distribution/routers/routers.py | 10 +--- .../inline/eval/meta_reference/eval.py | 17 +++--- tests/integration/eval/test_eval.py | 2 +- 7 files changed, 79 insertions(+), 69 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index c81f9b33d..8a46a89ad 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -2183,7 +2183,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/JobStatus" + "$ref": "#/components/schemas/Job" } } } @@ -7648,16 +7648,6 @@ "title": "PostTrainingJobArtifactsResponse", "description": "Artifacts of a finetuning job." }, - "JobStatus": { - "type": "string", - "enum": [ - "completed", - "in_progress", - "failed", - "scheduled" - ], - "title": "JobStatus" - }, "PostTrainingJobStatusResponse": { "type": "object", "properties": { @@ -7665,7 +7655,14 @@ "type": "string" }, "status": { - "$ref": "#/components/schemas/JobStatus" + "type": "string", + "enum": [ + "completed", + "in_progress", + "failed", + "scheduled" + ], + "title": "JobStatus" }, "scheduled_at": { "type": "string", @@ -8115,6 +8112,30 @@ "title": "IterrowsResponse", "description": "A paginated list of rows from a dataset." }, + "Job": { + "type": "object", + "properties": { + "job_id": { + "type": "string" + }, + "status": { + "type": "string", + "enum": [ + "completed", + "in_progress", + "failed", + "scheduled" + ], + "title": "JobStatus" + } + }, + "additionalProperties": false, + "required": [ + "job_id", + "status" + ], + "title": "Job" + }, "ListAgentSessionsResponse": { "type": "object", "properties": { @@ -9639,19 +9660,6 @@ ], "title": "RunEvalRequest" }, - "Job": { - "type": "object", - "properties": { - "job_id": { - "type": "string" - } - }, - "additionalProperties": false, - "required": [ - "job_id" - ], - "title": "Job" - }, "RunShieldRequest": { "type": "object", "properties": { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 8ea0e1b9c..0b8f90490 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -1491,7 +1491,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/JobStatus' + $ref: '#/components/schemas/Job' '400': $ref: '#/components/responses/BadRequest400' '429': @@ -5277,21 +5277,19 @@ components: - checkpoints title: PostTrainingJobArtifactsResponse description: Artifacts of a finetuning job. - JobStatus: - type: string - enum: - - completed - - in_progress - - failed - - scheduled - title: JobStatus PostTrainingJobStatusResponse: type: object properties: job_uuid: type: string status: - $ref: '#/components/schemas/JobStatus' + type: string + enum: + - completed + - in_progress + - failed + - scheduled + title: JobStatus scheduled_at: type: string format: date-time @@ -5556,6 +5554,24 @@ components: - data title: IterrowsResponse description: A paginated list of rows from a dataset. + Job: + type: object + properties: + job_id: + type: string + status: + type: string + enum: + - completed + - in_progress + - failed + - scheduled + title: JobStatus + additionalProperties: false + required: + - job_id + - status + title: Job ListAgentSessionsResponse: type: object properties: @@ -6550,15 +6566,6 @@ components: required: - benchmark_config title: RunEvalRequest - Job: - type: object - properties: - job_id: - type: string - additionalProperties: false - required: - - job_id - title: Job RunShieldRequest: type: object properties: diff --git a/llama_stack/apis/common/job_types.py b/llama_stack/apis/common/job_types.py index bc070017b..9acecc154 100644 --- a/llama_stack/apis/common/job_types.py +++ b/llama_stack/apis/common/job_types.py @@ -10,14 +10,14 @@ from pydantic import BaseModel from llama_stack.schema_utils import json_schema_type -@json_schema_type -class Job(BaseModel): - job_id: str - - -@json_schema_type class JobStatus(Enum): completed = "completed" in_progress = "in_progress" failed = "failed" scheduled = "scheduled" + + +@json_schema_type +class Job(BaseModel): + job_id: str + status: JobStatus diff --git a/llama_stack/apis/eval/eval.py b/llama_stack/apis/eval/eval.py index d05786321..0e5959c37 100644 --- a/llama_stack/apis/eval/eval.py +++ b/llama_stack/apis/eval/eval.py @@ -10,7 +10,7 @@ from pydantic import BaseModel, Field from typing_extensions import Annotated from llama_stack.apis.agents import AgentConfig -from llama_stack.apis.common.job_types import Job, JobStatus +from llama_stack.apis.common.job_types import Job from llama_stack.apis.inference import SamplingParams, SystemMessage from llama_stack.apis.scoring import ScoringResult from llama_stack.apis.scoring_functions import ScoringFnParams @@ -115,7 +115,7 @@ class Eval(Protocol): """ @webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET") - async def job_status(self, benchmark_id: str, job_id: str) -> JobStatus: + async def job_status(self, benchmark_id: str, job_id: str) -> Job: """Get the status of a job. :param benchmark_id: The ID of the benchmark to run the evaluation on. diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 2cf38f544..6ff36a65c 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -14,13 +14,7 @@ from llama_stack.apis.common.content_types import ( ) from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse from llama_stack.apis.datasets import DatasetPurpose, DataSource -from llama_stack.apis.eval import ( - BenchmarkConfig, - Eval, - EvaluateResponse, - Job, - JobStatus, -) +from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job from llama_stack.apis.inference import ( ChatCompletionResponse, ChatCompletionResponseEventType, @@ -623,7 +617,7 @@ class EvalRouter(Eval): self, benchmark_id: str, job_id: str, - ) -> Optional[JobStatus]: + ) -> Job: logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}") return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id) diff --git a/llama_stack/providers/inline/eval/meta_reference/eval.py b/llama_stack/providers/inline/eval/meta_reference/eval.py index 3630d4c03..7c28f1bb7 100644 --- a/llama_stack/providers/inline/eval/meta_reference/eval.py +++ b/llama_stack/providers/inline/eval/meta_reference/eval.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import json -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List from tqdm import tqdm @@ -21,8 +21,8 @@ from llama_stack.providers.inline.agents.meta_reference.agent_instance import ( from llama_stack.providers.utils.common.data_schema_validator import ColumnName from llama_stack.providers.utils.kvstore import kvstore_impl -from .....apis.common.job_types import Job -from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse, JobStatus +from .....apis.common.job_types import Job, JobStatus +from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse from .config import MetaReferenceEvalConfig EVAL_TASKS_PREFIX = "benchmarks:" @@ -102,7 +102,7 @@ class MetaReferenceEvalImpl( # need job scheduler queue (ray/celery) w/ jobs api job_id = str(len(self.jobs)) self.jobs[job_id] = res - return Job(job_id=job_id) + return Job(job_id=job_id, status=JobStatus.completed) async def _run_agent_generation( self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig @@ -216,17 +216,18 @@ class MetaReferenceEvalImpl( return EvaluateResponse(generations=generations, scores=score_response.results) - async def job_status(self, benchmark_id: str, job_id: str) -> Optional[JobStatus]: + async def job_status(self, benchmark_id: str, job_id: str) -> Job: if job_id in self.jobs: - return JobStatus.completed + return Job(job_id=job_id, status=JobStatus.completed) - return None + raise ValueError(f"Job {job_id} not found") async def job_cancel(self, benchmark_id: str, job_id: str) -> None: raise NotImplementedError("Job cancel is not implemented yet") async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse: - status = await self.job_status(benchmark_id, job_id) + job = await self.job_status(benchmark_id, job_id) + status = job.status if not status or status != JobStatus.completed: raise ValueError(f"Job is not completed, Status: {status.value}") diff --git a/tests/integration/eval/test_eval.py b/tests/integration/eval/test_eval.py index c4aa0fa1b..d1c3de519 100644 --- a/tests/integration/eval/test_eval.py +++ b/tests/integration/eval/test_eval.py @@ -94,7 +94,7 @@ def test_evaluate_benchmark(llama_stack_client, text_model_id, scoring_fn_id): ) assert response.job_id == "0" job_status = llama_stack_client.eval.jobs.status(job_id=response.job_id, benchmark_id=benchmark_id) - assert job_status and job_status == "completed" + assert job_status and job_status.status == "completed" eval_response = llama_stack_client.eval.jobs.retrieve(job_id=response.job_id, benchmark_id=benchmark_id) assert eval_response is not None