mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 22:49:47 +00:00
temp commit
This commit is contained in:
parent
6c709abc4d
commit
79c525be94
5 changed files with 115 additions and 163 deletions
|
|
@ -41,7 +41,7 @@ class TrainingConfig(BaseModel):
|
|||
gradient_accumulation_steps: int
|
||||
batch_size: int
|
||||
shuffle: bool
|
||||
# n_iters: int
|
||||
optimizer_config: OptimizerConfig
|
||||
|
||||
enable_activation_checkpointing: bool
|
||||
memory_efficient_fsdp_wrap: Optional[bool]
|
||||
|
|
@ -63,6 +63,7 @@ class LoraFinetuningConfig(BaseModel):
|
|||
apply_lora_to_output: bool
|
||||
rank: int
|
||||
alpha: int
|
||||
use_dora: bool
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -116,7 +117,6 @@ class PostTrainingSFTRequest(BaseModel):
|
|||
|
||||
algorithm: FinetuningAlgorithm
|
||||
algorithm_config: LoraFinetuningConfig
|
||||
optimizer_config: OptimizerConfig
|
||||
training_config: TrainingConfig
|
||||
|
||||
# TODO: define these
|
||||
|
|
@ -178,7 +178,7 @@ class PostTrainingJobArtifactsResponse(BaseModel):
|
|||
|
||||
class PostTraining(Protocol):
|
||||
@webmethod(route="/post-training/supervised-fine-tune")
|
||||
def supervised_fine_tune(
|
||||
async def supervised_fine_tune(
|
||||
self,
|
||||
job_uuid: str,
|
||||
model: str,
|
||||
|
|
@ -186,14 +186,14 @@ class PostTraining(Protocol):
|
|||
validation_dataset_id: str,
|
||||
algorithm: FinetuningAlgorithm,
|
||||
algorithm_config: LoraFinetuningConfig,
|
||||
optimizer_config: OptimizerConfig,
|
||||
# optimizer_config: OptimizerConfig,
|
||||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: Dict[str, Any],
|
||||
logger_config: Dict[str, Any],
|
||||
) -> PostTrainingJob: ...
|
||||
|
||||
@webmethod(route="/post-training/preference-optimize")
|
||||
def preference_optimize(
|
||||
async def preference_optimize(
|
||||
self,
|
||||
job_uuid: str,
|
||||
finetuned_model: URL,
|
||||
|
|
@ -208,21 +208,23 @@ class PostTraining(Protocol):
|
|||
) -> PostTrainingJob: ...
|
||||
|
||||
@webmethod(route="/post-training/jobs")
|
||||
def get_training_jobs(self) -> List[PostTrainingJob]: ...
|
||||
async def get_training_jobs(self) -> List[PostTrainingJob]: ...
|
||||
|
||||
# sends SSE stream of logs
|
||||
@webmethod(route="/post-training/job/logs")
|
||||
def get_training_job_logstream(self, job_uuid: str) -> PostTrainingJobLogStream: ...
|
||||
async def get_training_job_logstream(
|
||||
self, job_uuid: str
|
||||
) -> PostTrainingJobLogStream: ...
|
||||
|
||||
@webmethod(route="/post-training/job/status")
|
||||
def get_training_job_status(
|
||||
async def get_training_job_status(
|
||||
self, job_uuid: str
|
||||
) -> PostTrainingJobStatusResponse: ...
|
||||
|
||||
@webmethod(route="/post-training/job/cancel")
|
||||
def cancel_training_job(self, job_uuid: str) -> None: ...
|
||||
async def cancel_training_job(self, job_uuid: str) -> None: ...
|
||||
|
||||
@webmethod(route="/post-training/job/artifacts")
|
||||
def get_training_job_artifacts(
|
||||
async def get_training_job_artifacts(
|
||||
self, job_uuid: str
|
||||
) -> PostTrainingJobArtifactsResponse: ...
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue