diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html
index c81f9b33d..8a46a89ad 100644
--- a/docs/_static/llama-stack-spec.html
+++ b/docs/_static/llama-stack-spec.html
@@ -2183,7 +2183,7 @@
"content": {
"application/json": {
"schema": {
- "$ref": "#/components/schemas/JobStatus"
+ "$ref": "#/components/schemas/Job"
}
}
}
@@ -7648,16 +7648,6 @@
"title": "PostTrainingJobArtifactsResponse",
"description": "Artifacts of a finetuning job."
},
- "JobStatus": {
- "type": "string",
- "enum": [
- "completed",
- "in_progress",
- "failed",
- "scheduled"
- ],
- "title": "JobStatus"
- },
"PostTrainingJobStatusResponse": {
"type": "object",
"properties": {
@@ -7665,7 +7655,14 @@
"type": "string"
},
"status": {
- "$ref": "#/components/schemas/JobStatus"
+ "type": "string",
+ "enum": [
+ "completed",
+ "in_progress",
+ "failed",
+ "scheduled"
+ ],
+ "title": "JobStatus"
},
"scheduled_at": {
"type": "string",
@@ -8115,6 +8112,30 @@
"title": "IterrowsResponse",
"description": "A paginated list of rows from a dataset."
},
+ "Job": {
+ "type": "object",
+ "properties": {
+ "job_id": {
+ "type": "string"
+ },
+ "status": {
+ "type": "string",
+ "enum": [
+ "completed",
+ "in_progress",
+ "failed",
+ "scheduled"
+ ],
+ "title": "JobStatus"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "job_id",
+ "status"
+ ],
+ "title": "Job"
+ },
"ListAgentSessionsResponse": {
"type": "object",
"properties": {
@@ -9639,19 +9660,6 @@
],
"title": "RunEvalRequest"
},
- "Job": {
- "type": "object",
- "properties": {
- "job_id": {
- "type": "string"
- }
- },
- "additionalProperties": false,
- "required": [
- "job_id"
- ],
- "title": "Job"
- },
"RunShieldRequest": {
"type": "object",
"properties": {
diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml
index 8ea0e1b9c..0b8f90490 100644
--- a/docs/_static/llama-stack-spec.yaml
+++ b/docs/_static/llama-stack-spec.yaml
@@ -1491,7 +1491,7 @@ paths:
content:
application/json:
schema:
- $ref: '#/components/schemas/JobStatus'
+ $ref: '#/components/schemas/Job'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
@@ -5277,21 +5277,19 @@ components:
- checkpoints
title: PostTrainingJobArtifactsResponse
description: Artifacts of a finetuning job.
- JobStatus:
- type: string
- enum:
- - completed
- - in_progress
- - failed
- - scheduled
- title: JobStatus
PostTrainingJobStatusResponse:
type: object
properties:
job_uuid:
type: string
status:
- $ref: '#/components/schemas/JobStatus'
+ type: string
+ enum:
+ - completed
+ - in_progress
+ - failed
+ - scheduled
+ title: JobStatus
scheduled_at:
type: string
format: date-time
@@ -5556,6 +5554,24 @@ components:
- data
title: IterrowsResponse
description: A paginated list of rows from a dataset.
+ Job:
+ type: object
+ properties:
+ job_id:
+ type: string
+ status:
+ type: string
+ enum:
+ - completed
+ - in_progress
+ - failed
+ - scheduled
+ title: JobStatus
+ additionalProperties: false
+ required:
+ - job_id
+ - status
+ title: Job
ListAgentSessionsResponse:
type: object
properties:
@@ -6550,15 +6566,6 @@ components:
required:
- benchmark_config
title: RunEvalRequest
- Job:
- type: object
- properties:
- job_id:
- type: string
- additionalProperties: false
- required:
- - job_id
- title: Job
RunShieldRequest:
type: object
properties:
diff --git a/llama_stack/apis/common/job_types.py b/llama_stack/apis/common/job_types.py
index bc070017b..9acecc154 100644
--- a/llama_stack/apis/common/job_types.py
+++ b/llama_stack/apis/common/job_types.py
@@ -10,14 +10,14 @@ from pydantic import BaseModel
from llama_stack.schema_utils import json_schema_type
-@json_schema_type
-class Job(BaseModel):
- job_id: str
-
-
-@json_schema_type
class JobStatus(Enum):
completed = "completed"
in_progress = "in_progress"
failed = "failed"
scheduled = "scheduled"
+
+
+@json_schema_type
+class Job(BaseModel):
+ job_id: str
+ status: JobStatus
diff --git a/llama_stack/apis/eval/eval.py b/llama_stack/apis/eval/eval.py
index d05786321..0e5959c37 100644
--- a/llama_stack/apis/eval/eval.py
+++ b/llama_stack/apis/eval/eval.py
@@ -10,7 +10,7 @@ from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.apis.agents import AgentConfig
-from llama_stack.apis.common.job_types import Job, JobStatus
+from llama_stack.apis.common.job_types import Job
from llama_stack.apis.inference import SamplingParams, SystemMessage
from llama_stack.apis.scoring import ScoringResult
from llama_stack.apis.scoring_functions import ScoringFnParams
@@ -115,7 +115,7 @@ class Eval(Protocol):
"""
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
- async def job_status(self, benchmark_id: str, job_id: str) -> JobStatus:
+ async def job_status(self, benchmark_id: str, job_id: str) -> Job:
"""Get the status of a job.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py
index 2cf38f544..6ff36a65c 100644
--- a/llama_stack/distribution/routers/routers.py
+++ b/llama_stack/distribution/routers/routers.py
@@ -14,13 +14,7 @@ from llama_stack.apis.common.content_types import (
)
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
from llama_stack.apis.datasets import DatasetPurpose, DataSource
-from llama_stack.apis.eval import (
- BenchmarkConfig,
- Eval,
- EvaluateResponse,
- Job,
- JobStatus,
-)
+from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseEventType,
@@ -623,7 +617,7 @@ class EvalRouter(Eval):
self,
benchmark_id: str,
job_id: str,
- ) -> Optional[JobStatus]:
+ ) -> Job:
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
diff --git a/llama_stack/providers/inline/eval/meta_reference/eval.py b/llama_stack/providers/inline/eval/meta_reference/eval.py
index 3630d4c03..7c28f1bb7 100644
--- a/llama_stack/providers/inline/eval/meta_reference/eval.py
+++ b/llama_stack/providers/inline/eval/meta_reference/eval.py
@@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List
from tqdm import tqdm
@@ -21,8 +21,8 @@ from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
from llama_stack.providers.utils.kvstore import kvstore_impl
-from .....apis.common.job_types import Job
-from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse, JobStatus
+from .....apis.common.job_types import Job, JobStatus
+from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse
from .config import MetaReferenceEvalConfig
EVAL_TASKS_PREFIX = "benchmarks:"
@@ -102,7 +102,7 @@ class MetaReferenceEvalImpl(
# need job scheduler queue (ray/celery) w/ jobs api
job_id = str(len(self.jobs))
self.jobs[job_id] = res
- return Job(job_id=job_id)
+ return Job(job_id=job_id, status=JobStatus.completed)
async def _run_agent_generation(
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
@@ -216,17 +216,18 @@ class MetaReferenceEvalImpl(
return EvaluateResponse(generations=generations, scores=score_response.results)
- async def job_status(self, benchmark_id: str, job_id: str) -> Optional[JobStatus]:
+ async def job_status(self, benchmark_id: str, job_id: str) -> Job:
if job_id in self.jobs:
- return JobStatus.completed
+ return Job(job_id=job_id, status=JobStatus.completed)
- return None
+ raise ValueError(f"Job {job_id} not found")
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
raise NotImplementedError("Job cancel is not implemented yet")
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
- status = await self.job_status(benchmark_id, job_id)
+ job = await self.job_status(benchmark_id, job_id)
+ status = job.status
if not status or status != JobStatus.completed:
raise ValueError(f"Job is not completed, Status: {status.value}")
diff --git a/tests/integration/eval/test_eval.py b/tests/integration/eval/test_eval.py
index c4aa0fa1b..d1c3de519 100644
--- a/tests/integration/eval/test_eval.py
+++ b/tests/integration/eval/test_eval.py
@@ -94,7 +94,7 @@ def test_evaluate_benchmark(llama_stack_client, text_model_id, scoring_fn_id):
)
assert response.job_id == "0"
job_status = llama_stack_client.eval.jobs.status(job_id=response.job_id, benchmark_id=benchmark_id)
- assert job_status and job_status == "completed"
+ assert job_status and job_status.status == "completed"
eval_response = llama_stack_client.eval.jobs.retrieve(job_id=response.job_id, benchmark_id=benchmark_id)
assert eval_response is not None