feat(api): define a more coherent jobs api across different flows

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-03-24 20:54:04 -04:00
parent 71ed47ea76
commit 0f50cfa561
15 changed files with 1864 additions and 1670 deletions

View file

@ -4,9 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional, Protocol, runtime_checkable
from typing import List, Literal, Optional, Protocol, runtime_checkable
from llama_stack.apis.common.job_types import Job
from pydantic import BaseModel
from llama_stack.apis.common.job_types import BaseJob
from llama_stack.apis.inference import (
InterleavedContent,
LogProbConfig,
@ -20,6 +22,14 @@ from llama_stack.apis.inference import (
from llama_stack.schema_utils import webmethod
class BatchInferenceJob(BaseJob, BaseModel):
type: Literal["batch_inference"] = "batch_inference"
class ListBatchInferenceJobsResponse(BaseModel):
data: list[BatchInferenceJob]
@runtime_checkable
class BatchInference(Protocol):
"""Batch inference API for generating completions and chat completions.
@ -38,7 +48,7 @@ class BatchInference(Protocol):
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
) -> Job: ...
) -> BatchInferenceJob: ...
@webmethod(route="/batch-inference/chat-completion", method="POST")
async def chat_completion(
@ -52,4 +62,4 @@ class BatchInference(Protocol):
tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
) -> Job: ...
) -> BatchInferenceJob: ...

View file

@ -3,22 +3,68 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from datetime import datetime, timezone
from enum import Enum, unique
from typing import Any
from pydantic import BaseModel
from pydantic import BaseModel, Field, computed_field
from llama_stack.schema_utils import json_schema_type
@unique
class JobStatus(Enum):
completed = "completed"
in_progress = "in_progress"
failed = "failed"
unknown = "unknown"
new = "new"
scheduled = "scheduled"
running = "running"
paused = "paused"
resuming = "resuming"
cancelled = "cancelled"
failed = "failed"
completed = "completed"
@json_schema_type
class Job(BaseModel):
job_id: str
class JobStatusDetails(BaseModel):
status: JobStatus
message: str | None = None
timestamp: datetime
@json_schema_type
class JobArtifact(BaseModel):
name: str
# TODO: should it be a Literal / Enum?
type: str
# Any additional metadata the artifact may have
# TODO: is Any the right type here? What happens when the caller passes a value without a __repr__?
metadata: dict[str, Any] | None = None
# TODO: enforce type to be a URI
uri: str | None = None # points to /files
def _get_job_status_details(status: JobStatus) -> JobStatusDetails:
return JobStatusDetails(status=status, timestamp=datetime.now(timezone.utc))
class BaseJob(BaseModel):
id: str # TODO: make it a UUID?
artifacts: list[JobArtifact] = Field(default_factory=list)
events: list[JobStatusDetails] = Field(default_factory=lambda: [_get_job_status_details(JobStatus.new)])
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if "type" not in cls.__annotations__:
raise ValueError(f"Class {cls.__name__} must have a type field")
@computed_field
def status(self) -> JobStatus:
return self.events[-1].status
def update_status(self, value: JobStatus):
self.events.append(_get_job_status_details(value))

View file

@ -4,15 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
from typing import Dict, Literal, Optional, Protocol, Union
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.apis.agents import AgentConfig
from llama_stack.apis.common.job_types import Job
from llama_stack.apis.common.job_types import BaseJob
from llama_stack.apis.inference import SamplingParams, SystemMessage
from llama_stack.apis.scoring import ScoringResult
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@ -47,6 +46,14 @@ EvalCandidate = Annotated[Union[ModelCandidate, AgentCandidate], Field(discrimin
register_schema(EvalCandidate, name="EvalCandidate")
class EvaluateJob(BaseJob, BaseModel):
type: Literal["eval"] = "eval"
class ListEvaluateJobsResponse(BaseModel):
data: list[EvaluateJob]
@json_schema_type
class BenchmarkConfig(BaseModel):
"""A benchmark configuration for evaluation.
@ -68,76 +75,30 @@ class BenchmarkConfig(BaseModel):
# we could optinally add any specific dataset config here
@json_schema_type
class EvaluateResponse(BaseModel):
"""The response from an evaluation.
:param generations: The generations from the evaluation.
:param scores: The scores from the evaluation.
"""
generations: List[Dict[str, Any]]
# each key in the dict is a scoring function name
scores: Dict[str, ScoringResult]
class Eval(Protocol):
"""Llama Stack Evaluation API for running evaluations on model and agent candidates."""
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST")
async def run_eval(
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluate", method="POST")
async def evaluate(
self,
benchmark_id: str,
benchmark_config: BenchmarkConfig,
) -> Job:
"""Run an evaluation on a benchmark.
) -> EvaluateJob: ...
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param benchmark_config: The configuration for the benchmark.
:return: The job that was created to run the evaluation.
"""
# CRUD operations on running jobs
@webmethod(route="/evaluate/jobs/{job_id:path}", method="GET")
async def get_evaluate_job(self, job_id: str) -> EvaluateJob: ...
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST")
async def evaluate_rows(
self,
benchmark_id: str,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
benchmark_config: BenchmarkConfig,
) -> EvaluateResponse:
"""Evaluate a list of rows on a benchmark.
@webmethod(route="/evaluate/jobs", method="GET")
async def list_evaluate_jobs(self) -> ListEvaluateJobsResponse: ...
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param input_rows: The rows to evaluate.
:param scoring_functions: The scoring functions to use for the evaluation.
:param benchmark_config: The configuration for the benchmark.
:return: EvaluateResponse object containing generations and scores
"""
@webmethod(route="/evaluate/jobs/{job_id:path}", method="POST")
async def update_evaluate_job(self, job: EvaluateJob) -> EvaluateJob: ...
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
"""Get the status of a job.
@webmethod(route="/evaluate/job/{job_id:path}", method="DELETE")
async def delete_evaluate_job(self, job_id: str) -> None: ...
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param job_id: The ID of the job to get the status of.
:return: The status of the evaluationjob.
"""
...
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="DELETE")
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
"""Cancel a job.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param job_id: The ID of the job to cancel.
"""
...
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result", method="GET")
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
"""Get the result of a job.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param job_id: The ID of the job to get the result of.
:return: The result of the job.
"""
# Note: pause/resume/cancel are achieved as follows:
# - POST with status=paused
# - POST with status=resuming
# - POST with status=cancelled

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
@ -12,8 +11,7 @@ from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.job_types import JobStatus
from llama_stack.apis.common.training_types import Checkpoint
from llama_stack.apis.common.job_types import BaseJob
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@ -92,14 +90,6 @@ AlgorithmConfig = Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Fi
register_schema(AlgorithmConfig, name="AlgorithmConfig")
@json_schema_type
class PostTrainingJobLogStream(BaseModel):
"""Stream of logs from a finetuning job."""
job_uuid: str
log_lines: List[str]
@json_schema_type
class RLHFAlgorithm(Enum):
dpo = "dpo"
@ -135,41 +125,17 @@ class PostTrainingRLHFRequest(BaseModel):
logger_config: Dict[str, Any]
class PostTrainingJob(BaseModel):
job_uuid: str
@json_schema_type
class PostTrainingJobStatusResponse(BaseModel):
"""Status of a finetuning job."""
job_uuid: str
status: JobStatus
scheduled_at: Optional[datetime] = None
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
resources_allocated: Optional[Dict[str, Any]] = None
checkpoints: List[Checkpoint] = Field(default_factory=list)
class PostTrainingJob(BaseJob, BaseModel):
type: Literal["post-training"] = "post-training"
class ListPostTrainingJobsResponse(BaseModel):
data: List[PostTrainingJob]
@json_schema_type
class PostTrainingJobArtifactsResponse(BaseModel):
"""Artifacts of a finetuning job."""
job_uuid: str
checkpoints: List[Checkpoint] = Field(default_factory=list)
# TODO(ashwin): metrics, evals
data: list[PostTrainingJob]
class PostTraining(Protocol):
# This is how you create a new job - POST against the root endpoint
@webmethod(route="/post-training/supervised-fine-tune", method="POST")
async def supervised_fine_tune(
self,
@ -196,14 +162,20 @@ class PostTraining(Protocol):
logger_config: Dict[str, Any],
) -> PostTrainingJob: ...
# CRUD operations on running jobs
@webmethod(route="/post-training/jobs/{job_id:path}", method="GET")
async def get_post_training_job(self, job_id: str) -> PostTrainingJob: ...
@webmethod(route="/post-training/jobs", method="GET")
async def get_training_jobs(self) -> ListPostTrainingJobsResponse: ...
async def list_post_training_jobs(self) -> ListPostTrainingJobsResponse: ...
@webmethod(route="/post-training/job/status", method="GET")
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse: ...
@webmethod(route="/post-training/jobs/{job_id:path}", method="POST")
async def update_post_training_job(self, job: PostTrainingJob) -> PostTrainingJob: ...
@webmethod(route="/post-training/job/cancel", method="POST")
async def cancel_training_job(self, job_uuid: str) -> None: ...
@webmethod(route="/post-training/job/{job_id:path}", method="DELETE")
async def delete_post_training_job(self, job_id: str) -> None: ...
@webmethod(route="/post-training/job/artifacts", method="GET")
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse: ...
# Note: pause/resume/cancel are achieved as follows:
# - POST with status=paused
# - POST with status=resuming
# - POST with status=cancelled

View file

@ -5,10 +5,11 @@
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List, Optional, Protocol, Union
from typing import List, Literal, Optional, Protocol
from pydantic import BaseModel
from llama_stack.apis.common.job_types import BaseJob
from llama_stack.apis.inference import Message
from llama_stack.schema_utils import json_schema_type, webmethod
@ -34,11 +35,13 @@ class SyntheticDataGenerationRequest(BaseModel):
@json_schema_type
class SyntheticDataGenerationResponse(BaseModel):
"""Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold."""
class SyntheticDataGenerationJob(BaseJob, BaseModel):
type: Literal["synthetic-data-generation"] = "synthetic-data-generation"
synthetic_data: List[Dict[str, Any]]
statistics: Optional[Dict[str, Any]] = None
@json_schema_type
class ListSyntheticDataGenerationJobsResponse(BaseModel):
items: list[SyntheticDataGenerationJob]
class SyntheticDataGeneration(Protocol):
@ -48,4 +51,24 @@ class SyntheticDataGeneration(Protocol):
dialogs: List[Message],
filtering_function: FilteringFunction = FilteringFunction.none,
model: Optional[str] = None,
) -> Union[SyntheticDataGenerationResponse]: ...
) -> SyntheticDataGenerationJob: ...
# CRUD operations on running jobs
@webmethod(route="/synthetic-data-generation/jobs/{job_id:path}", method="GET")
async def get_synthetic_data_generation_job(self) -> SyntheticDataGenerationJob: ...
@webmethod(route="/synthetic-data-generation/jobs", method="GET")
async def list_synthetic_data_generation_jobs(self) -> ListSyntheticDataGenerationJobsResponse: ...
@webmethod(route="/synthetic-data-generation/jobs/{job_id:path}", method="POST")
async def update_synthetic_data_generation_job(
self, job: SyntheticDataGenerationJob
) -> SyntheticDataGenerationJob: ...
@webmethod(route="/synthetic-data-generation/job/{job_id:path}", method="DELETE")
async def delete_synthetic_data_generation_job(self, job_id: str) -> None: ...
# Note: pause/resume/cancel are achieved as follows:
# - POST with status=paused
# - POST with status=resuming
# - POST with status=cancelled