fix: fix jobs api literal return type (#1757)

# What does this PR do?

- We cannot directly return a literal type

> Note: this is not final jobs API change

[//]: # (If resolving an issue, uncomment and update the line below)
[//]: # (Closes #[issue-number])

## Test Plan
<img width="837" alt="image"
src="https://github.com/user-attachments/assets/18a17561-35f9-443d-987d-54afdd6ff40c"
/>


[//]: # (## Documentation)
This commit is contained in:
Xi Yan 2025-03-21 14:04:21 -07:00 committed by GitHub
parent d6887f46c6
commit baf68c665c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 79 additions and 69 deletions

View file

@ -2183,7 +2183,7 @@
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/JobStatus"
"$ref": "#/components/schemas/Job"
}
}
}
@ -7648,16 +7648,6 @@
"title": "PostTrainingJobArtifactsResponse",
"description": "Artifacts of a finetuning job."
},
"JobStatus": {
"type": "string",
"enum": [
"completed",
"in_progress",
"failed",
"scheduled"
],
"title": "JobStatus"
},
"PostTrainingJobStatusResponse": {
"type": "object",
"properties": {
@ -7665,7 +7655,14 @@
"type": "string"
},
"status": {
"$ref": "#/components/schemas/JobStatus"
"type": "string",
"enum": [
"completed",
"in_progress",
"failed",
"scheduled"
],
"title": "JobStatus"
},
"scheduled_at": {
"type": "string",
@ -8115,6 +8112,30 @@
"title": "IterrowsResponse",
"description": "A paginated list of rows from a dataset."
},
"Job": {
"type": "object",
"properties": {
"job_id": {
"type": "string"
},
"status": {
"type": "string",
"enum": [
"completed",
"in_progress",
"failed",
"scheduled"
],
"title": "JobStatus"
}
},
"additionalProperties": false,
"required": [
"job_id",
"status"
],
"title": "Job"
},
"ListAgentSessionsResponse": {
"type": "object",
"properties": {
@ -9639,19 +9660,6 @@
],
"title": "RunEvalRequest"
},
"Job": {
"type": "object",
"properties": {
"job_id": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"job_id"
],
"title": "Job"
},
"RunShieldRequest": {
"type": "object",
"properties": {

View file

@ -1491,7 +1491,7 @@ paths:
content:
application/json:
schema:
$ref: '#/components/schemas/JobStatus'
$ref: '#/components/schemas/Job'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
@ -5277,21 +5277,19 @@ components:
- checkpoints
title: PostTrainingJobArtifactsResponse
description: Artifacts of a finetuning job.
JobStatus:
type: string
enum:
- completed
- in_progress
- failed
- scheduled
title: JobStatus
PostTrainingJobStatusResponse:
type: object
properties:
job_uuid:
type: string
status:
$ref: '#/components/schemas/JobStatus'
type: string
enum:
- completed
- in_progress
- failed
- scheduled
title: JobStatus
scheduled_at:
type: string
format: date-time
@ -5556,6 +5554,24 @@ components:
- data
title: IterrowsResponse
description: A paginated list of rows from a dataset.
Job:
type: object
properties:
job_id:
type: string
status:
type: string
enum:
- completed
- in_progress
- failed
- scheduled
title: JobStatus
additionalProperties: false
required:
- job_id
- status
title: Job
ListAgentSessionsResponse:
type: object
properties:
@ -6550,15 +6566,6 @@ components:
required:
- benchmark_config
title: RunEvalRequest
Job:
type: object
properties:
job_id:
type: string
additionalProperties: false
required:
- job_id
title: Job
RunShieldRequest:
type: object
properties:

View file

@ -10,14 +10,14 @@ from pydantic import BaseModel
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class Job(BaseModel):
job_id: str
@json_schema_type
class JobStatus(Enum):
completed = "completed"
in_progress = "in_progress"
failed = "failed"
scheduled = "scheduled"
@json_schema_type
class Job(BaseModel):
job_id: str
status: JobStatus

View file

@ -10,7 +10,7 @@ from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.apis.agents import AgentConfig
from llama_stack.apis.common.job_types import Job, JobStatus
from llama_stack.apis.common.job_types import Job
from llama_stack.apis.inference import SamplingParams, SystemMessage
from llama_stack.apis.scoring import ScoringResult
from llama_stack.apis.scoring_functions import ScoringFnParams
@ -115,7 +115,7 @@ class Eval(Protocol):
"""
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
async def job_status(self, benchmark_id: str, job_id: str) -> JobStatus:
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
"""Get the status of a job.
:param benchmark_id: The ID of the benchmark to run the evaluation on.

View file

@ -14,13 +14,7 @@ from llama_stack.apis.common.content_types import (
)
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
from llama_stack.apis.datasets import DatasetPurpose, DataSource
from llama_stack.apis.eval import (
BenchmarkConfig,
Eval,
EvaluateResponse,
Job,
JobStatus,
)
from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseEventType,
@ -623,7 +617,7 @@ class EvalRouter(Eval):
self,
benchmark_id: str,
job_id: str,
) -> Optional[JobStatus]:
) -> Job:
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List
from tqdm import tqdm
@ -21,8 +21,8 @@ from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
from llama_stack.providers.utils.kvstore import kvstore_impl
from .....apis.common.job_types import Job
from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse, JobStatus
from .....apis.common.job_types import Job, JobStatus
from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse
from .config import MetaReferenceEvalConfig
EVAL_TASKS_PREFIX = "benchmarks:"
@ -102,7 +102,7 @@ class MetaReferenceEvalImpl(
# need job scheduler queue (ray/celery) w/ jobs api
job_id = str(len(self.jobs))
self.jobs[job_id] = res
return Job(job_id=job_id)
return Job(job_id=job_id, status=JobStatus.completed)
async def _run_agent_generation(
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
@ -216,17 +216,18 @@ class MetaReferenceEvalImpl(
return EvaluateResponse(generations=generations, scores=score_response.results)
async def job_status(self, benchmark_id: str, job_id: str) -> Optional[JobStatus]:
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
if job_id in self.jobs:
return JobStatus.completed
return Job(job_id=job_id, status=JobStatus.completed)
return None
raise ValueError(f"Job {job_id} not found")
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
raise NotImplementedError("Job cancel is not implemented yet")
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
status = await self.job_status(benchmark_id, job_id)
job = await self.job_status(benchmark_id, job_id)
status = job.status
if not status or status != JobStatus.completed:
raise ValueError(f"Job is not completed, Status: {status.value}")

View file

@ -94,7 +94,7 @@ def test_evaluate_benchmark(llama_stack_client, text_model_id, scoring_fn_id):
)
assert response.job_id == "0"
job_status = llama_stack_client.eval.jobs.status(job_id=response.job_id, benchmark_id=benchmark_id)
assert job_status and job_status == "completed"
assert job_status and job_status.status == "completed"
eval_response = llama_stack_client.eval.jobs.retrieve(job_id=response.job_id, benchmark_id=benchmark_id)
assert eval_response is not None