fix: fix jobs api literal return type (#1757)

# What does this PR do?

- We cannot directly return a literal type

> Note: this is not final jobs API change

[//]: # (If resolving an issue, uncomment and update the line below)
[//]: # (Closes #[issue-number])

## Test Plan
<img width="837" alt="image"
src="https://github.com/user-attachments/assets/18a17561-35f9-443d-987d-54afdd6ff40c"
/>


[//]: # (## Documentation)
This commit is contained in:
Xi Yan 2025-03-21 14:04:21 -07:00 committed by GitHub
parent d6887f46c6
commit baf68c665c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 79 additions and 69 deletions

View file

@ -2183,7 +2183,7 @@
"content": { "content": {
"application/json": { "application/json": {
"schema": { "schema": {
"$ref": "#/components/schemas/JobStatus" "$ref": "#/components/schemas/Job"
} }
} }
} }
@ -7648,7 +7648,13 @@
"title": "PostTrainingJobArtifactsResponse", "title": "PostTrainingJobArtifactsResponse",
"description": "Artifacts of a finetuning job." "description": "Artifacts of a finetuning job."
}, },
"JobStatus": { "PostTrainingJobStatusResponse": {
"type": "object",
"properties": {
"job_uuid": {
"type": "string"
},
"status": {
"type": "string", "type": "string",
"enum": [ "enum": [
"completed", "completed",
@ -7658,15 +7664,6 @@
], ],
"title": "JobStatus" "title": "JobStatus"
}, },
"PostTrainingJobStatusResponse": {
"type": "object",
"properties": {
"job_uuid": {
"type": "string"
},
"status": {
"$ref": "#/components/schemas/JobStatus"
},
"scheduled_at": { "scheduled_at": {
"type": "string", "type": "string",
"format": "date-time" "format": "date-time"
@ -8115,6 +8112,30 @@
"title": "IterrowsResponse", "title": "IterrowsResponse",
"description": "A paginated list of rows from a dataset." "description": "A paginated list of rows from a dataset."
}, },
"Job": {
"type": "object",
"properties": {
"job_id": {
"type": "string"
},
"status": {
"type": "string",
"enum": [
"completed",
"in_progress",
"failed",
"scheduled"
],
"title": "JobStatus"
}
},
"additionalProperties": false,
"required": [
"job_id",
"status"
],
"title": "Job"
},
"ListAgentSessionsResponse": { "ListAgentSessionsResponse": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -9639,19 +9660,6 @@
], ],
"title": "RunEvalRequest" "title": "RunEvalRequest"
}, },
"Job": {
"type": "object",
"properties": {
"job_id": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"job_id"
],
"title": "Job"
},
"RunShieldRequest": { "RunShieldRequest": {
"type": "object", "type": "object",
"properties": { "properties": {

View file

@ -1491,7 +1491,7 @@ paths:
content: content:
application/json: application/json:
schema: schema:
$ref: '#/components/schemas/JobStatus' $ref: '#/components/schemas/Job'
'400': '400':
$ref: '#/components/responses/BadRequest400' $ref: '#/components/responses/BadRequest400'
'429': '429':
@ -5277,7 +5277,12 @@ components:
- checkpoints - checkpoints
title: PostTrainingJobArtifactsResponse title: PostTrainingJobArtifactsResponse
description: Artifacts of a finetuning job. description: Artifacts of a finetuning job.
JobStatus: PostTrainingJobStatusResponse:
type: object
properties:
job_uuid:
type: string
status:
type: string type: string
enum: enum:
- completed - completed
@ -5285,13 +5290,6 @@ components:
- failed - failed
- scheduled - scheduled
title: JobStatus title: JobStatus
PostTrainingJobStatusResponse:
type: object
properties:
job_uuid:
type: string
status:
$ref: '#/components/schemas/JobStatus'
scheduled_at: scheduled_at:
type: string type: string
format: date-time format: date-time
@ -5556,6 +5554,24 @@ components:
- data - data
title: IterrowsResponse title: IterrowsResponse
description: A paginated list of rows from a dataset. description: A paginated list of rows from a dataset.
Job:
type: object
properties:
job_id:
type: string
status:
type: string
enum:
- completed
- in_progress
- failed
- scheduled
title: JobStatus
additionalProperties: false
required:
- job_id
- status
title: Job
ListAgentSessionsResponse: ListAgentSessionsResponse:
type: object type: object
properties: properties:
@ -6550,15 +6566,6 @@ components:
required: required:
- benchmark_config - benchmark_config
title: RunEvalRequest title: RunEvalRequest
Job:
type: object
properties:
job_id:
type: string
additionalProperties: false
required:
- job_id
title: Job
RunShieldRequest: RunShieldRequest:
type: object type: object
properties: properties:

View file

@ -10,14 +10,14 @@ from pydantic import BaseModel
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
@json_schema_type
class Job(BaseModel):
job_id: str
@json_schema_type
class JobStatus(Enum): class JobStatus(Enum):
completed = "completed" completed = "completed"
in_progress = "in_progress" in_progress = "in_progress"
failed = "failed" failed = "failed"
scheduled = "scheduled" scheduled = "scheduled"
@json_schema_type
class Job(BaseModel):
job_id: str
status: JobStatus

View file

@ -10,7 +10,7 @@ from pydantic import BaseModel, Field
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_stack.apis.agents import AgentConfig from llama_stack.apis.agents import AgentConfig
from llama_stack.apis.common.job_types import Job, JobStatus from llama_stack.apis.common.job_types import Job
from llama_stack.apis.inference import SamplingParams, SystemMessage from llama_stack.apis.inference import SamplingParams, SystemMessage
from llama_stack.apis.scoring import ScoringResult from llama_stack.apis.scoring import ScoringResult
from llama_stack.apis.scoring_functions import ScoringFnParams from llama_stack.apis.scoring_functions import ScoringFnParams
@ -115,7 +115,7 @@ class Eval(Protocol):
""" """
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET") @webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
async def job_status(self, benchmark_id: str, job_id: str) -> JobStatus: async def job_status(self, benchmark_id: str, job_id: str) -> Job:
"""Get the status of a job. """Get the status of a job.
:param benchmark_id: The ID of the benchmark to run the evaluation on. :param benchmark_id: The ID of the benchmark to run the evaluation on.

View file

@ -14,13 +14,7 @@ from llama_stack.apis.common.content_types import (
) )
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
from llama_stack.apis.datasets import DatasetPurpose, DataSource from llama_stack.apis.datasets import DatasetPurpose, DataSource
from llama_stack.apis.eval import ( from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job
BenchmarkConfig,
Eval,
EvaluateResponse,
Job,
JobStatus,
)
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionResponse, ChatCompletionResponse,
ChatCompletionResponseEventType, ChatCompletionResponseEventType,
@ -623,7 +617,7 @@ class EvalRouter(Eval):
self, self,
benchmark_id: str, benchmark_id: str,
job_id: str, job_id: str,
) -> Optional[JobStatus]: ) -> Job:
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}") logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id) return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import json import json
from typing import Any, Dict, List, Optional from typing import Any, Dict, List
from tqdm import tqdm from tqdm import tqdm
@ -21,8 +21,8 @@ from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
from llama_stack.providers.utils.common.data_schema_validator import ColumnName from llama_stack.providers.utils.common.data_schema_validator import ColumnName
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl
from .....apis.common.job_types import Job from .....apis.common.job_types import Job, JobStatus
from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse, JobStatus from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse
from .config import MetaReferenceEvalConfig from .config import MetaReferenceEvalConfig
EVAL_TASKS_PREFIX = "benchmarks:" EVAL_TASKS_PREFIX = "benchmarks:"
@ -102,7 +102,7 @@ class MetaReferenceEvalImpl(
# need job scheduler queue (ray/celery) w/ jobs api # need job scheduler queue (ray/celery) w/ jobs api
job_id = str(len(self.jobs)) job_id = str(len(self.jobs))
self.jobs[job_id] = res self.jobs[job_id] = res
return Job(job_id=job_id) return Job(job_id=job_id, status=JobStatus.completed)
async def _run_agent_generation( async def _run_agent_generation(
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
@ -216,17 +216,18 @@ class MetaReferenceEvalImpl(
return EvaluateResponse(generations=generations, scores=score_response.results) return EvaluateResponse(generations=generations, scores=score_response.results)
async def job_status(self, benchmark_id: str, job_id: str) -> Optional[JobStatus]: async def job_status(self, benchmark_id: str, job_id: str) -> Job:
if job_id in self.jobs: if job_id in self.jobs:
return JobStatus.completed return Job(job_id=job_id, status=JobStatus.completed)
return None raise ValueError(f"Job {job_id} not found")
async def job_cancel(self, benchmark_id: str, job_id: str) -> None: async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
raise NotImplementedError("Job cancel is not implemented yet") raise NotImplementedError("Job cancel is not implemented yet")
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse: async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
status = await self.job_status(benchmark_id, job_id) job = await self.job_status(benchmark_id, job_id)
status = job.status
if not status or status != JobStatus.completed: if not status or status != JobStatus.completed:
raise ValueError(f"Job is not completed, Status: {status.value}") raise ValueError(f"Job is not completed, Status: {status.value}")

View file

@ -94,7 +94,7 @@ def test_evaluate_benchmark(llama_stack_client, text_model_id, scoring_fn_id):
) )
assert response.job_id == "0" assert response.job_id == "0"
job_status = llama_stack_client.eval.jobs.status(job_id=response.job_id, benchmark_id=benchmark_id) job_status = llama_stack_client.eval.jobs.status(job_id=response.job_id, benchmark_id=benchmark_id)
assert job_status and job_status == "completed" assert job_status and job_status.status == "completed"
eval_response = llama_stack_client.eval.jobs.retrieve(job_id=response.job_id, benchmark_id=benchmark_id) eval_response = llama_stack_client.eval.jobs.retrieve(job_id=response.job_id, benchmark_id=benchmark_id)
assert eval_response is not None assert eval_response is not None