feat(api): define a more coherent jobs api across different flows

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-03-24 20:54:04 -04:00
parent 71ed47ea76
commit 0f50cfa561
15 changed files with 1864 additions and 1670 deletions

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -4,9 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import List, Optional, Protocol, runtime_checkable from typing import List, Literal, Optional, Protocol, runtime_checkable
from llama_stack.apis.common.job_types import Job from pydantic import BaseModel
from llama_stack.apis.common.job_types import BaseJob
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
InterleavedContent, InterleavedContent,
LogProbConfig, LogProbConfig,
@ -20,6 +22,14 @@ from llama_stack.apis.inference import (
from llama_stack.schema_utils import webmethod from llama_stack.schema_utils import webmethod
class BatchInferenceJob(BaseJob, BaseModel):
type: Literal["batch_inference"] = "batch_inference"
class ListBatchInferenceJobsResponse(BaseModel):
data: list[BatchInferenceJob]
@runtime_checkable @runtime_checkable
class BatchInference(Protocol): class BatchInference(Protocol):
"""Batch inference API for generating completions and chat completions. """Batch inference API for generating completions and chat completions.
@ -38,7 +48,7 @@ class BatchInference(Protocol):
sampling_params: Optional[SamplingParams] = None, sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> Job: ... ) -> BatchInferenceJob: ...
@webmethod(route="/batch-inference/chat-completion", method="POST") @webmethod(route="/batch-inference/chat-completion", method="POST")
async def chat_completion( async def chat_completion(
@ -52,4 +62,4 @@ class BatchInference(Protocol):
tool_prompt_format: Optional[ToolPromptFormat] = None, tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> Job: ... ) -> BatchInferenceJob: ...

View file

@ -3,22 +3,68 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum from datetime import datetime, timezone
from enum import Enum, unique
from typing import Any
from pydantic import BaseModel from pydantic import BaseModel, Field, computed_field
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
@unique
class JobStatus(Enum): class JobStatus(Enum):
completed = "completed" unknown = "unknown"
in_progress = "in_progress" new = "new"
failed = "failed"
scheduled = "scheduled" scheduled = "scheduled"
running = "running"
paused = "paused"
resuming = "resuming"
cancelled = "cancelled" cancelled = "cancelled"
failed = "failed"
completed = "completed"
@json_schema_type @json_schema_type
class Job(BaseModel): class JobStatusDetails(BaseModel):
job_id: str
status: JobStatus status: JobStatus
message: str | None = None
timestamp: datetime
@json_schema_type
class JobArtifact(BaseModel):
name: str
# TODO: should it be a Literal / Enum?
type: str
# Any additional metadata the artifact may have
# TODO: is Any the right type here? What happens when the caller passes a value without a __repr__?
metadata: dict[str, Any] | None = None
# TODO: enforce type to be a URI
uri: str | None = None # points to /files
def _get_job_status_details(status: JobStatus) -> JobStatusDetails:
return JobStatusDetails(status=status, timestamp=datetime.now(timezone.utc))
class BaseJob(BaseModel):
id: str # TODO: make it a UUID?
artifacts: list[JobArtifact] = Field(default_factory=list)
events: list[JobStatusDetails] = Field(default_factory=lambda: [_get_job_status_details(JobStatus.new)])
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if "type" not in cls.__annotations__:
raise ValueError(f"Class {cls.__name__} must have a type field")
@computed_field
def status(self) -> JobStatus:
return self.events[-1].status
def update_status(self, value: JobStatus):
self.events.append(_get_job_status_details(value))

View file

@ -4,15 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict, List, Literal, Optional, Protocol, Union from typing import Dict, Literal, Optional, Protocol, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_stack.apis.agents import AgentConfig from llama_stack.apis.agents import AgentConfig
from llama_stack.apis.common.job_types import Job from llama_stack.apis.common.job_types import BaseJob
from llama_stack.apis.inference import SamplingParams, SystemMessage from llama_stack.apis.inference import SamplingParams, SystemMessage
from llama_stack.apis.scoring import ScoringResult
from llama_stack.apis.scoring_functions import ScoringFnParams from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@ -47,6 +46,14 @@ EvalCandidate = Annotated[Union[ModelCandidate, AgentCandidate], Field(discrimin
register_schema(EvalCandidate, name="EvalCandidate") register_schema(EvalCandidate, name="EvalCandidate")
class EvaluateJob(BaseJob, BaseModel):
type: Literal["eval"] = "eval"
class ListEvaluateJobsResponse(BaseModel):
data: list[EvaluateJob]
@json_schema_type @json_schema_type
class BenchmarkConfig(BaseModel): class BenchmarkConfig(BaseModel):
"""A benchmark configuration for evaluation. """A benchmark configuration for evaluation.
@ -68,76 +75,30 @@ class BenchmarkConfig(BaseModel):
# we could optinally add any specific dataset config here # we could optinally add any specific dataset config here
@json_schema_type
class EvaluateResponse(BaseModel):
"""The response from an evaluation.
:param generations: The generations from the evaluation.
:param scores: The scores from the evaluation.
"""
generations: List[Dict[str, Any]]
# each key in the dict is a scoring function name
scores: Dict[str, ScoringResult]
class Eval(Protocol): class Eval(Protocol):
"""Llama Stack Evaluation API for running evaluations on model and agent candidates.""" """Llama Stack Evaluation API for running evaluations on model and agent candidates."""
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST") @webmethod(route="/eval/benchmarks/{benchmark_id}/evaluate", method="POST")
async def run_eval( async def evaluate(
self, self,
benchmark_id: str, benchmark_id: str,
benchmark_config: BenchmarkConfig, benchmark_config: BenchmarkConfig,
) -> Job: ) -> EvaluateJob: ...
"""Run an evaluation on a benchmark.
:param benchmark_id: The ID of the benchmark to run the evaluation on. # CRUD operations on running jobs
:param benchmark_config: The configuration for the benchmark. @webmethod(route="/evaluate/jobs/{job_id:path}", method="GET")
:return: The job that was created to run the evaluation. async def get_evaluate_job(self, job_id: str) -> EvaluateJob: ...
"""
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST") @webmethod(route="/evaluate/jobs", method="GET")
async def evaluate_rows( async def list_evaluate_jobs(self) -> ListEvaluateJobsResponse: ...
self,
benchmark_id: str,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
benchmark_config: BenchmarkConfig,
) -> EvaluateResponse:
"""Evaluate a list of rows on a benchmark.
:param benchmark_id: The ID of the benchmark to run the evaluation on. @webmethod(route="/evaluate/jobs/{job_id:path}", method="POST")
:param input_rows: The rows to evaluate. async def update_evaluate_job(self, job: EvaluateJob) -> EvaluateJob: ...
:param scoring_functions: The scoring functions to use for the evaluation.
:param benchmark_config: The configuration for the benchmark.
:return: EvaluateResponse object containing generations and scores
"""
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET") @webmethod(route="/evaluate/job/{job_id:path}", method="DELETE")
async def job_status(self, benchmark_id: str, job_id: str) -> Job: async def delete_evaluate_job(self, job_id: str) -> None: ...
"""Get the status of a job.
:param benchmark_id: The ID of the benchmark to run the evaluation on. # Note: pause/resume/cancel are achieved as follows:
:param job_id: The ID of the job to get the status of. # - POST with status=paused
:return: The status of the evaluationjob. # - POST with status=resuming
""" # - POST with status=cancelled
...
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="DELETE")
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
"""Cancel a job.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param job_id: The ID of the job to cancel.
"""
...
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result", method="GET")
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
"""Get the result of a job.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param job_id: The ID of the job to get the result of.
:return: The result of the job.
"""

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from datetime import datetime
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Protocol, Union from typing import Any, Dict, List, Literal, Optional, Protocol, Union
@ -12,8 +11,7 @@ from pydantic import BaseModel, Field
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.job_types import JobStatus from llama_stack.apis.common.job_types import BaseJob
from llama_stack.apis.common.training_types import Checkpoint
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@ -92,14 +90,6 @@ AlgorithmConfig = Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Fi
register_schema(AlgorithmConfig, name="AlgorithmConfig") register_schema(AlgorithmConfig, name="AlgorithmConfig")
@json_schema_type
class PostTrainingJobLogStream(BaseModel):
"""Stream of logs from a finetuning job."""
job_uuid: str
log_lines: List[str]
@json_schema_type @json_schema_type
class RLHFAlgorithm(Enum): class RLHFAlgorithm(Enum):
dpo = "dpo" dpo = "dpo"
@ -135,41 +125,17 @@ class PostTrainingRLHFRequest(BaseModel):
logger_config: Dict[str, Any] logger_config: Dict[str, Any]
class PostTrainingJob(BaseModel):
job_uuid: str
@json_schema_type @json_schema_type
class PostTrainingJobStatusResponse(BaseModel): class PostTrainingJob(BaseJob, BaseModel):
"""Status of a finetuning job.""" type: Literal["post-training"] = "post-training"
job_uuid: str
status: JobStatus
scheduled_at: Optional[datetime] = None
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
resources_allocated: Optional[Dict[str, Any]] = None
checkpoints: List[Checkpoint] = Field(default_factory=list)
class ListPostTrainingJobsResponse(BaseModel): class ListPostTrainingJobsResponse(BaseModel):
data: List[PostTrainingJob] data: list[PostTrainingJob]
@json_schema_type
class PostTrainingJobArtifactsResponse(BaseModel):
"""Artifacts of a finetuning job."""
job_uuid: str
checkpoints: List[Checkpoint] = Field(default_factory=list)
# TODO(ashwin): metrics, evals
class PostTraining(Protocol): class PostTraining(Protocol):
# This is how you create a new job - POST against the root endpoint
@webmethod(route="/post-training/supervised-fine-tune", method="POST") @webmethod(route="/post-training/supervised-fine-tune", method="POST")
async def supervised_fine_tune( async def supervised_fine_tune(
self, self,
@ -196,14 +162,20 @@ class PostTraining(Protocol):
logger_config: Dict[str, Any], logger_config: Dict[str, Any],
) -> PostTrainingJob: ... ) -> PostTrainingJob: ...
# CRUD operations on running jobs
@webmethod(route="/post-training/jobs/{job_id:path}", method="GET")
async def get_post_training_job(self, job_id: str) -> PostTrainingJob: ...
@webmethod(route="/post-training/jobs", method="GET") @webmethod(route="/post-training/jobs", method="GET")
async def get_training_jobs(self) -> ListPostTrainingJobsResponse: ... async def list_post_training_jobs(self) -> ListPostTrainingJobsResponse: ...
@webmethod(route="/post-training/job/status", method="GET") @webmethod(route="/post-training/jobs/{job_id:path}", method="POST")
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse: ... async def update_post_training_job(self, job: PostTrainingJob) -> PostTrainingJob: ...
@webmethod(route="/post-training/job/cancel", method="POST") @webmethod(route="/post-training/job/{job_id:path}", method="DELETE")
async def cancel_training_job(self, job_uuid: str) -> None: ... async def delete_post_training_job(self, job_id: str) -> None: ...
@webmethod(route="/post-training/job/artifacts", method="GET") # Note: pause/resume/cancel are achieved as follows:
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse: ... # - POST with status=paused
# - POST with status=resuming
# - POST with status=cancelled

View file

@ -5,10 +5,11 @@
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional, Protocol, Union from typing import List, Literal, Optional, Protocol
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.apis.common.job_types import BaseJob
from llama_stack.apis.inference import Message from llama_stack.apis.inference import Message
from llama_stack.schema_utils import json_schema_type, webmethod from llama_stack.schema_utils import json_schema_type, webmethod
@ -34,11 +35,13 @@ class SyntheticDataGenerationRequest(BaseModel):
@json_schema_type @json_schema_type
class SyntheticDataGenerationResponse(BaseModel): class SyntheticDataGenerationJob(BaseJob, BaseModel):
"""Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold.""" type: Literal["synthetic-data-generation"] = "synthetic-data-generation"
synthetic_data: List[Dict[str, Any]]
statistics: Optional[Dict[str, Any]] = None @json_schema_type
class ListSyntheticDataGenerationJobsResponse(BaseModel):
items: list[SyntheticDataGenerationJob]
class SyntheticDataGeneration(Protocol): class SyntheticDataGeneration(Protocol):
@ -48,4 +51,24 @@ class SyntheticDataGeneration(Protocol):
dialogs: List[Message], dialogs: List[Message],
filtering_function: FilteringFunction = FilteringFunction.none, filtering_function: FilteringFunction = FilteringFunction.none,
model: Optional[str] = None, model: Optional[str] = None,
) -> Union[SyntheticDataGenerationResponse]: ... ) -> SyntheticDataGenerationJob: ...
# CRUD operations on running jobs
@webmethod(route="/synthetic-data-generation/jobs/{job_id:path}", method="GET")
async def get_synthetic_data_generation_job(self) -> SyntheticDataGenerationJob: ...
@webmethod(route="/synthetic-data-generation/jobs", method="GET")
async def list_synthetic_data_generation_jobs(self) -> ListSyntheticDataGenerationJobsResponse: ...
@webmethod(route="/synthetic-data-generation/jobs/{job_id:path}", method="POST")
async def update_synthetic_data_generation_job(
self, job: SyntheticDataGenerationJob
) -> SyntheticDataGenerationJob: ...
@webmethod(route="/synthetic-data-generation/job/{job_id:path}", method="DELETE")
async def delete_synthetic_data_generation_job(self, job_id: str) -> None: ...
# Note: pause/resume/cancel are achieved as follows:
# - POST with status=paused
# - POST with status=resuming
# - POST with status=cancelled

View file

@ -16,7 +16,7 @@ from llama_stack.apis.common.content_types import (
from llama_stack.apis.common.responses import PaginatedResponse from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import DatasetPurpose, DataSource from llama_stack.apis.datasets import DatasetPurpose, DataSource
from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateJob, ListEvaluateJobsResponse
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
BatchChatCompletionResponse, BatchChatCompletionResponse,
BatchCompletionResponse, BatchCompletionResponse,
@ -779,61 +779,32 @@ class EvalRouter(Eval):
logger.debug("EvalRouter.shutdown") logger.debug("EvalRouter.shutdown")
pass pass
async def run_eval( async def evaluate(
self, self,
benchmark_id: str, benchmark_id: str,
benchmark_config: BenchmarkConfig, benchmark_config: BenchmarkConfig,
) -> Job: ) -> EvaluateJob:
logger.debug(f"EvalRouter.run_eval: {benchmark_id}") logger.debug(f"EvalRouter.evaluate: {benchmark_id}")
return await self.routing_table.get_provider_impl(benchmark_id).run_eval( return await self.routing_table.get_provider_impl(benchmark_id).evaluate(
benchmark_id=benchmark_id, benchmark_id=benchmark_id,
benchmark_config=benchmark_config, benchmark_config=benchmark_config,
) )
async def evaluate_rows( async def get_evaluate_job(self, job_id: str) -> EvaluateJob:
self, logger.debug(f"EvalRouter.get_evaluate_job: {job_id}")
benchmark_id: str, return await self.routing_table.get_provider_impl("eval").get_evaluate_job(job_id)
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
benchmark_config: BenchmarkConfig,
) -> EvaluateResponse:
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
benchmark_id=benchmark_id,
input_rows=input_rows,
scoring_functions=scoring_functions,
benchmark_config=benchmark_config,
)
async def job_status( async def list_evaluate_jobs(self) -> ListEvaluateJobsResponse:
self, logger.debug("EvalRouter.list_evaluate_jobs")
benchmark_id: str, return await self.routing_table.get_provider_impl("eval").list_evaluate_jobs()
job_id: str,
) -> Job:
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
async def job_cancel( async def update_evaluate_job(self, job: EvaluateJob) -> EvaluateJob:
self, logger.debug(f"EvalRouter.update_evaluate_job: {job.id}")
benchmark_id: str, return await self.routing_table.get_provider_impl("eval").update_evaluate_job(job)
job_id: str,
) -> None:
logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
await self.routing_table.get_provider_impl(benchmark_id).job_cancel(
benchmark_id,
job_id,
)
async def job_result( async def delete_evaluate_job(self, job_id: str) -> None:
self, logger.debug(f"EvalRouter.delete_evaluate_job: {job_id}")
benchmark_id: str, return await self.routing_table.get_provider_impl("eval").delete_evaluate_job(job_id)
job_id: str,
) -> EvaluateResponse:
logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_result(
benchmark_id,
job_id,
)
class ToolRuntimeRouter(ToolRuntime): class ToolRuntimeRouter(ToolRuntime):

View file

@ -20,9 +20,10 @@ from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
) )
from llama_stack.providers.utils.common.data_schema_validator import ColumnName from llama_stack.providers.utils.common.data_schema_validator import ColumnName
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.schema_utils import webmethod
from .....apis.common.job_types import Job, JobStatus from .....apis.common.job_types import JobArtifact, JobStatus
from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateJob, ListEvaluateJobsResponse
from .config import MetaReferenceEvalConfig from .config import MetaReferenceEvalConfig
EVAL_TASKS_PREFIX = "benchmarks:" EVAL_TASKS_PREFIX = "benchmarks:"
@ -75,11 +76,11 @@ class MetaReferenceEvalImpl(
) )
self.benchmarks[task_def.identifier] = task_def self.benchmarks[task_def.identifier] = task_def
async def run_eval( async def evaluate(
self, self,
benchmark_id: str, benchmark_id: str,
benchmark_config: BenchmarkConfig, benchmark_config: BenchmarkConfig,
) -> Job: ) -> EvaluateJob:
task_def = self.benchmarks[benchmark_id] task_def = self.benchmarks[benchmark_id]
dataset_id = task_def.dataset_id dataset_id = task_def.dataset_id
scoring_functions = task_def.scoring_functions scoring_functions = task_def.scoring_functions
@ -91,18 +92,35 @@ class MetaReferenceEvalImpl(
dataset_id=dataset_id, dataset_id=dataset_id,
limit=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples), limit=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples),
) )
res = await self.evaluate_rows(
generations, scoring_results = await self._evaluate_rows(
benchmark_id=benchmark_id, benchmark_id=benchmark_id,
input_rows=all_rows.data, input_rows=all_rows.data,
scoring_functions=scoring_functions, scoring_functions=scoring_functions,
benchmark_config=benchmark_config, benchmark_config=benchmark_config,
) )
artifacts = [
JobArtifact(
type="generation",
name=f"generation-{i}",
metadata=generation,
)
for i, generation in enumerate(generations)
] + [
JobArtifact(
type="scoring_results",
name="scoring_results",
metadata=scoring_results,
)
]
# TODO: currently needs to wait for generation before returning # TODO: currently needs to wait for generation before returning
# need job scheduler queue (ray/celery) w/ jobs api # need job scheduler queue (ray/celery) w/ jobs api
job_id = str(len(self.jobs)) job_id = str(len(self.jobs))
self.jobs[job_id] = res job = EvaluateJob(id=job_id, artifacts=artifacts)
return Job(job_id=job_id, status=JobStatus.completed) job.update_status(JobStatus.completed)
self.jobs[job_id] = job
return job
async def _run_agent_generation( async def _run_agent_generation(
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
@ -182,13 +200,13 @@ class MetaReferenceEvalImpl(
return generations return generations
async def evaluate_rows( async def _evaluate_rows(
self, self,
benchmark_id: str, benchmark_id: str,
input_rows: List[Dict[str, Any]], input_rows: List[Dict[str, Any]],
scoring_functions: List[str], scoring_functions: List[str],
benchmark_config: BenchmarkConfig, benchmark_config: BenchmarkConfig,
) -> EvaluateResponse: ) -> tuple[list[dict[str, Any]], dict[str, Any]]:
candidate = benchmark_config.eval_candidate candidate = benchmark_config.eval_candidate
if candidate.type == "agent": if candidate.type == "agent":
generations = await self._run_agent_generation(input_rows, benchmark_config) generations = await self._run_agent_generation(input_rows, benchmark_config)
@ -214,21 +232,26 @@ class MetaReferenceEvalImpl(
input_rows=score_input_rows, scoring_functions=scoring_functions_dict input_rows=score_input_rows, scoring_functions=scoring_functions_dict
) )
return EvaluateResponse(generations=generations, scores=score_response.results) return generations, score_response.results
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
if job_id in self.jobs:
return Job(job_id=job_id, status=JobStatus.completed)
raise ValueError(f"Job {job_id} not found")
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
raise NotImplementedError("Job cancel is not implemented yet")
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
job = await self.job_status(benchmark_id, job_id)
status = job.status
if not status or status != JobStatus.completed:
raise ValueError(f"Job is not completed, Status: {status.value}")
# CRUD operations on running jobs
@webmethod(route="/evaluate/jobs/{job_id:path}", method="GET")
async def get_evaluate_job(self, job_id: str) -> EvaluateJob:
return self.jobs[job_id] return self.jobs[job_id]
@webmethod(route="/evaluate/jobs", method="GET")
async def list_evaluate_jobs(self) -> ListEvaluateJobsResponse:
return ListEvaluateJobsResponse(data=list(self.jobs.values()))
@webmethod(route="/evaluate/jobs/{job_id:path}", method="POST")
async def update_evaluate_job(self, job: EvaluateJob) -> EvaluateJob:
raise NotImplementedError
@webmethod(route="/evaluate/job/{job_id:path}", method="DELETE")
async def delete_evaluate_job(self, job_id: str) -> None:
raise NotImplementedError
# Note: pause/resume/cancel are achieved as follows:
# - POST with status=paused
# - POST with status=resuming
# - POST with status=cancelled

View file

@ -10,14 +10,10 @@ from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import ( from llama_stack.apis.post_training import (
AlgorithmConfig, AlgorithmConfig,
Checkpoint,
DPOAlignmentConfig, DPOAlignmentConfig,
JobStatus,
ListPostTrainingJobsResponse, ListPostTrainingJobsResponse,
LoraFinetuningConfig, LoraFinetuningConfig,
PostTrainingJob, PostTrainingJob,
PostTrainingJobArtifactsResponse,
PostTrainingJobStatusResponse,
TrainingConfig, TrainingConfig,
) )
from llama_stack.providers.inline.post_training.torchtune.config import ( from llama_stack.providers.inline.post_training.torchtune.config import (
@ -54,15 +50,6 @@ class TorchtunePostTrainingImpl:
async def shutdown(self) -> None: async def shutdown(self) -> None:
await self._scheduler.shutdown() await self._scheduler.shutdown()
@staticmethod
def _checkpoint_to_artifact(checkpoint: Checkpoint) -> JobArtifact:
return JobArtifact(
type=TrainingArtifactType.CHECKPOINT.value,
name=checkpoint.identifier,
uri=checkpoint.path,
metadata=dict(checkpoint),
)
@staticmethod @staticmethod
def _resources_stats_to_artifact(resources_stats: Dict[str, Any]) -> JobArtifact: def _resources_stats_to_artifact(resources_stats: Dict[str, Any]) -> JobArtifact:
return JobArtifact( return JobArtifact(
@ -98,14 +85,14 @@ class TorchtunePostTrainingImpl:
self.datasetio_api, self.datasetio_api,
self.datasets_api, self.datasets_api,
) )
await recipe.setup() await recipe.setup()
resources_allocated, checkpoints = await recipe.train() resources_allocated, checkpoints = await recipe.train()
on_artifact_collected_cb(self._resources_stats_to_artifact(resources_allocated)) on_artifact_collected_cb(self._resources_stats_to_artifact(resources_allocated))
for checkpoint in checkpoints: for checkpoint in checkpoints:
artifact = self._checkpoint_to_artifact(checkpoint) on_artifact_collected_cb(checkpoint)
on_artifact_collected_cb(artifact)
on_status_change_cb(SchedulerJobStatus.completed) on_status_change_cb(SchedulerJobStatus.completed)
on_log_message_cb("Lora finetuning completed") on_log_message_cb("Lora finetuning completed")
@ -113,6 +100,8 @@ class TorchtunePostTrainingImpl:
raise NotImplementedError() raise NotImplementedError()
job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler) job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler)
# TODO: initialize with more data from scheduler
return PostTrainingJob(job_uuid=job_uuid) return PostTrainingJob(job_uuid=job_uuid)
async def preference_optimize( async def preference_optimize(
@ -125,56 +114,31 @@ class TorchtunePostTrainingImpl:
logger_config: Dict[str, Any], logger_config: Dict[str, Any],
) -> PostTrainingJob: ... ) -> PostTrainingJob: ...
async def get_training_jobs(self) -> ListPostTrainingJobsResponse: # TODO: should these be under post-training/supervised-fine-tune/?
# CRUD operations on running jobs
@webmethod(route="/post-training/jobs/{job_id:path}", method="GET")
async def get_post_training_job(self, job_id: str) -> PostTrainingJob:
# TODO: implement
raise NotImplementedError
@webmethod(route="/post-training/jobs", method="GET")
async def list_post_training_jobs(self) -> ListPostTrainingJobsResponse:
# TODO: populate other data
return ListPostTrainingJobsResponse( return ListPostTrainingJobsResponse(
data=[PostTrainingJob(job_uuid=job.id) for job in self._scheduler.get_jobs()] data=[PostTrainingJob(job_uuid=job.id) for job in self._scheduler.get_jobs()]
) )
@staticmethod @webmethod(route="/post-training/jobs/{job_id:path}", method="POST")
def _get_artifacts_metadata_by_type(job, artifact_type): async def update_post_training_job(self, job: PostTrainingJob) -> PostTrainingJob:
return [artifact.metadata for artifact in job.artifacts if artifact.type == artifact_type] # TODO: implement
raise NotImplementedError
@classmethod @webmethod(route="/post-training/job/{job_id:path}", method="DELETE")
def _get_checkpoints(cls, job): async def delete_post_training_job(self, job_id: str) -> None:
return cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.CHECKPOINT.value) # TODO: implement
raise NotImplementedError
@classmethod # Note: pause/resume/cancel are achieved as follows:
def _get_resources_allocated(cls, job): # - POST with status=paused
data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value) # - POST with status=resuming
return data[0] if data else None # - POST with status=cancelled
@webmethod(route="/post-training/job/status")
async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]:
job = self._scheduler.get_job(job_uuid)
match job.status:
# TODO: Add support for other statuses to API
case SchedulerJobStatus.new | SchedulerJobStatus.scheduled:
status = JobStatus.scheduled
case SchedulerJobStatus.running:
status = JobStatus.in_progress
case SchedulerJobStatus.completed:
status = JobStatus.completed
case SchedulerJobStatus.failed:
status = JobStatus.failed
case _:
raise NotImplementedError()
return PostTrainingJobStatusResponse(
job_uuid=job_uuid,
status=status,
scheduled_at=job.scheduled_at,
started_at=job.started_at,
completed_at=job.completed_at,
checkpoints=self._get_checkpoints(job),
resources_allocated=self._get_resources_allocated(job),
)
@webmethod(route="/post-training/job/cancel")
async def cancel_training_job(self, job_uuid: str) -> None:
self._scheduler.cancel(job_uuid)
@webmethod(route="/post-training/job/artifacts")
async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]:
job = self._scheduler.get_job(job_uuid)
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job))

View file

@ -33,11 +33,11 @@ from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup
from torchtune.training.metric_logging import DiskLogger from torchtune.training.metric_logging import DiskLogger
from tqdm import tqdm from tqdm import tqdm
from llama_stack.apis.common.job_types import JobArtifact
from llama_stack.apis.common.training_types import PostTrainingMetric from llama_stack.apis.common.training_types import PostTrainingMetric
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import ( from llama_stack.apis.post_training import (
Checkpoint,
DataConfig, DataConfig,
EfficiencyConfig, EfficiencyConfig,
LoraFinetuningConfig, LoraFinetuningConfig,
@ -457,7 +457,7 @@ class LoraFinetuningSingleDevice:
return loss return loss
async def train(self) -> Tuple[Dict[str, Any], List[Checkpoint]]: async def train(self) -> Tuple[Dict[str, Any], List[JobArtifact]]:
""" """
The core training loop. The core training loop.
""" """
@ -543,13 +543,18 @@ class LoraFinetuningSingleDevice:
self.epochs_run += 1 self.epochs_run += 1
log.info("Starting checkpoint save...") log.info("Starting checkpoint save...")
checkpoint_path = await self.save_checkpoint(epoch=curr_epoch) checkpoint_path = await self.save_checkpoint(epoch=curr_epoch)
checkpoint = Checkpoint(
identifier=f"{self.model_id}-sft-{curr_epoch}", checkpoint = JobArtifact(
created_at=datetime.now(timezone.utc), name=f"{self.model_id}-sft-{curr_epoch}",
epoch=curr_epoch, type="checkpoint",
post_training_job_id=self.job_uuid, # TODO: this should be exposed via /files instead
path=checkpoint_path, uri=checkpoint_path,
) )
metadata = {
"created_at": datetime.now(timezone.utc),
"epoch": curr_epoch,
}
if self.training_config.data_config.validation_dataset_id: if self.training_config.data_config.validation_dataset_id:
validation_loss, perplexity = await self.validation() validation_loss, perplexity = await self.validation()
training_metrics = PostTrainingMetric( training_metrics = PostTrainingMetric(
@ -558,7 +563,9 @@ class LoraFinetuningSingleDevice:
validation_loss=validation_loss, validation_loss=validation_loss,
perplexity=perplexity, perplexity=perplexity,
) )
checkpoint.training_metrics = training_metrics metadata["training_metrics"] = training_metrics
checkpoint.metadata = metadata
checkpoints.append(checkpoint) checkpoints.append(checkpoint)
# clean up the memory after training finishes # clean up the memory after training finishes

View file

@ -4,19 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import warnings import warnings
from datetime import datetime from typing import Any, Dict, Optional
from typing import Any, Dict, List, Literal, Optional
import aiohttp import aiohttp
from pydantic import BaseModel, ConfigDict
from llama_stack.apis.common.job_types import JobStatus
from llama_stack.apis.post_training import ( from llama_stack.apis.post_training import (
AlgorithmConfig, AlgorithmConfig,
DPOAlignmentConfig, DPOAlignmentConfig,
JobStatus, ListPostTrainingJobsResponse,
PostTrainingJob, PostTrainingJob,
PostTrainingJobArtifactsResponse,
PostTrainingJobStatusResponse,
TrainingConfig, TrainingConfig,
) )
from llama_stack.providers.remote.post_training.nvidia.config import NvidiaPostTrainingConfig from llama_stack.providers.remote.post_training.nvidia.config import NvidiaPostTrainingConfig
@ -25,36 +22,6 @@ from llama_stack.providers.utils.inference.model_registry import ModelRegistryHe
from .models import _MODEL_ENTRIES from .models import _MODEL_ENTRIES
# Map API status to JobStatus enum
STATUS_MAPPING = {
"running": "in_progress",
"completed": "completed",
"failed": "failed",
"cancelled": "cancelled",
"pending": "scheduled",
}
class NvidiaPostTrainingJob(PostTrainingJob):
"""Parse the response from the Customizer API.
Inherits job_uuid from PostTrainingJob.
Adds status, created_at, updated_at parameters.
Passes through all other parameters from data field in the response.
"""
model_config = ConfigDict(extra="allow")
status: JobStatus
created_at: datetime
updated_at: datetime
class ListNvidiaPostTrainingJobs(BaseModel):
data: List[NvidiaPostTrainingJob]
class NvidiaPostTrainingJobStatusResponse(PostTrainingJobStatusResponse):
model_config = ConfigDict(extra="allow")
class NvidiaPostTrainingAdapter(ModelRegistryHelper): class NvidiaPostTrainingAdapter(ModelRegistryHelper):
def __init__(self, config: NvidiaPostTrainingConfig): def __init__(self, config: NvidiaPostTrainingConfig):
@ -100,102 +67,54 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
raise Exception(f"API request failed: {error_data}") raise Exception(f"API request failed: {error_data}")
return await response.json() return await response.json()
async def get_training_jobs( raise Exception(f"API request failed after {self.config.max_retries} retries")
self,
page: Optional[int] = 1,
page_size: Optional[int] = 10,
sort: Optional[Literal["created_at", "-created_at"]] = "created_at",
) -> ListNvidiaPostTrainingJobs:
"""Get all customization jobs.
Updated the base class return type from ListPostTrainingJobsResponse to ListNvidiaPostTrainingJobs.
Returns a ListNvidiaPostTrainingJobs object with the following fields: @staticmethod
- data: List[NvidiaPostTrainingJob] - List of NvidiaPostTrainingJob objects def _get_job_status(job: Dict[str, Any]) -> JobStatus:
job_status = job.get("status", "unknown").lower()
try:
return JobStatus(job_status)
except ValueError:
return JobStatus.unknown
# TODO: fetch just the necessary job from remote
async def get_post_training_job(self, job_id: str) -> PostTrainingJob:
jobs = await self.list_post_training_jobs()
for job in jobs.data:
if job.id == job_id:
return job
raise ValueError(f"Job with ID {job_id} not found")
async def list_post_training_jobs(self) -> ListPostTrainingJobsResponse:
"""Get all customization jobs.
ToDo: Support for schema input for filtering. ToDo: Support for schema input for filtering.
""" """
params = {"page": page, "page_size": page_size, "sort": sort} # TODO: don't hardcode pagination params
params = {"page": 1, "page_size": 10, "sort": "created_at"}
response = await self._make_request("GET", "/v1/customization/jobs", params=params) response = await self._make_request("GET", "/v1/customization/jobs", params=params)
jobs = [] jobs = []
for job in response.get("data", []): for job_dict in response.get("data", []):
job_id = job.pop("id") # TODO: expose artifacts
job_status = job.pop("status", "unknown").lower() job = PostTrainingJob(**job_dict)
mapped_status = STATUS_MAPPING.get(job_status, "unknown") job.update_status(self._get_job_status(job_dict))
jobs.append(job)
# Convert string timestamps to datetime objects return ListPostTrainingJobsResponse(data=jobs)
created_at = (
datetime.fromisoformat(job.pop("created_at"))
if "created_at" in job
else datetime.now(tz=datetime.timezone.utc)
)
updated_at = (
datetime.fromisoformat(job.pop("updated_at"))
if "updated_at" in job
else datetime.now(tz=datetime.timezone.utc)
)
# Create NvidiaPostTrainingJob instance async def update_post_training_job(self, job_id: str, status: JobStatus | None = None) -> PostTrainingJob:
jobs.append( if status is None:
NvidiaPostTrainingJob( raise ValueError("Status must be provided")
job_uuid=job_id, if status not in {JobStatus.cancelled}:
status=JobStatus(mapped_status), raise ValueError(f"Unsupported status: {status}")
created_at=created_at,
updated_at=updated_at,
**job,
)
)
return ListNvidiaPostTrainingJobs(data=jobs)
async def get_training_job_status(self, job_uuid: str) -> NvidiaPostTrainingJobStatusResponse:
"""Get the status of a customization job.
Updated the base class return type from PostTrainingJobResponse to NvidiaPostTrainingJob.
Returns a NvidiaPostTrainingJob object with the following fields:
- job_uuid: str - Unique identifier for the job
- status: JobStatus - Current status of the job (in_progress, completed, failed, cancelled, scheduled)
- created_at: datetime - The time when the job was created
- updated_at: datetime - The last time the job status was updated
Additional fields that may be included:
- steps_completed: Optional[int] - Number of training steps completed
- epochs_completed: Optional[int] - Number of epochs completed
- percentage_done: Optional[float] - Percentage of training completed (0-100)
- best_epoch: Optional[int] - The epoch with the best performance
- train_loss: Optional[float] - Training loss of the best checkpoint
- val_loss: Optional[float] - Validation loss of the best checkpoint
- metrics: Optional[Dict] - Additional training metrics
- status_logs: Optional[List] - Detailed logs of status changes
"""
response = await self._make_request(
"GET",
f"/v1/customization/jobs/{job_uuid}/status",
params={"job_id": job_uuid},
)
api_status = response.pop("status").lower()
mapped_status = STATUS_MAPPING.get(api_status, "unknown")
return NvidiaPostTrainingJobStatusResponse(
status=JobStatus(mapped_status),
job_uuid=job_uuid,
started_at=datetime.fromisoformat(response.pop("created_at")),
updated_at=datetime.fromisoformat(response.pop("updated_at")),
**response,
)
async def cancel_training_job(self, job_uuid: str) -> None:
await self._make_request( await self._make_request(
method="POST", path=f"/v1/customization/jobs/{job_uuid}/cancel", params={"job_id": job_uuid} method="POST", path=f"/v1/customization/jobs/{job_id}/cancel", params={"job_id": job_id}
) )
return await self.get_post_training_job(job_id)
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse: async def delete_post_training_job(self, job_id: str) -> None:
raise NotImplementedError("Job artifacts are not implemented yet") raise NotImplementedError("Delete job is not implemented yet")
async def get_post_training_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
raise NotImplementedError("Job artifacts are not implemented yet")
async def supervised_fine_tune( async def supervised_fine_tune(
self, self,
@ -206,7 +125,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
model: str, model: str,
checkpoint_dir: Optional[str], checkpoint_dir: Optional[str],
algorithm_config: Optional[AlgorithmConfig] = None, algorithm_config: Optional[AlgorithmConfig] = None,
) -> NvidiaPostTrainingJob: ) -> PostTrainingJob:
""" """
Fine-tunes a model on a dataset. Fine-tunes a model on a dataset.
Currently only supports Lora finetuning for standlone docker container. Currently only supports Lora finetuning for standlone docker container.
@ -409,15 +328,12 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
headers={"Accept": "application/json"}, headers={"Accept": "application/json"},
json=job_config, json=job_config,
) )
job_uuid = response["id"]
response.pop("status") response.pop("status")
created_at = datetime.fromisoformat(response.pop("created_at"))
updated_at = datetime.fromisoformat(response.pop("updated_at"))
return NvidiaPostTrainingJob( # TODO: expose artifacts
job_uuid=job_uuid, status=JobStatus.in_progress, created_at=created_at, updated_at=updated_at, **response job = PostTrainingJob(**response)
) job.update_status(JobStatus.running)
return job
async def preference_optimize( async def preference_optimize(
self, self,
@ -430,6 +346,3 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
) -> PostTrainingJob: ) -> PostTrainingJob:
"""Optimize a model based on preference data.""" """Optimize a model based on preference data."""
raise NotImplementedError("Preference optimization is not implemented yet") raise NotImplementedError("Preference optimization is not implemented yet")
async def get_training_job_container_logs(self, job_uuid: str) -> PostTrainingJobStatusResponse:
raise NotImplementedError("Job logs are not implemented yet")

View file

@ -562,6 +562,15 @@ else:
return typing.get_type_hints(typ) return typing.get_type_hints(typ)
def get_computed_fields(typ: type) -> dict[str, type]:
"Returns all computed fields of a class."
pydantic_decorators = getattr(typ, "__pydantic_decorators__", None)
if not pydantic_decorators:
return {}
computed_fields = pydantic_decorators.computed_fields
return {field_name: decorator.info.return_type for field_name, decorator in computed_fields.items()}
def get_class_properties(typ: type) -> Iterable[Tuple[str, type | str]]: def get_class_properties(typ: type) -> Iterable[Tuple[str, type | str]]:
"Returns all properties of a class." "Returns all properties of a class."
@ -569,7 +578,8 @@ def get_class_properties(typ: type) -> Iterable[Tuple[str, type | str]]:
return ((field.name, field.type) for field in dataclasses.fields(typ)) return ((field.name, field.type) for field in dataclasses.fields(typ))
else: else:
resolved_hints = get_resolved_hints(typ) resolved_hints = get_resolved_hints(typ)
return resolved_hints.items() computed_fields = get_computed_fields(typ)
return (resolved_hints | computed_fields).items()
def get_class_property(typ: type, name: str) -> Optional[type | str]: def get_class_property(typ: type, name: str) -> Optional[type | str]:

View file

@ -8,14 +8,12 @@ from typing import List
import pytest import pytest
from llama_stack.apis.common.job_types import JobStatus from llama_stack.apis.common.job_types import JobStatus
from llama_stack.apis.common.training_types import Checkpoint
from llama_stack.apis.post_training import ( from llama_stack.apis.post_training import (
Checkpoint,
DataConfig, DataConfig,
LoraFinetuningConfig, LoraFinetuningConfig,
OptimizerConfig, OptimizerConfig,
PostTrainingJob, PostTrainingJob,
PostTrainingJobArtifactsResponse,
PostTrainingJobStatusResponse,
TrainingConfig, TrainingConfig,
) )
@ -84,7 +82,6 @@ class TestPostTraining:
async def test_get_training_job_status(self, post_training_stack): async def test_get_training_job_status(self, post_training_stack):
post_training_impl = post_training_stack post_training_impl = post_training_stack
job_status = await post_training_impl.get_training_job_status("1234") job_status = await post_training_impl.get_training_job_status("1234")
assert isinstance(job_status, PostTrainingJobStatusResponse)
assert job_status.job_uuid == "1234" assert job_status.job_uuid == "1234"
assert job_status.status == JobStatus.completed assert job_status.status == JobStatus.completed
assert isinstance(job_status.checkpoints[0], Checkpoint) assert isinstance(job_status.checkpoints[0], Checkpoint)
@ -93,7 +90,6 @@ class TestPostTraining:
async def test_get_training_job_artifacts(self, post_training_stack): async def test_get_training_job_artifacts(self, post_training_stack):
post_training_impl = post_training_stack post_training_impl = post_training_stack
job_artifacts = await post_training_impl.get_training_job_artifacts("1234") job_artifacts = await post_training_impl.get_training_job_artifacts("1234")
assert isinstance(job_artifacts, PostTrainingJobArtifactsResponse)
assert job_artifacts.job_uuid == "1234" assert job_artifacts.job_uuid == "1234"
assert isinstance(job_artifacts.checkpoints[0], Checkpoint) assert isinstance(job_artifacts.checkpoints[0], Checkpoint)
assert job_artifacts.checkpoints[0].identifier == "Llama3.2-3B-Instruct-sft-0" assert job_artifacts.checkpoints[0].identifier == "Llama3.2-3B-Instruct-sft-0"

View file

@ -17,12 +17,14 @@ from llama_stack_client.types.post_training_supervised_fine_tune_params import (
TrainingConfigOptimizerConfig, TrainingConfigOptimizerConfig,
) )
from llama_stack.apis.common.job_types import JobStatus
from llama_stack.apis.post_training import (
ListPostTrainingJobsResponse,
PostTrainingJob,
)
from llama_stack.providers.remote.post_training.nvidia.post_training import ( from llama_stack.providers.remote.post_training.nvidia.post_training import (
ListNvidiaPostTrainingJobs,
NvidiaPostTrainingAdapter, NvidiaPostTrainingAdapter,
NvidiaPostTrainingConfig, NvidiaPostTrainingConfig,
NvidiaPostTrainingJob,
NvidiaPostTrainingJobStatusResponse,
) )
@ -49,21 +51,25 @@ class TestNvidiaPostTraining(unittest.TestCase):
def _assert_request(self, mock_call, expected_method, expected_path, expected_params=None, expected_json=None): def _assert_request(self, mock_call, expected_method, expected_path, expected_params=None, expected_json=None):
"""Helper method to verify request details in mock calls.""" """Helper method to verify request details in mock calls."""
call_args = mock_call.call_args found = False
for call_args in mock_call.call_args_list:
if expected_method and expected_path:
if isinstance(call_args[0], tuple) and len(call_args[0]) == 2:
if call_args[0] == (expected_method, expected_path):
found = True
else:
if call_args[1]["method"] == expected_method and call_args[1]["path"] == expected_path:
found = True
if expected_method and expected_path: if expected_params:
if isinstance(call_args[0], tuple) and len(call_args[0]) == 2: if call_args[1]["params"] == expected_params:
assert call_args[0] == (expected_method, expected_path) found = True
else:
assert call_args[1]["method"] == expected_method
assert call_args[1]["path"] == expected_path
if expected_params: if expected_json:
assert call_args[1]["params"] == expected_params for key, value in expected_json.items():
if call_args[1]["json"][key] == value:
if expected_json: found = True
for key, value in expected_json.items(): assert found
assert call_args[1]["json"][key] == value
def test_supervised_fine_tune(self): def test_supervised_fine_tune(self):
"""Test the supervised fine-tuning API call.""" """Test the supervised fine-tuning API call."""
@ -151,9 +157,8 @@ class TestNvidiaPostTraining(unittest.TestCase):
) )
) )
# check the output is a PostTrainingJob assert isinstance(training_job, PostTrainingJob)
assert isinstance(training_job, NvidiaPostTrainingJob) assert training_job.id == "cust-JGTaMbJMdqjJU8WbQdN9Q2"
assert training_job.job_uuid == "cust-JGTaMbJMdqjJU8WbQdN9Q2"
self.mock_make_request.assert_called_once() self.mock_make_request.assert_called_once()
self._assert_request( self._assert_request(
@ -199,38 +204,7 @@ class TestNvidiaPostTraining(unittest.TestCase):
) )
) )
def test_get_training_job_status(self): def test_list_post_training_jobs(self):
self.mock_make_request.return_value = {
"created_at": "2024-12-09T04:06:28.580220",
"updated_at": "2024-12-09T04:21:19.852832",
"status": "completed",
"steps_completed": 1210,
"epochs_completed": 2,
"percentage_done": 100.0,
"best_epoch": 2,
"train_loss": 1.718016266822815,
"val_loss": 1.8661999702453613,
}
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
status = self.run_async(self.adapter.get_training_job_status(job_uuid=job_id))
assert isinstance(status, NvidiaPostTrainingJobStatusResponse)
assert status.status.value == "completed"
assert status.steps_completed == 1210
assert status.epochs_completed == 2
assert status.percentage_done == 100.0
assert status.best_epoch == 2
assert status.train_loss == 1.718016266822815
assert status.val_loss == 1.8661999702453613
self.mock_make_request.assert_called_once()
self._assert_request(
self.mock_make_request, "GET", f"/v1/customization/jobs/{job_id}/status", expected_params={"job_id": job_id}
)
def test_get_training_jobs(self):
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2" job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
self.mock_make_request.return_value = { self.mock_make_request.return_value = {
"data": [ "data": [
@ -258,12 +232,12 @@ class TestNvidiaPostTraining(unittest.TestCase):
] ]
} }
jobs = self.run_async(self.adapter.get_training_jobs()) jobs = self.run_async(self.adapter.list_post_training_jobs())
assert isinstance(jobs, ListNvidiaPostTrainingJobs) assert isinstance(jobs, ListPostTrainingJobsResponse)
assert len(jobs.data) == 1 assert len(jobs.data) == 1
job = jobs.data[0] job = jobs.data[0]
assert job.job_uuid == job_id assert job.id == job_id
assert job.status.value == "completed" assert job.status.value == "completed"
self.mock_make_request.assert_called_once() self.mock_make_request.assert_called_once()
@ -275,14 +249,36 @@ class TestNvidiaPostTraining(unittest.TestCase):
) )
def test_cancel_training_job(self): def test_cancel_training_job(self):
self.mock_make_request.return_value = {} # Empty response for successful cancellation
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2" job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
self.mock_make_request.return_value = {
"data": [
{
"id": job_id,
"created_at": "2024-12-09T04:06:28.542884",
"updated_at": "2024-12-09T04:21:19.852832",
"config": {
"name": "meta-llama/Llama-3.1-8B-Instruct",
"base_model": "meta-llama/Llama-3.1-8B-Instruct",
},
"dataset": {"name": "default/sample-basic-test"},
"hyperparameters": {
"finetuning_type": "lora",
"training_type": "sft",
"batch_size": 16,
"epochs": 2,
"learning_rate": 0.0001,
"lora": {"adapter_dim": 16, "adapter_dropout": 0.1},
},
"output_model": "default/job-1234",
"status": "completed",
"project": "default",
}
]
}
result = self.run_async(self.adapter.cancel_training_job(job_uuid=job_id)) result = self.run_async(self.adapter.update_post_training_job(job_id=job_id, status=JobStatus.cancelled))
assert result.id == job_id
assert result is None
self.mock_make_request.assert_called_once()
self._assert_request( self._assert_request(
self.mock_make_request, self.mock_make_request,
"POST", "POST",