mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
fix: fix jobs api literal return type (#1757)
# What does this PR do? - We cannot directly return a literal type > Note: this is not final jobs API change [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan <img width="837" alt="image" src="https://github.com/user-attachments/assets/18a17561-35f9-443d-987d-54afdd6ff40c" /> [//]: # (## Documentation)
This commit is contained in:
parent
d6887f46c6
commit
baf68c665c
7 changed files with 79 additions and 69 deletions
58
docs/_static/llama-stack-spec.html
vendored
58
docs/_static/llama-stack-spec.html
vendored
|
@ -2183,7 +2183,7 @@
|
|||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/JobStatus"
|
||||
"$ref": "#/components/schemas/Job"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -7648,16 +7648,6 @@
|
|||
"title": "PostTrainingJobArtifactsResponse",
|
||||
"description": "Artifacts of a finetuning job."
|
||||
},
|
||||
"JobStatus": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"completed",
|
||||
"in_progress",
|
||||
"failed",
|
||||
"scheduled"
|
||||
],
|
||||
"title": "JobStatus"
|
||||
},
|
||||
"PostTrainingJobStatusResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
@ -7665,7 +7655,14 @@
|
|||
"type": "string"
|
||||
},
|
||||
"status": {
|
||||
"$ref": "#/components/schemas/JobStatus"
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"completed",
|
||||
"in_progress",
|
||||
"failed",
|
||||
"scheduled"
|
||||
],
|
||||
"title": "JobStatus"
|
||||
},
|
||||
"scheduled_at": {
|
||||
"type": "string",
|
||||
|
@ -8115,6 +8112,30 @@
|
|||
"title": "IterrowsResponse",
|
||||
"description": "A paginated list of rows from a dataset."
|
||||
},
|
||||
"Job": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"job_id": {
|
||||
"type": "string"
|
||||
},
|
||||
"status": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"completed",
|
||||
"in_progress",
|
||||
"failed",
|
||||
"scheduled"
|
||||
],
|
||||
"title": "JobStatus"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"job_id",
|
||||
"status"
|
||||
],
|
||||
"title": "Job"
|
||||
},
|
||||
"ListAgentSessionsResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
@ -9639,19 +9660,6 @@
|
|||
],
|
||||
"title": "RunEvalRequest"
|
||||
},
|
||||
"Job": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"job_id": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"job_id"
|
||||
],
|
||||
"title": "Job"
|
||||
},
|
||||
"RunShieldRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
|
45
docs/_static/llama-stack-spec.yaml
vendored
45
docs/_static/llama-stack-spec.yaml
vendored
|
@ -1491,7 +1491,7 @@ paths:
|
|||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/JobStatus'
|
||||
$ref: '#/components/schemas/Job'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
|
@ -5277,21 +5277,19 @@ components:
|
|||
- checkpoints
|
||||
title: PostTrainingJobArtifactsResponse
|
||||
description: Artifacts of a finetuning job.
|
||||
JobStatus:
|
||||
type: string
|
||||
enum:
|
||||
- completed
|
||||
- in_progress
|
||||
- failed
|
||||
- scheduled
|
||||
title: JobStatus
|
||||
PostTrainingJobStatusResponse:
|
||||
type: object
|
||||
properties:
|
||||
job_uuid:
|
||||
type: string
|
||||
status:
|
||||
$ref: '#/components/schemas/JobStatus'
|
||||
type: string
|
||||
enum:
|
||||
- completed
|
||||
- in_progress
|
||||
- failed
|
||||
- scheduled
|
||||
title: JobStatus
|
||||
scheduled_at:
|
||||
type: string
|
||||
format: date-time
|
||||
|
@ -5556,6 +5554,24 @@ components:
|
|||
- data
|
||||
title: IterrowsResponse
|
||||
description: A paginated list of rows from a dataset.
|
||||
Job:
|
||||
type: object
|
||||
properties:
|
||||
job_id:
|
||||
type: string
|
||||
status:
|
||||
type: string
|
||||
enum:
|
||||
- completed
|
||||
- in_progress
|
||||
- failed
|
||||
- scheduled
|
||||
title: JobStatus
|
||||
additionalProperties: false
|
||||
required:
|
||||
- job_id
|
||||
- status
|
||||
title: Job
|
||||
ListAgentSessionsResponse:
|
||||
type: object
|
||||
properties:
|
||||
|
@ -6550,15 +6566,6 @@ components:
|
|||
required:
|
||||
- benchmark_config
|
||||
title: RunEvalRequest
|
||||
Job:
|
||||
type: object
|
||||
properties:
|
||||
job_id:
|
||||
type: string
|
||||
additionalProperties: false
|
||||
required:
|
||||
- job_id
|
||||
title: Job
|
||||
RunShieldRequest:
|
||||
type: object
|
||||
properties:
|
||||
|
|
|
@ -10,14 +10,14 @@ from pydantic import BaseModel
|
|||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Job(BaseModel):
|
||||
job_id: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class JobStatus(Enum):
|
||||
completed = "completed"
|
||||
in_progress = "in_progress"
|
||||
failed = "failed"
|
||||
scheduled = "scheduled"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Job(BaseModel):
|
||||
job_id: str
|
||||
status: JobStatus
|
||||
|
|
|
@ -10,7 +10,7 @@ from pydantic import BaseModel, Field
|
|||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.apis.agents import AgentConfig
|
||||
from llama_stack.apis.common.job_types import Job, JobStatus
|
||||
from llama_stack.apis.common.job_types import Job
|
||||
from llama_stack.apis.inference import SamplingParams, SystemMessage
|
||||
from llama_stack.apis.scoring import ScoringResult
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
|
@ -115,7 +115,7 @@ class Eval(Protocol):
|
|||
"""
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
|
||||
async def job_status(self, benchmark_id: str, job_id: str) -> JobStatus:
|
||||
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
|
||||
"""Get the status of a job.
|
||||
|
||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||
|
|
|
@ -14,13 +14,7 @@ from llama_stack.apis.common.content_types import (
|
|||
)
|
||||
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
|
||||
from llama_stack.apis.datasets import DatasetPurpose, DataSource
|
||||
from llama_stack.apis.eval import (
|
||||
BenchmarkConfig,
|
||||
Eval,
|
||||
EvaluateResponse,
|
||||
Job,
|
||||
JobStatus,
|
||||
)
|
||||
from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEventType,
|
||||
|
@ -623,7 +617,7 @@ class EvalRouter(Eval):
|
|||
self,
|
||||
benchmark_id: str,
|
||||
job_id: str,
|
||||
) -> Optional[JobStatus]:
|
||||
) -> Job:
|
||||
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
|
||||
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
|
@ -21,8 +21,8 @@ from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
|
|||
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
||||
from .....apis.common.job_types import Job
|
||||
from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse, JobStatus
|
||||
from .....apis.common.job_types import Job, JobStatus
|
||||
from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse
|
||||
from .config import MetaReferenceEvalConfig
|
||||
|
||||
EVAL_TASKS_PREFIX = "benchmarks:"
|
||||
|
@ -102,7 +102,7 @@ class MetaReferenceEvalImpl(
|
|||
# need job scheduler queue (ray/celery) w/ jobs api
|
||||
job_id = str(len(self.jobs))
|
||||
self.jobs[job_id] = res
|
||||
return Job(job_id=job_id)
|
||||
return Job(job_id=job_id, status=JobStatus.completed)
|
||||
|
||||
async def _run_agent_generation(
|
||||
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
|
||||
|
@ -216,17 +216,18 @@ class MetaReferenceEvalImpl(
|
|||
|
||||
return EvaluateResponse(generations=generations, scores=score_response.results)
|
||||
|
||||
async def job_status(self, benchmark_id: str, job_id: str) -> Optional[JobStatus]:
|
||||
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
|
||||
if job_id in self.jobs:
|
||||
return JobStatus.completed
|
||||
return Job(job_id=job_id, status=JobStatus.completed)
|
||||
|
||||
return None
|
||||
raise ValueError(f"Job {job_id} not found")
|
||||
|
||||
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
|
||||
raise NotImplementedError("Job cancel is not implemented yet")
|
||||
|
||||
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
|
||||
status = await self.job_status(benchmark_id, job_id)
|
||||
job = await self.job_status(benchmark_id, job_id)
|
||||
status = job.status
|
||||
if not status or status != JobStatus.completed:
|
||||
raise ValueError(f"Job is not completed, Status: {status.value}")
|
||||
|
||||
|
|
|
@ -94,7 +94,7 @@ def test_evaluate_benchmark(llama_stack_client, text_model_id, scoring_fn_id):
|
|||
)
|
||||
assert response.job_id == "0"
|
||||
job_status = llama_stack_client.eval.jobs.status(job_id=response.job_id, benchmark_id=benchmark_id)
|
||||
assert job_status and job_status == "completed"
|
||||
assert job_status and job_status.status == "completed"
|
||||
|
||||
eval_response = llama_stack_client.eval.jobs.retrieve(job_id=response.job_id, benchmark_id=benchmark_id)
|
||||
assert eval_response is not None
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue