temp commit

This commit is contained in:
Botao Chen 2024-12-09 20:24:30 -08:00
parent 9c1ae088f9
commit c9a009b5e7
7 changed files with 268 additions and 53 deletions

View file

@ -18,3 +18,5 @@ class Job(BaseModel):
class JobStatus(Enum):
completed = "completed"
in_progress = "in_progress"
failed = "failed"
scheduled = "scheduled"

View file

@ -4,13 +4,26 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.llama3.api.datatypes import URL
from datetime import datetime
from typing import Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel
@json_schema_type
class PostTrainingMetric(BaseModel):
epoch: int
train_loss: float
validation_loss: float
perplexity: float
@json_schema_type(schema={"description": "Checkpoint created during training runs"})
class Checkpoint(BaseModel):
iters: int
path: URL
identifier: str
created_at: datetime
epoch: int
post_training_job_id: str
path: str
training_metrics: Optional[PostTrainingMetric]

View file

@ -14,6 +14,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.common.job_types import JobStatus
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.common.training_types import * # noqa: F403
@ -87,14 +88,6 @@ class PostTrainingJobLogStream(BaseModel):
log_lines: List[str]
@json_schema_type
class PostTrainingJobStatus(Enum):
running = "running"
completed = "completed"
failed = "failed"
scheduled = "scheduled"
@json_schema_type
class RLHFAlgorithm(Enum):
dpo = "dpo"
@ -139,7 +132,7 @@ class PostTrainingJobStatusResponse(BaseModel):
"""Status of a finetuning job."""
job_uuid: str
status: PostTrainingJobStatus
status: JobStatus
scheduled_at: Optional[datetime] = None
started_at: Optional[datetime] = None
@ -192,16 +185,10 @@ class PostTraining(Protocol):
@webmethod(route="/post-training/jobs")
async def get_training_jobs(self) -> List[PostTrainingJob]: ...
# sends SSE stream of logs
@webmethod(route="/post-training/job/logs")
async def get_training_job_logstream(
self, job_uuid: str
) -> PostTrainingJobLogStream: ...
@webmethod(route="/post-training/job/status")
async def get_training_job_status(
self, job_uuid: str
) -> PostTrainingJobStatusResponse: ...
) -> Optional[PostTrainingJobStatusResponse]: ...
@webmethod(route="/post-training/job/cancel")
async def cancel_training_job(self, job_uuid: str) -> None: ...