mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-02 11:50:01 +00:00
feat(api): define a more coherent jobs api across different flows
Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
71ed47ea76
commit
0f50cfa561
15 changed files with 1864 additions and 1670 deletions
|
|
@ -4,9 +4,11 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List, Optional, Protocol, runtime_checkable
|
||||
from typing import List, Literal, Optional, Protocol, runtime_checkable
|
||||
|
||||
from llama_stack.apis.common.job_types import Job
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.common.job_types import BaseJob
|
||||
from llama_stack.apis.inference import (
|
||||
InterleavedContent,
|
||||
LogProbConfig,
|
||||
|
|
@ -20,6 +22,14 @@ from llama_stack.apis.inference import (
|
|||
from llama_stack.schema_utils import webmethod
|
||||
|
||||
|
||||
class BatchInferenceJob(BaseJob, BaseModel):
|
||||
type: Literal["batch_inference"] = "batch_inference"
|
||||
|
||||
|
||||
class ListBatchInferenceJobsResponse(BaseModel):
|
||||
data: list[BatchInferenceJob]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class BatchInference(Protocol):
|
||||
"""Batch inference API for generating completions and chat completions.
|
||||
|
|
@ -38,7 +48,7 @@ class BatchInference(Protocol):
|
|||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Job: ...
|
||||
) -> BatchInferenceJob: ...
|
||||
|
||||
@webmethod(route="/batch-inference/chat-completion", method="POST")
|
||||
async def chat_completion(
|
||||
|
|
@ -52,4 +62,4 @@ class BatchInference(Protocol):
|
|||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Job: ...
|
||||
) -> BatchInferenceJob: ...
|
||||
|
|
|
|||
|
|
@ -3,22 +3,68 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from enum import Enum
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum, unique
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field, computed_field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@unique
|
||||
class JobStatus(Enum):
|
||||
completed = "completed"
|
||||
in_progress = "in_progress"
|
||||
failed = "failed"
|
||||
unknown = "unknown"
|
||||
new = "new"
|
||||
scheduled = "scheduled"
|
||||
running = "running"
|
||||
paused = "paused"
|
||||
resuming = "resuming"
|
||||
cancelled = "cancelled"
|
||||
failed = "failed"
|
||||
completed = "completed"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Job(BaseModel):
|
||||
job_id: str
|
||||
class JobStatusDetails(BaseModel):
|
||||
status: JobStatus
|
||||
message: str | None = None
|
||||
timestamp: datetime
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class JobArtifact(BaseModel):
|
||||
name: str
|
||||
|
||||
# TODO: should it be a Literal / Enum?
|
||||
type: str
|
||||
|
||||
# Any additional metadata the artifact may have
|
||||
# TODO: is Any the right type here? What happens when the caller passes a value without a __repr__?
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
# TODO: enforce type to be a URI
|
||||
uri: str | None = None # points to /files
|
||||
|
||||
|
||||
def _get_job_status_details(status: JobStatus) -> JobStatusDetails:
|
||||
return JobStatusDetails(status=status, timestamp=datetime.now(timezone.utc))
|
||||
|
||||
|
||||
class BaseJob(BaseModel):
|
||||
id: str # TODO: make it a UUID?
|
||||
|
||||
artifacts: list[JobArtifact] = Field(default_factory=list)
|
||||
events: list[JobStatusDetails] = Field(default_factory=lambda: [_get_job_status_details(JobStatus.new)])
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
if "type" not in cls.__annotations__:
|
||||
raise ValueError(f"Class {cls.__name__} must have a type field")
|
||||
|
||||
@computed_field
|
||||
def status(self) -> JobStatus:
|
||||
return self.events[-1].status
|
||||
|
||||
def update_status(self, value: JobStatus):
|
||||
self.events.append(_get_job_status_details(value))
|
||||
|
|
|
|||
|
|
@ -4,15 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
||||
from typing import Dict, Literal, Optional, Protocol, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.apis.agents import AgentConfig
|
||||
from llama_stack.apis.common.job_types import Job
|
||||
from llama_stack.apis.common.job_types import BaseJob
|
||||
from llama_stack.apis.inference import SamplingParams, SystemMessage
|
||||
from llama_stack.apis.scoring import ScoringResult
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
|
|
@ -47,6 +46,14 @@ EvalCandidate = Annotated[Union[ModelCandidate, AgentCandidate], Field(discrimin
|
|||
register_schema(EvalCandidate, name="EvalCandidate")
|
||||
|
||||
|
||||
class EvaluateJob(BaseJob, BaseModel):
|
||||
type: Literal["eval"] = "eval"
|
||||
|
||||
|
||||
class ListEvaluateJobsResponse(BaseModel):
|
||||
data: list[EvaluateJob]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BenchmarkConfig(BaseModel):
|
||||
"""A benchmark configuration for evaluation.
|
||||
|
|
@ -68,76 +75,30 @@ class BenchmarkConfig(BaseModel):
|
|||
# we could optinally add any specific dataset config here
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class EvaluateResponse(BaseModel):
|
||||
"""The response from an evaluation.
|
||||
|
||||
:param generations: The generations from the evaluation.
|
||||
:param scores: The scores from the evaluation.
|
||||
"""
|
||||
|
||||
generations: List[Dict[str, Any]]
|
||||
# each key in the dict is a scoring function name
|
||||
scores: Dict[str, ScoringResult]
|
||||
|
||||
|
||||
class Eval(Protocol):
|
||||
"""Llama Stack Evaluation API for running evaluations on model and agent candidates."""
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST")
|
||||
async def run_eval(
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluate", method="POST")
|
||||
async def evaluate(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> Job:
|
||||
"""Run an evaluation on a benchmark.
|
||||
) -> EvaluateJob: ...
|
||||
|
||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||
:param benchmark_config: The configuration for the benchmark.
|
||||
:return: The job that was created to run the evaluation.
|
||||
"""
|
||||
# CRUD operations on running jobs
|
||||
@webmethod(route="/evaluate/jobs/{job_id:path}", method="GET")
|
||||
async def get_evaluate_job(self, job_id: str) -> EvaluateJob: ...
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST")
|
||||
async def evaluate_rows(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: List[str],
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> EvaluateResponse:
|
||||
"""Evaluate a list of rows on a benchmark.
|
||||
@webmethod(route="/evaluate/jobs", method="GET")
|
||||
async def list_evaluate_jobs(self) -> ListEvaluateJobsResponse: ...
|
||||
|
||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||
:param input_rows: The rows to evaluate.
|
||||
:param scoring_functions: The scoring functions to use for the evaluation.
|
||||
:param benchmark_config: The configuration for the benchmark.
|
||||
:return: EvaluateResponse object containing generations and scores
|
||||
"""
|
||||
@webmethod(route="/evaluate/jobs/{job_id:path}", method="POST")
|
||||
async def update_evaluate_job(self, job: EvaluateJob) -> EvaluateJob: ...
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
|
||||
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
|
||||
"""Get the status of a job.
|
||||
@webmethod(route="/evaluate/job/{job_id:path}", method="DELETE")
|
||||
async def delete_evaluate_job(self, job_id: str) -> None: ...
|
||||
|
||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||
:param job_id: The ID of the job to get the status of.
|
||||
:return: The status of the evaluationjob.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="DELETE")
|
||||
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
|
||||
"""Cancel a job.
|
||||
|
||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||
:param job_id: The ID of the job to cancel.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result", method="GET")
|
||||
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
|
||||
"""Get the result of a job.
|
||||
|
||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||
:param job_id: The ID of the job to get the result of.
|
||||
:return: The result of the job.
|
||||
"""
|
||||
# Note: pause/resume/cancel are achieved as follows:
|
||||
# - POST with status=paused
|
||||
# - POST with status=resuming
|
||||
# - POST with status=cancelled
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
||||
|
||||
|
|
@ -12,8 +11,7 @@ from pydantic import BaseModel, Field
|
|||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.common.job_types import JobStatus
|
||||
from llama_stack.apis.common.training_types import Checkpoint
|
||||
from llama_stack.apis.common.job_types import BaseJob
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
|
||||
|
|
@ -92,14 +90,6 @@ AlgorithmConfig = Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Fi
|
|||
register_schema(AlgorithmConfig, name="AlgorithmConfig")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class PostTrainingJobLogStream(BaseModel):
|
||||
"""Stream of logs from a finetuning job."""
|
||||
|
||||
job_uuid: str
|
||||
log_lines: List[str]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RLHFAlgorithm(Enum):
|
||||
dpo = "dpo"
|
||||
|
|
@ -135,41 +125,17 @@ class PostTrainingRLHFRequest(BaseModel):
|
|||
logger_config: Dict[str, Any]
|
||||
|
||||
|
||||
class PostTrainingJob(BaseModel):
|
||||
job_uuid: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class PostTrainingJobStatusResponse(BaseModel):
|
||||
"""Status of a finetuning job."""
|
||||
|
||||
job_uuid: str
|
||||
status: JobStatus
|
||||
|
||||
scheduled_at: Optional[datetime] = None
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
|
||||
resources_allocated: Optional[Dict[str, Any]] = None
|
||||
|
||||
checkpoints: List[Checkpoint] = Field(default_factory=list)
|
||||
class PostTrainingJob(BaseJob, BaseModel):
|
||||
type: Literal["post-training"] = "post-training"
|
||||
|
||||
|
||||
class ListPostTrainingJobsResponse(BaseModel):
|
||||
data: List[PostTrainingJob]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class PostTrainingJobArtifactsResponse(BaseModel):
|
||||
"""Artifacts of a finetuning job."""
|
||||
|
||||
job_uuid: str
|
||||
checkpoints: List[Checkpoint] = Field(default_factory=list)
|
||||
|
||||
# TODO(ashwin): metrics, evals
|
||||
data: list[PostTrainingJob]
|
||||
|
||||
|
||||
class PostTraining(Protocol):
|
||||
# This is how you create a new job - POST against the root endpoint
|
||||
@webmethod(route="/post-training/supervised-fine-tune", method="POST")
|
||||
async def supervised_fine_tune(
|
||||
self,
|
||||
|
|
@ -196,14 +162,20 @@ class PostTraining(Protocol):
|
|||
logger_config: Dict[str, Any],
|
||||
) -> PostTrainingJob: ...
|
||||
|
||||
# CRUD operations on running jobs
|
||||
@webmethod(route="/post-training/jobs/{job_id:path}", method="GET")
|
||||
async def get_post_training_job(self, job_id: str) -> PostTrainingJob: ...
|
||||
|
||||
@webmethod(route="/post-training/jobs", method="GET")
|
||||
async def get_training_jobs(self) -> ListPostTrainingJobsResponse: ...
|
||||
async def list_post_training_jobs(self) -> ListPostTrainingJobsResponse: ...
|
||||
|
||||
@webmethod(route="/post-training/job/status", method="GET")
|
||||
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse: ...
|
||||
@webmethod(route="/post-training/jobs/{job_id:path}", method="POST")
|
||||
async def update_post_training_job(self, job: PostTrainingJob) -> PostTrainingJob: ...
|
||||
|
||||
@webmethod(route="/post-training/job/cancel", method="POST")
|
||||
async def cancel_training_job(self, job_uuid: str) -> None: ...
|
||||
@webmethod(route="/post-training/job/{job_id:path}", method="DELETE")
|
||||
async def delete_post_training_job(self, job_id: str) -> None: ...
|
||||
|
||||
@webmethod(route="/post-training/job/artifacts", method="GET")
|
||||
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse: ...
|
||||
# Note: pause/resume/cancel are achieved as follows:
|
||||
# - POST with status=paused
|
||||
# - POST with status=resuming
|
||||
# - POST with status=cancelled
|
||||
|
|
|
|||
|
|
@ -5,10 +5,11 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Protocol, Union
|
||||
from typing import List, Literal, Optional, Protocol
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.common.job_types import BaseJob
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
|
@ -34,11 +35,13 @@ class SyntheticDataGenerationRequest(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class SyntheticDataGenerationResponse(BaseModel):
|
||||
"""Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold."""
|
||||
class SyntheticDataGenerationJob(BaseJob, BaseModel):
|
||||
type: Literal["synthetic-data-generation"] = "synthetic-data-generation"
|
||||
|
||||
synthetic_data: List[Dict[str, Any]]
|
||||
statistics: Optional[Dict[str, Any]] = None
|
||||
|
||||
@json_schema_type
|
||||
class ListSyntheticDataGenerationJobsResponse(BaseModel):
|
||||
items: list[SyntheticDataGenerationJob]
|
||||
|
||||
|
||||
class SyntheticDataGeneration(Protocol):
|
||||
|
|
@ -48,4 +51,24 @@ class SyntheticDataGeneration(Protocol):
|
|||
dialogs: List[Message],
|
||||
filtering_function: FilteringFunction = FilteringFunction.none,
|
||||
model: Optional[str] = None,
|
||||
) -> Union[SyntheticDataGenerationResponse]: ...
|
||||
) -> SyntheticDataGenerationJob: ...
|
||||
|
||||
# CRUD operations on running jobs
|
||||
@webmethod(route="/synthetic-data-generation/jobs/{job_id:path}", method="GET")
|
||||
async def get_synthetic_data_generation_job(self) -> SyntheticDataGenerationJob: ...
|
||||
|
||||
@webmethod(route="/synthetic-data-generation/jobs", method="GET")
|
||||
async def list_synthetic_data_generation_jobs(self) -> ListSyntheticDataGenerationJobsResponse: ...
|
||||
|
||||
@webmethod(route="/synthetic-data-generation/jobs/{job_id:path}", method="POST")
|
||||
async def update_synthetic_data_generation_job(
|
||||
self, job: SyntheticDataGenerationJob
|
||||
) -> SyntheticDataGenerationJob: ...
|
||||
|
||||
@webmethod(route="/synthetic-data-generation/job/{job_id:path}", method="DELETE")
|
||||
async def delete_synthetic_data_generation_job(self, job_id: str) -> None: ...
|
||||
|
||||
# Note: pause/resume/cancel are achieved as follows:
|
||||
# - POST with status=paused
|
||||
# - POST with status=resuming
|
||||
# - POST with status=cancelled
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue