finetuning

This commit is contained in:
Ashwin Bharambe 2024-07-10 20:47:05 -07:00
parent 956f07b04c
commit 69ecf55de2
5 changed files with 1334 additions and 28 deletions

View file

@ -1,6 +1,7 @@
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional, Protocol, Set, Union, Tuple
from typing import Any, Dict, List, Optional, Protocol, Set, Tuple, Union
import yaml
from agentic_system_types import (
@ -10,6 +11,18 @@ from agentic_system_types import (
SafetyViolation,
)
from finetuning_types import (
Dataset,
DoraFinetuningConfig,
FinetuningAlgorithm,
FinetuningJobLogStream,
FinetuningJobStatus,
LoraFinetuningConfig,
OptimizerConfig,
QLoraFinetuningConfig,
TrainingConfig,
)
from model_types import (
BuiltinTool,
Content,
@ -22,6 +35,7 @@ from model_types import (
ToolCall,
ToolDefinition,
ToolResponse,
URL,
)
from pyopenapi import Info, Options, Server, Specification, webmethod
@ -205,6 +219,7 @@ class AgenticSystem(Protocol):
@dataclass
class PromptGeneration:
# TODO(ashwin): probably create a Dialog type which is used everywhere including chat completion
prompt: Message
message_history: List[Message]
generation: Message
@ -286,8 +301,99 @@ class SyntheticDataGeneration(Protocol):
) -> Union[SyntheticDataGenerationResponse]: ...
@json_schema_type
@dataclass
class CreateDatasetRequest:
"""Request to create a dataset."""
uuid: str
dataset: Dataset
class Datasets(Protocol):
@webmethod(route="/datasets/create")
def create_dataset(
self,
request: CreateDatasetRequest,
) -> None: ...
@webmethod(route="/datasets/get")
def get_dataset(
self,
dataset_id: str,
) -> Dataset: ...
@webmethod(route="/datasets/delete")
def delete_dataset(
self,
dataset_id: str,
) -> None: ...
@json_schema_type
@dataclass
class FinetuningTrainRequest:
"""Request to finetune a model."""
job_uuid: str
model: PretrainedModel
dataset: Dataset
validation_dataset: Dataset
algorithm: FinetuningAlgorithm
algorithm_config: Union[
LoraFinetuningConfig, QLoraFinetuningConfig, DoraFinetuningConfig
]
optimizer_config: OptimizerConfig
training_config: TrainingConfig
# TODO: define these
hyperparam_search_config: Dict[str, Any]
logger_config: Dict[str, Any]
@json_schema_type
@dataclass
class FinetuningJobStatusResponse:
"""Status of a finetuning job."""
job_uuid: str
status: FinetuningJobStatus
scheduled_at: Optional[datetime] = None
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
resources_allocated: Optional[Dict[str, Any]] = None
class Finetuning(Protocol):
@webmethod(route="/finetuning/text_generation/train")
def post_train(
self,
request: FinetuningTrainRequest,
) -> None: ...
# sends SSE stream of logs
@webmethod(route="/finetuning/job/logs")
def get_training_log_stream(self, job_uuid: str) -> FinetuningJobLogStream: ...
@webmethod(route="/finetuning/job/status")
def get_training_job_status(self, job_uuid: str) -> FinetuningJobStatusResponse: ...
@webmethod(route="/finetuning/job/cancel")
def cancel_training_job(self, job_uuid: str) -> None: ...
class LlamaStackEndpoints(
Inference, AgenticSystem, RewardScoring, SyntheticDataGeneration
Inference,
AgenticSystem,
RewardScoring,
SyntheticDataGeneration,
Datasets,
Finetuning,
): ...