mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-05 10:13:05 +00:00
feat(api): define a more coherent jobs api across different flows
Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
71ed47ea76
commit
0f50cfa561
15 changed files with 1864 additions and 1670 deletions
1607
docs/_static/llama-stack-spec.html
vendored
1607
docs/_static/llama-stack-spec.html
vendored
File diff suppressed because it is too large
Load diff
1103
docs/_static/llama-stack-spec.yaml
vendored
1103
docs/_static/llama-stack-spec.yaml
vendored
File diff suppressed because it is too large
Load diff
|
@ -4,9 +4,11 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import List, Optional, Protocol, runtime_checkable
|
from typing import List, Literal, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_stack.apis.common.job_types import Job
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from llama_stack.apis.common.job_types import BaseJob
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
|
@ -20,6 +22,14 @@ from llama_stack.apis.inference import (
|
||||||
from llama_stack.schema_utils import webmethod
|
from llama_stack.schema_utils import webmethod
|
||||||
|
|
||||||
|
|
||||||
|
class BatchInferenceJob(BaseJob, BaseModel):
|
||||||
|
type: Literal["batch_inference"] = "batch_inference"
|
||||||
|
|
||||||
|
|
||||||
|
class ListBatchInferenceJobsResponse(BaseModel):
|
||||||
|
data: list[BatchInferenceJob]
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class BatchInference(Protocol):
|
class BatchInference(Protocol):
|
||||||
"""Batch inference API for generating completions and chat completions.
|
"""Batch inference API for generating completions and chat completions.
|
||||||
|
@ -38,7 +48,7 @@ class BatchInference(Protocol):
|
||||||
sampling_params: Optional[SamplingParams] = None,
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> Job: ...
|
) -> BatchInferenceJob: ...
|
||||||
|
|
||||||
@webmethod(route="/batch-inference/chat-completion", method="POST")
|
@webmethod(route="/batch-inference/chat-completion", method="POST")
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
|
@ -52,4 +62,4 @@ class BatchInference(Protocol):
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> Job: ...
|
) -> BatchInferenceJob: ...
|
||||||
|
|
|
@ -3,22 +3,68 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
from enum import Enum
|
from datetime import datetime, timezone
|
||||||
|
from enum import Enum, unique
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field, computed_field
|
||||||
|
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
|
@unique
|
||||||
class JobStatus(Enum):
|
class JobStatus(Enum):
|
||||||
completed = "completed"
|
unknown = "unknown"
|
||||||
in_progress = "in_progress"
|
new = "new"
|
||||||
failed = "failed"
|
|
||||||
scheduled = "scheduled"
|
scheduled = "scheduled"
|
||||||
|
running = "running"
|
||||||
|
paused = "paused"
|
||||||
|
resuming = "resuming"
|
||||||
cancelled = "cancelled"
|
cancelled = "cancelled"
|
||||||
|
failed = "failed"
|
||||||
|
completed = "completed"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class Job(BaseModel):
|
class JobStatusDetails(BaseModel):
|
||||||
job_id: str
|
|
||||||
status: JobStatus
|
status: JobStatus
|
||||||
|
message: str | None = None
|
||||||
|
timestamp: datetime
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class JobArtifact(BaseModel):
|
||||||
|
name: str
|
||||||
|
|
||||||
|
# TODO: should it be a Literal / Enum?
|
||||||
|
type: str
|
||||||
|
|
||||||
|
# Any additional metadata the artifact may have
|
||||||
|
# TODO: is Any the right type here? What happens when the caller passes a value without a __repr__?
|
||||||
|
metadata: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
# TODO: enforce type to be a URI
|
||||||
|
uri: str | None = None # points to /files
|
||||||
|
|
||||||
|
|
||||||
|
def _get_job_status_details(status: JobStatus) -> JobStatusDetails:
|
||||||
|
return JobStatusDetails(status=status, timestamp=datetime.now(timezone.utc))
|
||||||
|
|
||||||
|
|
||||||
|
class BaseJob(BaseModel):
|
||||||
|
id: str # TODO: make it a UUID?
|
||||||
|
|
||||||
|
artifacts: list[JobArtifact] = Field(default_factory=list)
|
||||||
|
events: list[JobStatusDetails] = Field(default_factory=lambda: [_get_job_status_details(JobStatus.new)])
|
||||||
|
|
||||||
|
def __init_subclass__(cls, **kwargs):
|
||||||
|
super().__init_subclass__(**kwargs)
|
||||||
|
if "type" not in cls.__annotations__:
|
||||||
|
raise ValueError(f"Class {cls.__name__} must have a type field")
|
||||||
|
|
||||||
|
@computed_field
|
||||||
|
def status(self) -> JobStatus:
|
||||||
|
return self.events[-1].status
|
||||||
|
|
||||||
|
def update_status(self, value: JobStatus):
|
||||||
|
self.events.append(_get_job_status_details(value))
|
||||||
|
|
|
@ -4,15 +4,14 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
from typing import Dict, Literal, Optional, Protocol, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_stack.apis.agents import AgentConfig
|
from llama_stack.apis.agents import AgentConfig
|
||||||
from llama_stack.apis.common.job_types import Job
|
from llama_stack.apis.common.job_types import BaseJob
|
||||||
from llama_stack.apis.inference import SamplingParams, SystemMessage
|
from llama_stack.apis.inference import SamplingParams, SystemMessage
|
||||||
from llama_stack.apis.scoring import ScoringResult
|
|
||||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
|
||||||
|
@ -47,6 +46,14 @@ EvalCandidate = Annotated[Union[ModelCandidate, AgentCandidate], Field(discrimin
|
||||||
register_schema(EvalCandidate, name="EvalCandidate")
|
register_schema(EvalCandidate, name="EvalCandidate")
|
||||||
|
|
||||||
|
|
||||||
|
class EvaluateJob(BaseJob, BaseModel):
|
||||||
|
type: Literal["eval"] = "eval"
|
||||||
|
|
||||||
|
|
||||||
|
class ListEvaluateJobsResponse(BaseModel):
|
||||||
|
data: list[EvaluateJob]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class BenchmarkConfig(BaseModel):
|
class BenchmarkConfig(BaseModel):
|
||||||
"""A benchmark configuration for evaluation.
|
"""A benchmark configuration for evaluation.
|
||||||
|
@ -68,76 +75,30 @@ class BenchmarkConfig(BaseModel):
|
||||||
# we could optinally add any specific dataset config here
|
# we could optinally add any specific dataset config here
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class EvaluateResponse(BaseModel):
|
|
||||||
"""The response from an evaluation.
|
|
||||||
|
|
||||||
:param generations: The generations from the evaluation.
|
|
||||||
:param scores: The scores from the evaluation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
generations: List[Dict[str, Any]]
|
|
||||||
# each key in the dict is a scoring function name
|
|
||||||
scores: Dict[str, ScoringResult]
|
|
||||||
|
|
||||||
|
|
||||||
class Eval(Protocol):
|
class Eval(Protocol):
|
||||||
"""Llama Stack Evaluation API for running evaluations on model and agent candidates."""
|
"""Llama Stack Evaluation API for running evaluations on model and agent candidates."""
|
||||||
|
|
||||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST")
|
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluate", method="POST")
|
||||||
async def run_eval(
|
async def evaluate(
|
||||||
self,
|
self,
|
||||||
benchmark_id: str,
|
benchmark_id: str,
|
||||||
benchmark_config: BenchmarkConfig,
|
benchmark_config: BenchmarkConfig,
|
||||||
) -> Job:
|
) -> EvaluateJob: ...
|
||||||
"""Run an evaluation on a benchmark.
|
|
||||||
|
|
||||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
# CRUD operations on running jobs
|
||||||
:param benchmark_config: The configuration for the benchmark.
|
@webmethod(route="/evaluate/jobs/{job_id:path}", method="GET")
|
||||||
:return: The job that was created to run the evaluation.
|
async def get_evaluate_job(self, job_id: str) -> EvaluateJob: ...
|
||||||
"""
|
|
||||||
|
|
||||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST")
|
@webmethod(route="/evaluate/jobs", method="GET")
|
||||||
async def evaluate_rows(
|
async def list_evaluate_jobs(self) -> ListEvaluateJobsResponse: ...
|
||||||
self,
|
|
||||||
benchmark_id: str,
|
|
||||||
input_rows: List[Dict[str, Any]],
|
|
||||||
scoring_functions: List[str],
|
|
||||||
benchmark_config: BenchmarkConfig,
|
|
||||||
) -> EvaluateResponse:
|
|
||||||
"""Evaluate a list of rows on a benchmark.
|
|
||||||
|
|
||||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
@webmethod(route="/evaluate/jobs/{job_id:path}", method="POST")
|
||||||
:param input_rows: The rows to evaluate.
|
async def update_evaluate_job(self, job: EvaluateJob) -> EvaluateJob: ...
|
||||||
:param scoring_functions: The scoring functions to use for the evaluation.
|
|
||||||
:param benchmark_config: The configuration for the benchmark.
|
|
||||||
:return: EvaluateResponse object containing generations and scores
|
|
||||||
"""
|
|
||||||
|
|
||||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
|
@webmethod(route="/evaluate/job/{job_id:path}", method="DELETE")
|
||||||
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
|
async def delete_evaluate_job(self, job_id: str) -> None: ...
|
||||||
"""Get the status of a job.
|
|
||||||
|
|
||||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
# Note: pause/resume/cancel are achieved as follows:
|
||||||
:param job_id: The ID of the job to get the status of.
|
# - POST with status=paused
|
||||||
:return: The status of the evaluationjob.
|
# - POST with status=resuming
|
||||||
"""
|
# - POST with status=cancelled
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="DELETE")
|
|
||||||
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
|
|
||||||
"""Cancel a job.
|
|
||||||
|
|
||||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
|
||||||
:param job_id: The ID of the job to cancel.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result", method="GET")
|
|
||||||
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
|
|
||||||
"""Get the result of a job.
|
|
||||||
|
|
||||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
|
||||||
:param job_id: The ID of the job to get the result of.
|
|
||||||
:return: The result of the job.
|
|
||||||
"""
|
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from datetime import datetime
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
||||||
|
|
||||||
|
@ -12,8 +11,7 @@ from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.common.job_types import JobStatus
|
from llama_stack.apis.common.job_types import BaseJob
|
||||||
from llama_stack.apis.common.training_types import Checkpoint
|
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
|
||||||
|
|
||||||
|
@ -92,14 +90,6 @@ AlgorithmConfig = Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Fi
|
||||||
register_schema(AlgorithmConfig, name="AlgorithmConfig")
|
register_schema(AlgorithmConfig, name="AlgorithmConfig")
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class PostTrainingJobLogStream(BaseModel):
|
|
||||||
"""Stream of logs from a finetuning job."""
|
|
||||||
|
|
||||||
job_uuid: str
|
|
||||||
log_lines: List[str]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RLHFAlgorithm(Enum):
|
class RLHFAlgorithm(Enum):
|
||||||
dpo = "dpo"
|
dpo = "dpo"
|
||||||
|
@ -135,41 +125,17 @@ class PostTrainingRLHFRequest(BaseModel):
|
||||||
logger_config: Dict[str, Any]
|
logger_config: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
class PostTrainingJob(BaseModel):
|
|
||||||
job_uuid: str
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class PostTrainingJobStatusResponse(BaseModel):
|
class PostTrainingJob(BaseJob, BaseModel):
|
||||||
"""Status of a finetuning job."""
|
type: Literal["post-training"] = "post-training"
|
||||||
|
|
||||||
job_uuid: str
|
|
||||||
status: JobStatus
|
|
||||||
|
|
||||||
scheduled_at: Optional[datetime] = None
|
|
||||||
started_at: Optional[datetime] = None
|
|
||||||
completed_at: Optional[datetime] = None
|
|
||||||
|
|
||||||
resources_allocated: Optional[Dict[str, Any]] = None
|
|
||||||
|
|
||||||
checkpoints: List[Checkpoint] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
class ListPostTrainingJobsResponse(BaseModel):
|
class ListPostTrainingJobsResponse(BaseModel):
|
||||||
data: List[PostTrainingJob]
|
data: list[PostTrainingJob]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class PostTrainingJobArtifactsResponse(BaseModel):
|
|
||||||
"""Artifacts of a finetuning job."""
|
|
||||||
|
|
||||||
job_uuid: str
|
|
||||||
checkpoints: List[Checkpoint] = Field(default_factory=list)
|
|
||||||
|
|
||||||
# TODO(ashwin): metrics, evals
|
|
||||||
|
|
||||||
|
|
||||||
class PostTraining(Protocol):
|
class PostTraining(Protocol):
|
||||||
|
# This is how you create a new job - POST against the root endpoint
|
||||||
@webmethod(route="/post-training/supervised-fine-tune", method="POST")
|
@webmethod(route="/post-training/supervised-fine-tune", method="POST")
|
||||||
async def supervised_fine_tune(
|
async def supervised_fine_tune(
|
||||||
self,
|
self,
|
||||||
|
@ -196,14 +162,20 @@ class PostTraining(Protocol):
|
||||||
logger_config: Dict[str, Any],
|
logger_config: Dict[str, Any],
|
||||||
) -> PostTrainingJob: ...
|
) -> PostTrainingJob: ...
|
||||||
|
|
||||||
|
# CRUD operations on running jobs
|
||||||
|
@webmethod(route="/post-training/jobs/{job_id:path}", method="GET")
|
||||||
|
async def get_post_training_job(self, job_id: str) -> PostTrainingJob: ...
|
||||||
|
|
||||||
@webmethod(route="/post-training/jobs", method="GET")
|
@webmethod(route="/post-training/jobs", method="GET")
|
||||||
async def get_training_jobs(self) -> ListPostTrainingJobsResponse: ...
|
async def list_post_training_jobs(self) -> ListPostTrainingJobsResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/post-training/job/status", method="GET")
|
@webmethod(route="/post-training/jobs/{job_id:path}", method="POST")
|
||||||
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse: ...
|
async def update_post_training_job(self, job: PostTrainingJob) -> PostTrainingJob: ...
|
||||||
|
|
||||||
@webmethod(route="/post-training/job/cancel", method="POST")
|
@webmethod(route="/post-training/job/{job_id:path}", method="DELETE")
|
||||||
async def cancel_training_job(self, job_uuid: str) -> None: ...
|
async def delete_post_training_job(self, job_id: str) -> None: ...
|
||||||
|
|
||||||
@webmethod(route="/post-training/job/artifacts", method="GET")
|
# Note: pause/resume/cancel are achieved as follows:
|
||||||
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse: ...
|
# - POST with status=paused
|
||||||
|
# - POST with status=resuming
|
||||||
|
# - POST with status=cancelled
|
||||||
|
|
|
@ -5,10 +5,11 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Optional, Protocol, Union
|
from typing import List, Literal, Optional, Protocol
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from llama_stack.apis.common.job_types import BaseJob
|
||||||
from llama_stack.apis.inference import Message
|
from llama_stack.apis.inference import Message
|
||||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
@ -34,11 +35,13 @@ class SyntheticDataGenerationRequest(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class SyntheticDataGenerationResponse(BaseModel):
|
class SyntheticDataGenerationJob(BaseJob, BaseModel):
|
||||||
"""Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold."""
|
type: Literal["synthetic-data-generation"] = "synthetic-data-generation"
|
||||||
|
|
||||||
synthetic_data: List[Dict[str, Any]]
|
|
||||||
statistics: Optional[Dict[str, Any]] = None
|
@json_schema_type
|
||||||
|
class ListSyntheticDataGenerationJobsResponse(BaseModel):
|
||||||
|
items: list[SyntheticDataGenerationJob]
|
||||||
|
|
||||||
|
|
||||||
class SyntheticDataGeneration(Protocol):
|
class SyntheticDataGeneration(Protocol):
|
||||||
|
@ -48,4 +51,24 @@ class SyntheticDataGeneration(Protocol):
|
||||||
dialogs: List[Message],
|
dialogs: List[Message],
|
||||||
filtering_function: FilteringFunction = FilteringFunction.none,
|
filtering_function: FilteringFunction = FilteringFunction.none,
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
) -> Union[SyntheticDataGenerationResponse]: ...
|
) -> SyntheticDataGenerationJob: ...
|
||||||
|
|
||||||
|
# CRUD operations on running jobs
|
||||||
|
@webmethod(route="/synthetic-data-generation/jobs/{job_id:path}", method="GET")
|
||||||
|
async def get_synthetic_data_generation_job(self) -> SyntheticDataGenerationJob: ...
|
||||||
|
|
||||||
|
@webmethod(route="/synthetic-data-generation/jobs", method="GET")
|
||||||
|
async def list_synthetic_data_generation_jobs(self) -> ListSyntheticDataGenerationJobsResponse: ...
|
||||||
|
|
||||||
|
@webmethod(route="/synthetic-data-generation/jobs/{job_id:path}", method="POST")
|
||||||
|
async def update_synthetic_data_generation_job(
|
||||||
|
self, job: SyntheticDataGenerationJob
|
||||||
|
) -> SyntheticDataGenerationJob: ...
|
||||||
|
|
||||||
|
@webmethod(route="/synthetic-data-generation/job/{job_id:path}", method="DELETE")
|
||||||
|
async def delete_synthetic_data_generation_job(self, job_id: str) -> None: ...
|
||||||
|
|
||||||
|
# Note: pause/resume/cancel are achieved as follows:
|
||||||
|
# - POST with status=paused
|
||||||
|
# - POST with status=resuming
|
||||||
|
# - POST with status=cancelled
|
||||||
|
|
|
@ -16,7 +16,7 @@ from llama_stack.apis.common.content_types import (
|
||||||
from llama_stack.apis.common.responses import PaginatedResponse
|
from llama_stack.apis.common.responses import PaginatedResponse
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.apis.datasets import DatasetPurpose, DataSource
|
from llama_stack.apis.datasets import DatasetPurpose, DataSource
|
||||||
from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job
|
from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateJob, ListEvaluateJobsResponse
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
BatchChatCompletionResponse,
|
BatchChatCompletionResponse,
|
||||||
BatchCompletionResponse,
|
BatchCompletionResponse,
|
||||||
|
@ -779,61 +779,32 @@ class EvalRouter(Eval):
|
||||||
logger.debug("EvalRouter.shutdown")
|
logger.debug("EvalRouter.shutdown")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def run_eval(
|
async def evaluate(
|
||||||
self,
|
self,
|
||||||
benchmark_id: str,
|
benchmark_id: str,
|
||||||
benchmark_config: BenchmarkConfig,
|
benchmark_config: BenchmarkConfig,
|
||||||
) -> Job:
|
) -> EvaluateJob:
|
||||||
logger.debug(f"EvalRouter.run_eval: {benchmark_id}")
|
logger.debug(f"EvalRouter.evaluate: {benchmark_id}")
|
||||||
return await self.routing_table.get_provider_impl(benchmark_id).run_eval(
|
return await self.routing_table.get_provider_impl(benchmark_id).evaluate(
|
||||||
benchmark_id=benchmark_id,
|
benchmark_id=benchmark_id,
|
||||||
benchmark_config=benchmark_config,
|
benchmark_config=benchmark_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def evaluate_rows(
|
async def get_evaluate_job(self, job_id: str) -> EvaluateJob:
|
||||||
self,
|
logger.debug(f"EvalRouter.get_evaluate_job: {job_id}")
|
||||||
benchmark_id: str,
|
return await self.routing_table.get_provider_impl("eval").get_evaluate_job(job_id)
|
||||||
input_rows: List[Dict[str, Any]],
|
|
||||||
scoring_functions: List[str],
|
|
||||||
benchmark_config: BenchmarkConfig,
|
|
||||||
) -> EvaluateResponse:
|
|
||||||
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
|
|
||||||
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
|
|
||||||
benchmark_id=benchmark_id,
|
|
||||||
input_rows=input_rows,
|
|
||||||
scoring_functions=scoring_functions,
|
|
||||||
benchmark_config=benchmark_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def job_status(
|
async def list_evaluate_jobs(self) -> ListEvaluateJobsResponse:
|
||||||
self,
|
logger.debug("EvalRouter.list_evaluate_jobs")
|
||||||
benchmark_id: str,
|
return await self.routing_table.get_provider_impl("eval").list_evaluate_jobs()
|
||||||
job_id: str,
|
|
||||||
) -> Job:
|
|
||||||
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
|
|
||||||
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
|
|
||||||
|
|
||||||
async def job_cancel(
|
async def update_evaluate_job(self, job: EvaluateJob) -> EvaluateJob:
|
||||||
self,
|
logger.debug(f"EvalRouter.update_evaluate_job: {job.id}")
|
||||||
benchmark_id: str,
|
return await self.routing_table.get_provider_impl("eval").update_evaluate_job(job)
|
||||||
job_id: str,
|
|
||||||
) -> None:
|
|
||||||
logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
|
|
||||||
await self.routing_table.get_provider_impl(benchmark_id).job_cancel(
|
|
||||||
benchmark_id,
|
|
||||||
job_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def job_result(
|
async def delete_evaluate_job(self, job_id: str) -> None:
|
||||||
self,
|
logger.debug(f"EvalRouter.delete_evaluate_job: {job_id}")
|
||||||
benchmark_id: str,
|
return await self.routing_table.get_provider_impl("eval").delete_evaluate_job(job_id)
|
||||||
job_id: str,
|
|
||||||
) -> EvaluateResponse:
|
|
||||||
logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}")
|
|
||||||
return await self.routing_table.get_provider_impl(benchmark_id).job_result(
|
|
||||||
benchmark_id,
|
|
||||||
job_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ToolRuntimeRouter(ToolRuntime):
|
class ToolRuntimeRouter(ToolRuntime):
|
||||||
|
|
|
@ -20,9 +20,10 @@ from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
|
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
|
from llama_stack.schema_utils import webmethod
|
||||||
|
|
||||||
from .....apis.common.job_types import Job, JobStatus
|
from .....apis.common.job_types import JobArtifact, JobStatus
|
||||||
from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse
|
from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateJob, ListEvaluateJobsResponse
|
||||||
from .config import MetaReferenceEvalConfig
|
from .config import MetaReferenceEvalConfig
|
||||||
|
|
||||||
EVAL_TASKS_PREFIX = "benchmarks:"
|
EVAL_TASKS_PREFIX = "benchmarks:"
|
||||||
|
@ -75,11 +76,11 @@ class MetaReferenceEvalImpl(
|
||||||
)
|
)
|
||||||
self.benchmarks[task_def.identifier] = task_def
|
self.benchmarks[task_def.identifier] = task_def
|
||||||
|
|
||||||
async def run_eval(
|
async def evaluate(
|
||||||
self,
|
self,
|
||||||
benchmark_id: str,
|
benchmark_id: str,
|
||||||
benchmark_config: BenchmarkConfig,
|
benchmark_config: BenchmarkConfig,
|
||||||
) -> Job:
|
) -> EvaluateJob:
|
||||||
task_def = self.benchmarks[benchmark_id]
|
task_def = self.benchmarks[benchmark_id]
|
||||||
dataset_id = task_def.dataset_id
|
dataset_id = task_def.dataset_id
|
||||||
scoring_functions = task_def.scoring_functions
|
scoring_functions = task_def.scoring_functions
|
||||||
|
@ -91,18 +92,35 @@ class MetaReferenceEvalImpl(
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
limit=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples),
|
limit=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples),
|
||||||
)
|
)
|
||||||
res = await self.evaluate_rows(
|
|
||||||
|
generations, scoring_results = await self._evaluate_rows(
|
||||||
benchmark_id=benchmark_id,
|
benchmark_id=benchmark_id,
|
||||||
input_rows=all_rows.data,
|
input_rows=all_rows.data,
|
||||||
scoring_functions=scoring_functions,
|
scoring_functions=scoring_functions,
|
||||||
benchmark_config=benchmark_config,
|
benchmark_config=benchmark_config,
|
||||||
)
|
)
|
||||||
|
artifacts = [
|
||||||
|
JobArtifact(
|
||||||
|
type="generation",
|
||||||
|
name=f"generation-{i}",
|
||||||
|
metadata=generation,
|
||||||
|
)
|
||||||
|
for i, generation in enumerate(generations)
|
||||||
|
] + [
|
||||||
|
JobArtifact(
|
||||||
|
type="scoring_results",
|
||||||
|
name="scoring_results",
|
||||||
|
metadata=scoring_results,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
# TODO: currently needs to wait for generation before returning
|
# TODO: currently needs to wait for generation before returning
|
||||||
# need job scheduler queue (ray/celery) w/ jobs api
|
# need job scheduler queue (ray/celery) w/ jobs api
|
||||||
job_id = str(len(self.jobs))
|
job_id = str(len(self.jobs))
|
||||||
self.jobs[job_id] = res
|
job = EvaluateJob(id=job_id, artifacts=artifacts)
|
||||||
return Job(job_id=job_id, status=JobStatus.completed)
|
job.update_status(JobStatus.completed)
|
||||||
|
self.jobs[job_id] = job
|
||||||
|
return job
|
||||||
|
|
||||||
async def _run_agent_generation(
|
async def _run_agent_generation(
|
||||||
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
|
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
|
||||||
|
@ -182,13 +200,13 @@ class MetaReferenceEvalImpl(
|
||||||
|
|
||||||
return generations
|
return generations
|
||||||
|
|
||||||
async def evaluate_rows(
|
async def _evaluate_rows(
|
||||||
self,
|
self,
|
||||||
benchmark_id: str,
|
benchmark_id: str,
|
||||||
input_rows: List[Dict[str, Any]],
|
input_rows: List[Dict[str, Any]],
|
||||||
scoring_functions: List[str],
|
scoring_functions: List[str],
|
||||||
benchmark_config: BenchmarkConfig,
|
benchmark_config: BenchmarkConfig,
|
||||||
) -> EvaluateResponse:
|
) -> tuple[list[dict[str, Any]], dict[str, Any]]:
|
||||||
candidate = benchmark_config.eval_candidate
|
candidate = benchmark_config.eval_candidate
|
||||||
if candidate.type == "agent":
|
if candidate.type == "agent":
|
||||||
generations = await self._run_agent_generation(input_rows, benchmark_config)
|
generations = await self._run_agent_generation(input_rows, benchmark_config)
|
||||||
|
@ -214,21 +232,26 @@ class MetaReferenceEvalImpl(
|
||||||
input_rows=score_input_rows, scoring_functions=scoring_functions_dict
|
input_rows=score_input_rows, scoring_functions=scoring_functions_dict
|
||||||
)
|
)
|
||||||
|
|
||||||
return EvaluateResponse(generations=generations, scores=score_response.results)
|
return generations, score_response.results
|
||||||
|
|
||||||
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
|
|
||||||
if job_id in self.jobs:
|
|
||||||
return Job(job_id=job_id, status=JobStatus.completed)
|
|
||||||
|
|
||||||
raise ValueError(f"Job {job_id} not found")
|
|
||||||
|
|
||||||
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
|
|
||||||
raise NotImplementedError("Job cancel is not implemented yet")
|
|
||||||
|
|
||||||
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
|
|
||||||
job = await self.job_status(benchmark_id, job_id)
|
|
||||||
status = job.status
|
|
||||||
if not status or status != JobStatus.completed:
|
|
||||||
raise ValueError(f"Job is not completed, Status: {status.value}")
|
|
||||||
|
|
||||||
|
# CRUD operations on running jobs
|
||||||
|
@webmethod(route="/evaluate/jobs/{job_id:path}", method="GET")
|
||||||
|
async def get_evaluate_job(self, job_id: str) -> EvaluateJob:
|
||||||
return self.jobs[job_id]
|
return self.jobs[job_id]
|
||||||
|
|
||||||
|
@webmethod(route="/evaluate/jobs", method="GET")
|
||||||
|
async def list_evaluate_jobs(self) -> ListEvaluateJobsResponse:
|
||||||
|
return ListEvaluateJobsResponse(data=list(self.jobs.values()))
|
||||||
|
|
||||||
|
@webmethod(route="/evaluate/jobs/{job_id:path}", method="POST")
|
||||||
|
async def update_evaluate_job(self, job: EvaluateJob) -> EvaluateJob:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@webmethod(route="/evaluate/job/{job_id:path}", method="DELETE")
|
||||||
|
async def delete_evaluate_job(self, job_id: str) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
# Note: pause/resume/cancel are achieved as follows:
|
||||||
|
# - POST with status=paused
|
||||||
|
# - POST with status=resuming
|
||||||
|
# - POST with status=cancelled
|
||||||
|
|
|
@ -10,14 +10,10 @@ from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.apis.datasets import Datasets
|
from llama_stack.apis.datasets import Datasets
|
||||||
from llama_stack.apis.post_training import (
|
from llama_stack.apis.post_training import (
|
||||||
AlgorithmConfig,
|
AlgorithmConfig,
|
||||||
Checkpoint,
|
|
||||||
DPOAlignmentConfig,
|
DPOAlignmentConfig,
|
||||||
JobStatus,
|
|
||||||
ListPostTrainingJobsResponse,
|
ListPostTrainingJobsResponse,
|
||||||
LoraFinetuningConfig,
|
LoraFinetuningConfig,
|
||||||
PostTrainingJob,
|
PostTrainingJob,
|
||||||
PostTrainingJobArtifactsResponse,
|
|
||||||
PostTrainingJobStatusResponse,
|
|
||||||
TrainingConfig,
|
TrainingConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.inline.post_training.torchtune.config import (
|
from llama_stack.providers.inline.post_training.torchtune.config import (
|
||||||
|
@ -54,15 +50,6 @@ class TorchtunePostTrainingImpl:
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
await self._scheduler.shutdown()
|
await self._scheduler.shutdown()
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _checkpoint_to_artifact(checkpoint: Checkpoint) -> JobArtifact:
|
|
||||||
return JobArtifact(
|
|
||||||
type=TrainingArtifactType.CHECKPOINT.value,
|
|
||||||
name=checkpoint.identifier,
|
|
||||||
uri=checkpoint.path,
|
|
||||||
metadata=dict(checkpoint),
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _resources_stats_to_artifact(resources_stats: Dict[str, Any]) -> JobArtifact:
|
def _resources_stats_to_artifact(resources_stats: Dict[str, Any]) -> JobArtifact:
|
||||||
return JobArtifact(
|
return JobArtifact(
|
||||||
|
@ -98,14 +85,14 @@ class TorchtunePostTrainingImpl:
|
||||||
self.datasetio_api,
|
self.datasetio_api,
|
||||||
self.datasets_api,
|
self.datasets_api,
|
||||||
)
|
)
|
||||||
|
|
||||||
await recipe.setup()
|
await recipe.setup()
|
||||||
|
|
||||||
resources_allocated, checkpoints = await recipe.train()
|
resources_allocated, checkpoints = await recipe.train()
|
||||||
|
|
||||||
on_artifact_collected_cb(self._resources_stats_to_artifact(resources_allocated))
|
on_artifact_collected_cb(self._resources_stats_to_artifact(resources_allocated))
|
||||||
for checkpoint in checkpoints:
|
for checkpoint in checkpoints:
|
||||||
artifact = self._checkpoint_to_artifact(checkpoint)
|
on_artifact_collected_cb(checkpoint)
|
||||||
on_artifact_collected_cb(artifact)
|
|
||||||
|
|
||||||
on_status_change_cb(SchedulerJobStatus.completed)
|
on_status_change_cb(SchedulerJobStatus.completed)
|
||||||
on_log_message_cb("Lora finetuning completed")
|
on_log_message_cb("Lora finetuning completed")
|
||||||
|
@ -113,6 +100,8 @@ class TorchtunePostTrainingImpl:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler)
|
job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler)
|
||||||
|
|
||||||
|
# TODO: initialize with more data from scheduler
|
||||||
return PostTrainingJob(job_uuid=job_uuid)
|
return PostTrainingJob(job_uuid=job_uuid)
|
||||||
|
|
||||||
async def preference_optimize(
|
async def preference_optimize(
|
||||||
|
@ -125,56 +114,31 @@ class TorchtunePostTrainingImpl:
|
||||||
logger_config: Dict[str, Any],
|
logger_config: Dict[str, Any],
|
||||||
) -> PostTrainingJob: ...
|
) -> PostTrainingJob: ...
|
||||||
|
|
||||||
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
|
# TODO: should these be under post-training/supervised-fine-tune/?
|
||||||
|
# CRUD operations on running jobs
|
||||||
|
@webmethod(route="/post-training/jobs/{job_id:path}", method="GET")
|
||||||
|
async def get_post_training_job(self, job_id: str) -> PostTrainingJob:
|
||||||
|
# TODO: implement
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@webmethod(route="/post-training/jobs", method="GET")
|
||||||
|
async def list_post_training_jobs(self) -> ListPostTrainingJobsResponse:
|
||||||
|
# TODO: populate other data
|
||||||
return ListPostTrainingJobsResponse(
|
return ListPostTrainingJobsResponse(
|
||||||
data=[PostTrainingJob(job_uuid=job.id) for job in self._scheduler.get_jobs()]
|
data=[PostTrainingJob(job_uuid=job.id) for job in self._scheduler.get_jobs()]
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@webmethod(route="/post-training/jobs/{job_id:path}", method="POST")
|
||||||
def _get_artifacts_metadata_by_type(job, artifact_type):
|
async def update_post_training_job(self, job: PostTrainingJob) -> PostTrainingJob:
|
||||||
return [artifact.metadata for artifact in job.artifacts if artifact.type == artifact_type]
|
# TODO: implement
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@classmethod
|
@webmethod(route="/post-training/job/{job_id:path}", method="DELETE")
|
||||||
def _get_checkpoints(cls, job):
|
async def delete_post_training_job(self, job_id: str) -> None:
|
||||||
return cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.CHECKPOINT.value)
|
# TODO: implement
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@classmethod
|
# Note: pause/resume/cancel are achieved as follows:
|
||||||
def _get_resources_allocated(cls, job):
|
# - POST with status=paused
|
||||||
data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value)
|
# - POST with status=resuming
|
||||||
return data[0] if data else None
|
# - POST with status=cancelled
|
||||||
|
|
||||||
@webmethod(route="/post-training/job/status")
|
|
||||||
async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]:
|
|
||||||
job = self._scheduler.get_job(job_uuid)
|
|
||||||
|
|
||||||
match job.status:
|
|
||||||
# TODO: Add support for other statuses to API
|
|
||||||
case SchedulerJobStatus.new | SchedulerJobStatus.scheduled:
|
|
||||||
status = JobStatus.scheduled
|
|
||||||
case SchedulerJobStatus.running:
|
|
||||||
status = JobStatus.in_progress
|
|
||||||
case SchedulerJobStatus.completed:
|
|
||||||
status = JobStatus.completed
|
|
||||||
case SchedulerJobStatus.failed:
|
|
||||||
status = JobStatus.failed
|
|
||||||
case _:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
return PostTrainingJobStatusResponse(
|
|
||||||
job_uuid=job_uuid,
|
|
||||||
status=status,
|
|
||||||
scheduled_at=job.scheduled_at,
|
|
||||||
started_at=job.started_at,
|
|
||||||
completed_at=job.completed_at,
|
|
||||||
checkpoints=self._get_checkpoints(job),
|
|
||||||
resources_allocated=self._get_resources_allocated(job),
|
|
||||||
)
|
|
||||||
|
|
||||||
@webmethod(route="/post-training/job/cancel")
|
|
||||||
async def cancel_training_job(self, job_uuid: str) -> None:
|
|
||||||
self._scheduler.cancel(job_uuid)
|
|
||||||
|
|
||||||
@webmethod(route="/post-training/job/artifacts")
|
|
||||||
async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]:
|
|
||||||
job = self._scheduler.get_job(job_uuid)
|
|
||||||
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job))
|
|
||||||
|
|
|
@ -33,11 +33,11 @@ from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup
|
||||||
from torchtune.training.metric_logging import DiskLogger
|
from torchtune.training.metric_logging import DiskLogger
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from llama_stack.apis.common.job_types import JobArtifact
|
||||||
from llama_stack.apis.common.training_types import PostTrainingMetric
|
from llama_stack.apis.common.training_types import PostTrainingMetric
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.apis.datasets import Datasets
|
from llama_stack.apis.datasets import Datasets
|
||||||
from llama_stack.apis.post_training import (
|
from llama_stack.apis.post_training import (
|
||||||
Checkpoint,
|
|
||||||
DataConfig,
|
DataConfig,
|
||||||
EfficiencyConfig,
|
EfficiencyConfig,
|
||||||
LoraFinetuningConfig,
|
LoraFinetuningConfig,
|
||||||
|
@ -457,7 +457,7 @@ class LoraFinetuningSingleDevice:
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
async def train(self) -> Tuple[Dict[str, Any], List[Checkpoint]]:
|
async def train(self) -> Tuple[Dict[str, Any], List[JobArtifact]]:
|
||||||
"""
|
"""
|
||||||
The core training loop.
|
The core training loop.
|
||||||
"""
|
"""
|
||||||
|
@ -543,13 +543,18 @@ class LoraFinetuningSingleDevice:
|
||||||
self.epochs_run += 1
|
self.epochs_run += 1
|
||||||
log.info("Starting checkpoint save...")
|
log.info("Starting checkpoint save...")
|
||||||
checkpoint_path = await self.save_checkpoint(epoch=curr_epoch)
|
checkpoint_path = await self.save_checkpoint(epoch=curr_epoch)
|
||||||
checkpoint = Checkpoint(
|
|
||||||
identifier=f"{self.model_id}-sft-{curr_epoch}",
|
checkpoint = JobArtifact(
|
||||||
created_at=datetime.now(timezone.utc),
|
name=f"{self.model_id}-sft-{curr_epoch}",
|
||||||
epoch=curr_epoch,
|
type="checkpoint",
|
||||||
post_training_job_id=self.job_uuid,
|
# TODO: this should be exposed via /files instead
|
||||||
path=checkpoint_path,
|
uri=checkpoint_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
"created_at": datetime.now(timezone.utc),
|
||||||
|
"epoch": curr_epoch,
|
||||||
|
}
|
||||||
if self.training_config.data_config.validation_dataset_id:
|
if self.training_config.data_config.validation_dataset_id:
|
||||||
validation_loss, perplexity = await self.validation()
|
validation_loss, perplexity = await self.validation()
|
||||||
training_metrics = PostTrainingMetric(
|
training_metrics = PostTrainingMetric(
|
||||||
|
@ -558,7 +563,9 @@ class LoraFinetuningSingleDevice:
|
||||||
validation_loss=validation_loss,
|
validation_loss=validation_loss,
|
||||||
perplexity=perplexity,
|
perplexity=perplexity,
|
||||||
)
|
)
|
||||||
checkpoint.training_metrics = training_metrics
|
metadata["training_metrics"] = training_metrics
|
||||||
|
checkpoint.metadata = metadata
|
||||||
|
|
||||||
checkpoints.append(checkpoint)
|
checkpoints.append(checkpoint)
|
||||||
|
|
||||||
# clean up the memory after training finishes
|
# clean up the memory after training finishes
|
||||||
|
|
|
@ -4,19 +4,16 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import warnings
|
import warnings
|
||||||
from datetime import datetime
|
from typing import Any, Dict, Optional
|
||||||
from typing import Any, Dict, List, Literal, Optional
|
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from pydantic import BaseModel, ConfigDict
|
|
||||||
|
|
||||||
|
from llama_stack.apis.common.job_types import JobStatus
|
||||||
from llama_stack.apis.post_training import (
|
from llama_stack.apis.post_training import (
|
||||||
AlgorithmConfig,
|
AlgorithmConfig,
|
||||||
DPOAlignmentConfig,
|
DPOAlignmentConfig,
|
||||||
JobStatus,
|
ListPostTrainingJobsResponse,
|
||||||
PostTrainingJob,
|
PostTrainingJob,
|
||||||
PostTrainingJobArtifactsResponse,
|
|
||||||
PostTrainingJobStatusResponse,
|
|
||||||
TrainingConfig,
|
TrainingConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.remote.post_training.nvidia.config import NvidiaPostTrainingConfig
|
from llama_stack.providers.remote.post_training.nvidia.config import NvidiaPostTrainingConfig
|
||||||
|
@ -25,36 +22,6 @@ from llama_stack.providers.utils.inference.model_registry import ModelRegistryHe
|
||||||
|
|
||||||
from .models import _MODEL_ENTRIES
|
from .models import _MODEL_ENTRIES
|
||||||
|
|
||||||
# Map API status to JobStatus enum
|
|
||||||
STATUS_MAPPING = {
|
|
||||||
"running": "in_progress",
|
|
||||||
"completed": "completed",
|
|
||||||
"failed": "failed",
|
|
||||||
"cancelled": "cancelled",
|
|
||||||
"pending": "scheduled",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class NvidiaPostTrainingJob(PostTrainingJob):
|
|
||||||
"""Parse the response from the Customizer API.
|
|
||||||
Inherits job_uuid from PostTrainingJob.
|
|
||||||
Adds status, created_at, updated_at parameters.
|
|
||||||
Passes through all other parameters from data field in the response.
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_config = ConfigDict(extra="allow")
|
|
||||||
status: JobStatus
|
|
||||||
created_at: datetime
|
|
||||||
updated_at: datetime
|
|
||||||
|
|
||||||
|
|
||||||
class ListNvidiaPostTrainingJobs(BaseModel):
|
|
||||||
data: List[NvidiaPostTrainingJob]
|
|
||||||
|
|
||||||
|
|
||||||
class NvidiaPostTrainingJobStatusResponse(PostTrainingJobStatusResponse):
|
|
||||||
model_config = ConfigDict(extra="allow")
|
|
||||||
|
|
||||||
|
|
||||||
class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
||||||
def __init__(self, config: NvidiaPostTrainingConfig):
|
def __init__(self, config: NvidiaPostTrainingConfig):
|
||||||
|
@ -100,102 +67,54 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
||||||
raise Exception(f"API request failed: {error_data}")
|
raise Exception(f"API request failed: {error_data}")
|
||||||
return await response.json()
|
return await response.json()
|
||||||
|
|
||||||
async def get_training_jobs(
|
raise Exception(f"API request failed after {self.config.max_retries} retries")
|
||||||
self,
|
|
||||||
page: Optional[int] = 1,
|
|
||||||
page_size: Optional[int] = 10,
|
|
||||||
sort: Optional[Literal["created_at", "-created_at"]] = "created_at",
|
|
||||||
) -> ListNvidiaPostTrainingJobs:
|
|
||||||
"""Get all customization jobs.
|
|
||||||
Updated the base class return type from ListPostTrainingJobsResponse to ListNvidiaPostTrainingJobs.
|
|
||||||
|
|
||||||
Returns a ListNvidiaPostTrainingJobs object with the following fields:
|
@staticmethod
|
||||||
- data: List[NvidiaPostTrainingJob] - List of NvidiaPostTrainingJob objects
|
def _get_job_status(job: Dict[str, Any]) -> JobStatus:
|
||||||
|
job_status = job.get("status", "unknown").lower()
|
||||||
|
try:
|
||||||
|
return JobStatus(job_status)
|
||||||
|
except ValueError:
|
||||||
|
return JobStatus.unknown
|
||||||
|
|
||||||
|
# TODO: fetch just the necessary job from remote
|
||||||
|
async def get_post_training_job(self, job_id: str) -> PostTrainingJob:
|
||||||
|
jobs = await self.list_post_training_jobs()
|
||||||
|
for job in jobs.data:
|
||||||
|
if job.id == job_id:
|
||||||
|
return job
|
||||||
|
raise ValueError(f"Job with ID {job_id} not found")
|
||||||
|
|
||||||
|
async def list_post_training_jobs(self) -> ListPostTrainingJobsResponse:
|
||||||
|
"""Get all customization jobs.
|
||||||
|
|
||||||
ToDo: Support for schema input for filtering.
|
ToDo: Support for schema input for filtering.
|
||||||
"""
|
"""
|
||||||
params = {"page": page, "page_size": page_size, "sort": sort}
|
# TODO: don't hardcode pagination params
|
||||||
|
params = {"page": 1, "page_size": 10, "sort": "created_at"}
|
||||||
response = await self._make_request("GET", "/v1/customization/jobs", params=params)
|
response = await self._make_request("GET", "/v1/customization/jobs", params=params)
|
||||||
|
|
||||||
jobs = []
|
jobs = []
|
||||||
for job in response.get("data", []):
|
for job_dict in response.get("data", []):
|
||||||
job_id = job.pop("id")
|
# TODO: expose artifacts
|
||||||
job_status = job.pop("status", "unknown").lower()
|
job = PostTrainingJob(**job_dict)
|
||||||
mapped_status = STATUS_MAPPING.get(job_status, "unknown")
|
job.update_status(self._get_job_status(job_dict))
|
||||||
|
jobs.append(job)
|
||||||
|
|
||||||
# Convert string timestamps to datetime objects
|
return ListPostTrainingJobsResponse(data=jobs)
|
||||||
created_at = (
|
|
||||||
datetime.fromisoformat(job.pop("created_at"))
|
|
||||||
if "created_at" in job
|
|
||||||
else datetime.now(tz=datetime.timezone.utc)
|
|
||||||
)
|
|
||||||
updated_at = (
|
|
||||||
datetime.fromisoformat(job.pop("updated_at"))
|
|
||||||
if "updated_at" in job
|
|
||||||
else datetime.now(tz=datetime.timezone.utc)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create NvidiaPostTrainingJob instance
|
async def update_post_training_job(self, job_id: str, status: JobStatus | None = None) -> PostTrainingJob:
|
||||||
jobs.append(
|
if status is None:
|
||||||
NvidiaPostTrainingJob(
|
raise ValueError("Status must be provided")
|
||||||
job_uuid=job_id,
|
if status not in {JobStatus.cancelled}:
|
||||||
status=JobStatus(mapped_status),
|
raise ValueError(f"Unsupported status: {status}")
|
||||||
created_at=created_at,
|
|
||||||
updated_at=updated_at,
|
|
||||||
**job,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return ListNvidiaPostTrainingJobs(data=jobs)
|
|
||||||
|
|
||||||
async def get_training_job_status(self, job_uuid: str) -> NvidiaPostTrainingJobStatusResponse:
|
|
||||||
"""Get the status of a customization job.
|
|
||||||
Updated the base class return type from PostTrainingJobResponse to NvidiaPostTrainingJob.
|
|
||||||
|
|
||||||
Returns a NvidiaPostTrainingJob object with the following fields:
|
|
||||||
- job_uuid: str - Unique identifier for the job
|
|
||||||
- status: JobStatus - Current status of the job (in_progress, completed, failed, cancelled, scheduled)
|
|
||||||
- created_at: datetime - The time when the job was created
|
|
||||||
- updated_at: datetime - The last time the job status was updated
|
|
||||||
|
|
||||||
Additional fields that may be included:
|
|
||||||
- steps_completed: Optional[int] - Number of training steps completed
|
|
||||||
- epochs_completed: Optional[int] - Number of epochs completed
|
|
||||||
- percentage_done: Optional[float] - Percentage of training completed (0-100)
|
|
||||||
- best_epoch: Optional[int] - The epoch with the best performance
|
|
||||||
- train_loss: Optional[float] - Training loss of the best checkpoint
|
|
||||||
- val_loss: Optional[float] - Validation loss of the best checkpoint
|
|
||||||
- metrics: Optional[Dict] - Additional training metrics
|
|
||||||
- status_logs: Optional[List] - Detailed logs of status changes
|
|
||||||
"""
|
|
||||||
response = await self._make_request(
|
|
||||||
"GET",
|
|
||||||
f"/v1/customization/jobs/{job_uuid}/status",
|
|
||||||
params={"job_id": job_uuid},
|
|
||||||
)
|
|
||||||
|
|
||||||
api_status = response.pop("status").lower()
|
|
||||||
mapped_status = STATUS_MAPPING.get(api_status, "unknown")
|
|
||||||
|
|
||||||
return NvidiaPostTrainingJobStatusResponse(
|
|
||||||
status=JobStatus(mapped_status),
|
|
||||||
job_uuid=job_uuid,
|
|
||||||
started_at=datetime.fromisoformat(response.pop("created_at")),
|
|
||||||
updated_at=datetime.fromisoformat(response.pop("updated_at")),
|
|
||||||
**response,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def cancel_training_job(self, job_uuid: str) -> None:
|
|
||||||
await self._make_request(
|
await self._make_request(
|
||||||
method="POST", path=f"/v1/customization/jobs/{job_uuid}/cancel", params={"job_id": job_uuid}
|
method="POST", path=f"/v1/customization/jobs/{job_id}/cancel", params={"job_id": job_id}
|
||||||
)
|
)
|
||||||
|
return await self.get_post_training_job(job_id)
|
||||||
|
|
||||||
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
|
async def delete_post_training_job(self, job_id: str) -> None:
|
||||||
raise NotImplementedError("Job artifacts are not implemented yet")
|
raise NotImplementedError("Delete job is not implemented yet")
|
||||||
|
|
||||||
async def get_post_training_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
|
|
||||||
raise NotImplementedError("Job artifacts are not implemented yet")
|
|
||||||
|
|
||||||
async def supervised_fine_tune(
|
async def supervised_fine_tune(
|
||||||
self,
|
self,
|
||||||
|
@ -206,7 +125,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
||||||
model: str,
|
model: str,
|
||||||
checkpoint_dir: Optional[str],
|
checkpoint_dir: Optional[str],
|
||||||
algorithm_config: Optional[AlgorithmConfig] = None,
|
algorithm_config: Optional[AlgorithmConfig] = None,
|
||||||
) -> NvidiaPostTrainingJob:
|
) -> PostTrainingJob:
|
||||||
"""
|
"""
|
||||||
Fine-tunes a model on a dataset.
|
Fine-tunes a model on a dataset.
|
||||||
Currently only supports Lora finetuning for standlone docker container.
|
Currently only supports Lora finetuning for standlone docker container.
|
||||||
|
@ -409,15 +328,12 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
||||||
headers={"Accept": "application/json"},
|
headers={"Accept": "application/json"},
|
||||||
json=job_config,
|
json=job_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
job_uuid = response["id"]
|
|
||||||
response.pop("status")
|
response.pop("status")
|
||||||
created_at = datetime.fromisoformat(response.pop("created_at"))
|
|
||||||
updated_at = datetime.fromisoformat(response.pop("updated_at"))
|
|
||||||
|
|
||||||
return NvidiaPostTrainingJob(
|
# TODO: expose artifacts
|
||||||
job_uuid=job_uuid, status=JobStatus.in_progress, created_at=created_at, updated_at=updated_at, **response
|
job = PostTrainingJob(**response)
|
||||||
)
|
job.update_status(JobStatus.running)
|
||||||
|
return job
|
||||||
|
|
||||||
async def preference_optimize(
|
async def preference_optimize(
|
||||||
self,
|
self,
|
||||||
|
@ -430,6 +346,3 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
||||||
) -> PostTrainingJob:
|
) -> PostTrainingJob:
|
||||||
"""Optimize a model based on preference data."""
|
"""Optimize a model based on preference data."""
|
||||||
raise NotImplementedError("Preference optimization is not implemented yet")
|
raise NotImplementedError("Preference optimization is not implemented yet")
|
||||||
|
|
||||||
async def get_training_job_container_logs(self, job_uuid: str) -> PostTrainingJobStatusResponse:
|
|
||||||
raise NotImplementedError("Job logs are not implemented yet")
|
|
||||||
|
|
|
@ -562,6 +562,15 @@ else:
|
||||||
return typing.get_type_hints(typ)
|
return typing.get_type_hints(typ)
|
||||||
|
|
||||||
|
|
||||||
|
def get_computed_fields(typ: type) -> dict[str, type]:
|
||||||
|
"Returns all computed fields of a class."
|
||||||
|
pydantic_decorators = getattr(typ, "__pydantic_decorators__", None)
|
||||||
|
if not pydantic_decorators:
|
||||||
|
return {}
|
||||||
|
computed_fields = pydantic_decorators.computed_fields
|
||||||
|
return {field_name: decorator.info.return_type for field_name, decorator in computed_fields.items()}
|
||||||
|
|
||||||
|
|
||||||
def get_class_properties(typ: type) -> Iterable[Tuple[str, type | str]]:
|
def get_class_properties(typ: type) -> Iterable[Tuple[str, type | str]]:
|
||||||
"Returns all properties of a class."
|
"Returns all properties of a class."
|
||||||
|
|
||||||
|
@ -569,7 +578,8 @@ def get_class_properties(typ: type) -> Iterable[Tuple[str, type | str]]:
|
||||||
return ((field.name, field.type) for field in dataclasses.fields(typ))
|
return ((field.name, field.type) for field in dataclasses.fields(typ))
|
||||||
else:
|
else:
|
||||||
resolved_hints = get_resolved_hints(typ)
|
resolved_hints = get_resolved_hints(typ)
|
||||||
return resolved_hints.items()
|
computed_fields = get_computed_fields(typ)
|
||||||
|
return (resolved_hints | computed_fields).items()
|
||||||
|
|
||||||
|
|
||||||
def get_class_property(typ: type, name: str) -> Optional[type | str]:
|
def get_class_property(typ: type, name: str) -> Optional[type | str]:
|
||||||
|
|
|
@ -8,14 +8,12 @@ from typing import List
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from llama_stack.apis.common.job_types import JobStatus
|
from llama_stack.apis.common.job_types import JobStatus
|
||||||
|
from llama_stack.apis.common.training_types import Checkpoint
|
||||||
from llama_stack.apis.post_training import (
|
from llama_stack.apis.post_training import (
|
||||||
Checkpoint,
|
|
||||||
DataConfig,
|
DataConfig,
|
||||||
LoraFinetuningConfig,
|
LoraFinetuningConfig,
|
||||||
OptimizerConfig,
|
OptimizerConfig,
|
||||||
PostTrainingJob,
|
PostTrainingJob,
|
||||||
PostTrainingJobArtifactsResponse,
|
|
||||||
PostTrainingJobStatusResponse,
|
|
||||||
TrainingConfig,
|
TrainingConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -84,7 +82,6 @@ class TestPostTraining:
|
||||||
async def test_get_training_job_status(self, post_training_stack):
|
async def test_get_training_job_status(self, post_training_stack):
|
||||||
post_training_impl = post_training_stack
|
post_training_impl = post_training_stack
|
||||||
job_status = await post_training_impl.get_training_job_status("1234")
|
job_status = await post_training_impl.get_training_job_status("1234")
|
||||||
assert isinstance(job_status, PostTrainingJobStatusResponse)
|
|
||||||
assert job_status.job_uuid == "1234"
|
assert job_status.job_uuid == "1234"
|
||||||
assert job_status.status == JobStatus.completed
|
assert job_status.status == JobStatus.completed
|
||||||
assert isinstance(job_status.checkpoints[0], Checkpoint)
|
assert isinstance(job_status.checkpoints[0], Checkpoint)
|
||||||
|
@ -93,7 +90,6 @@ class TestPostTraining:
|
||||||
async def test_get_training_job_artifacts(self, post_training_stack):
|
async def test_get_training_job_artifacts(self, post_training_stack):
|
||||||
post_training_impl = post_training_stack
|
post_training_impl = post_training_stack
|
||||||
job_artifacts = await post_training_impl.get_training_job_artifacts("1234")
|
job_artifacts = await post_training_impl.get_training_job_artifacts("1234")
|
||||||
assert isinstance(job_artifacts, PostTrainingJobArtifactsResponse)
|
|
||||||
assert job_artifacts.job_uuid == "1234"
|
assert job_artifacts.job_uuid == "1234"
|
||||||
assert isinstance(job_artifacts.checkpoints[0], Checkpoint)
|
assert isinstance(job_artifacts.checkpoints[0], Checkpoint)
|
||||||
assert job_artifacts.checkpoints[0].identifier == "Llama3.2-3B-Instruct-sft-0"
|
assert job_artifacts.checkpoints[0].identifier == "Llama3.2-3B-Instruct-sft-0"
|
||||||
|
|
|
@ -17,12 +17,14 @@ from llama_stack_client.types.post_training_supervised_fine_tune_params import (
|
||||||
TrainingConfigOptimizerConfig,
|
TrainingConfigOptimizerConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from llama_stack.apis.common.job_types import JobStatus
|
||||||
|
from llama_stack.apis.post_training import (
|
||||||
|
ListPostTrainingJobsResponse,
|
||||||
|
PostTrainingJob,
|
||||||
|
)
|
||||||
from llama_stack.providers.remote.post_training.nvidia.post_training import (
|
from llama_stack.providers.remote.post_training.nvidia.post_training import (
|
||||||
ListNvidiaPostTrainingJobs,
|
|
||||||
NvidiaPostTrainingAdapter,
|
NvidiaPostTrainingAdapter,
|
||||||
NvidiaPostTrainingConfig,
|
NvidiaPostTrainingConfig,
|
||||||
NvidiaPostTrainingJob,
|
|
||||||
NvidiaPostTrainingJobStatusResponse,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -49,21 +51,25 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
||||||
|
|
||||||
def _assert_request(self, mock_call, expected_method, expected_path, expected_params=None, expected_json=None):
|
def _assert_request(self, mock_call, expected_method, expected_path, expected_params=None, expected_json=None):
|
||||||
"""Helper method to verify request details in mock calls."""
|
"""Helper method to verify request details in mock calls."""
|
||||||
call_args = mock_call.call_args
|
found = False
|
||||||
|
for call_args in mock_call.call_args_list:
|
||||||
|
if expected_method and expected_path:
|
||||||
|
if isinstance(call_args[0], tuple) and len(call_args[0]) == 2:
|
||||||
|
if call_args[0] == (expected_method, expected_path):
|
||||||
|
found = True
|
||||||
|
else:
|
||||||
|
if call_args[1]["method"] == expected_method and call_args[1]["path"] == expected_path:
|
||||||
|
found = True
|
||||||
|
|
||||||
if expected_method and expected_path:
|
if expected_params:
|
||||||
if isinstance(call_args[0], tuple) and len(call_args[0]) == 2:
|
if call_args[1]["params"] == expected_params:
|
||||||
assert call_args[0] == (expected_method, expected_path)
|
found = True
|
||||||
else:
|
|
||||||
assert call_args[1]["method"] == expected_method
|
|
||||||
assert call_args[1]["path"] == expected_path
|
|
||||||
|
|
||||||
if expected_params:
|
if expected_json:
|
||||||
assert call_args[1]["params"] == expected_params
|
for key, value in expected_json.items():
|
||||||
|
if call_args[1]["json"][key] == value:
|
||||||
if expected_json:
|
found = True
|
||||||
for key, value in expected_json.items():
|
assert found
|
||||||
assert call_args[1]["json"][key] == value
|
|
||||||
|
|
||||||
def test_supervised_fine_tune(self):
|
def test_supervised_fine_tune(self):
|
||||||
"""Test the supervised fine-tuning API call."""
|
"""Test the supervised fine-tuning API call."""
|
||||||
|
@ -151,9 +157,8 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# check the output is a PostTrainingJob
|
assert isinstance(training_job, PostTrainingJob)
|
||||||
assert isinstance(training_job, NvidiaPostTrainingJob)
|
assert training_job.id == "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
||||||
assert training_job.job_uuid == "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
|
||||||
|
|
||||||
self.mock_make_request.assert_called_once()
|
self.mock_make_request.assert_called_once()
|
||||||
self._assert_request(
|
self._assert_request(
|
||||||
|
@ -199,38 +204,7 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_get_training_job_status(self):
|
def test_list_post_training_jobs(self):
|
||||||
self.mock_make_request.return_value = {
|
|
||||||
"created_at": "2024-12-09T04:06:28.580220",
|
|
||||||
"updated_at": "2024-12-09T04:21:19.852832",
|
|
||||||
"status": "completed",
|
|
||||||
"steps_completed": 1210,
|
|
||||||
"epochs_completed": 2,
|
|
||||||
"percentage_done": 100.0,
|
|
||||||
"best_epoch": 2,
|
|
||||||
"train_loss": 1.718016266822815,
|
|
||||||
"val_loss": 1.8661999702453613,
|
|
||||||
}
|
|
||||||
|
|
||||||
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
|
||||||
|
|
||||||
status = self.run_async(self.adapter.get_training_job_status(job_uuid=job_id))
|
|
||||||
|
|
||||||
assert isinstance(status, NvidiaPostTrainingJobStatusResponse)
|
|
||||||
assert status.status.value == "completed"
|
|
||||||
assert status.steps_completed == 1210
|
|
||||||
assert status.epochs_completed == 2
|
|
||||||
assert status.percentage_done == 100.0
|
|
||||||
assert status.best_epoch == 2
|
|
||||||
assert status.train_loss == 1.718016266822815
|
|
||||||
assert status.val_loss == 1.8661999702453613
|
|
||||||
|
|
||||||
self.mock_make_request.assert_called_once()
|
|
||||||
self._assert_request(
|
|
||||||
self.mock_make_request, "GET", f"/v1/customization/jobs/{job_id}/status", expected_params={"job_id": job_id}
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_get_training_jobs(self):
|
|
||||||
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
||||||
self.mock_make_request.return_value = {
|
self.mock_make_request.return_value = {
|
||||||
"data": [
|
"data": [
|
||||||
|
@ -258,12 +232,12 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
jobs = self.run_async(self.adapter.get_training_jobs())
|
jobs = self.run_async(self.adapter.list_post_training_jobs())
|
||||||
|
|
||||||
assert isinstance(jobs, ListNvidiaPostTrainingJobs)
|
assert isinstance(jobs, ListPostTrainingJobsResponse)
|
||||||
assert len(jobs.data) == 1
|
assert len(jobs.data) == 1
|
||||||
job = jobs.data[0]
|
job = jobs.data[0]
|
||||||
assert job.job_uuid == job_id
|
assert job.id == job_id
|
||||||
assert job.status.value == "completed"
|
assert job.status.value == "completed"
|
||||||
|
|
||||||
self.mock_make_request.assert_called_once()
|
self.mock_make_request.assert_called_once()
|
||||||
|
@ -275,14 +249,36 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_cancel_training_job(self):
|
def test_cancel_training_job(self):
|
||||||
self.mock_make_request.return_value = {} # Empty response for successful cancellation
|
|
||||||
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
||||||
|
self.mock_make_request.return_value = {
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"id": job_id,
|
||||||
|
"created_at": "2024-12-09T04:06:28.542884",
|
||||||
|
"updated_at": "2024-12-09T04:21:19.852832",
|
||||||
|
"config": {
|
||||||
|
"name": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
"base_model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
},
|
||||||
|
"dataset": {"name": "default/sample-basic-test"},
|
||||||
|
"hyperparameters": {
|
||||||
|
"finetuning_type": "lora",
|
||||||
|
"training_type": "sft",
|
||||||
|
"batch_size": 16,
|
||||||
|
"epochs": 2,
|
||||||
|
"learning_rate": 0.0001,
|
||||||
|
"lora": {"adapter_dim": 16, "adapter_dropout": 0.1},
|
||||||
|
},
|
||||||
|
"output_model": "default/job-1234",
|
||||||
|
"status": "completed",
|
||||||
|
"project": "default",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
result = self.run_async(self.adapter.cancel_training_job(job_uuid=job_id))
|
result = self.run_async(self.adapter.update_post_training_job(job_id=job_id, status=JobStatus.cancelled))
|
||||||
|
assert result.id == job_id
|
||||||
|
|
||||||
assert result is None
|
|
||||||
|
|
||||||
self.mock_make_request.assert_called_once()
|
|
||||||
self._assert_request(
|
self._assert_request(
|
||||||
self.mock_make_request,
|
self.mock_make_request,
|
||||||
"POST",
|
"POST",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue