mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
finetuning
This commit is contained in:
parent
956f07b04c
commit
69ecf55de2
5 changed files with 1334 additions and 28 deletions
|
@ -1,6 +1,7 @@
|
|||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Protocol, Set, Union, Tuple
|
||||
from typing import Any, Dict, List, Optional, Protocol, Set, Tuple, Union
|
||||
|
||||
import yaml
|
||||
from agentic_system_types import (
|
||||
|
@ -10,6 +11,18 @@ from agentic_system_types import (
|
|||
SafetyViolation,
|
||||
)
|
||||
|
||||
from finetuning_types import (
|
||||
Dataset,
|
||||
DoraFinetuningConfig,
|
||||
FinetuningAlgorithm,
|
||||
FinetuningJobLogStream,
|
||||
FinetuningJobStatus,
|
||||
LoraFinetuningConfig,
|
||||
OptimizerConfig,
|
||||
QLoraFinetuningConfig,
|
||||
TrainingConfig,
|
||||
)
|
||||
|
||||
from model_types import (
|
||||
BuiltinTool,
|
||||
Content,
|
||||
|
@ -22,6 +35,7 @@ from model_types import (
|
|||
ToolCall,
|
||||
ToolDefinition,
|
||||
ToolResponse,
|
||||
URL,
|
||||
)
|
||||
|
||||
from pyopenapi import Info, Options, Server, Specification, webmethod
|
||||
|
@ -205,6 +219,7 @@ class AgenticSystem(Protocol):
|
|||
|
||||
@dataclass
|
||||
class PromptGeneration:
|
||||
# TODO(ashwin): probably create a Dialog type which is used everywhere including chat completion
|
||||
prompt: Message
|
||||
message_history: List[Message]
|
||||
generation: Message
|
||||
|
@ -286,8 +301,99 @@ class SyntheticDataGeneration(Protocol):
|
|||
) -> Union[SyntheticDataGenerationResponse]: ...
|
||||
|
||||
|
||||
@json_schema_type
|
||||
@dataclass
|
||||
class CreateDatasetRequest:
|
||||
"""Request to create a dataset."""
|
||||
|
||||
uuid: str
|
||||
dataset: Dataset
|
||||
|
||||
|
||||
class Datasets(Protocol):
|
||||
@webmethod(route="/datasets/create")
|
||||
def create_dataset(
|
||||
self,
|
||||
request: CreateDatasetRequest,
|
||||
) -> None: ...
|
||||
|
||||
@webmethod(route="/datasets/get")
|
||||
def get_dataset(
|
||||
self,
|
||||
dataset_id: str,
|
||||
) -> Dataset: ...
|
||||
|
||||
@webmethod(route="/datasets/delete")
|
||||
def delete_dataset(
|
||||
self,
|
||||
dataset_id: str,
|
||||
) -> None: ...
|
||||
|
||||
|
||||
@json_schema_type
|
||||
@dataclass
|
||||
class FinetuningTrainRequest:
|
||||
"""Request to finetune a model."""
|
||||
|
||||
job_uuid: str
|
||||
|
||||
model: PretrainedModel
|
||||
dataset: Dataset
|
||||
validation_dataset: Dataset
|
||||
|
||||
algorithm: FinetuningAlgorithm
|
||||
algorithm_config: Union[
|
||||
LoraFinetuningConfig, QLoraFinetuningConfig, DoraFinetuningConfig
|
||||
]
|
||||
|
||||
optimizer_config: OptimizerConfig
|
||||
training_config: TrainingConfig
|
||||
|
||||
# TODO: define these
|
||||
hyperparam_search_config: Dict[str, Any]
|
||||
logger_config: Dict[str, Any]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
@dataclass
|
||||
class FinetuningJobStatusResponse:
|
||||
"""Status of a finetuning job."""
|
||||
|
||||
job_uuid: str
|
||||
status: FinetuningJobStatus
|
||||
|
||||
scheduled_at: Optional[datetime] = None
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
|
||||
resources_allocated: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class Finetuning(Protocol):
|
||||
@webmethod(route="/finetuning/text_generation/train")
|
||||
def post_train(
|
||||
self,
|
||||
request: FinetuningTrainRequest,
|
||||
) -> None: ...
|
||||
|
||||
# sends SSE stream of logs
|
||||
@webmethod(route="/finetuning/job/logs")
|
||||
def get_training_log_stream(self, job_uuid: str) -> FinetuningJobLogStream: ...
|
||||
|
||||
@webmethod(route="/finetuning/job/status")
|
||||
def get_training_job_status(self, job_uuid: str) -> FinetuningJobStatusResponse: ...
|
||||
|
||||
@webmethod(route="/finetuning/job/cancel")
|
||||
def cancel_training_job(self, job_uuid: str) -> None: ...
|
||||
|
||||
|
||||
class LlamaStackEndpoints(
|
||||
Inference, AgenticSystem, RewardScoring, SyntheticDataGeneration
|
||||
Inference,
|
||||
AgenticSystem,
|
||||
RewardScoring,
|
||||
SyntheticDataGeneration,
|
||||
Datasets,
|
||||
Finetuning,
|
||||
): ...
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue