diff --git a/llama_stack/providers/inline/post_training/torchtune/common/utils.py b/llama_stack/providers/inline/post_training/torchtune/common/utils.py index 462cbc21e..f2a2edae5 100644 --- a/llama_stack/providers/inline/post_training/torchtune/common/utils.py +++ b/llama_stack/providers/inline/post_training/torchtune/common/utils.py @@ -14,11 +14,10 @@ from enum import Enum from typing import Any, Callable, Dict, List import torch -from llama_stack.apis.datasets import Datasets -from llama_stack.apis.common.type_system import * # noqa from llama_models.datatypes import Model from llama_models.sku_list import resolve_model -from llama_stack.apis.common.type_system import ParamType +from llama_stack.apis.common.type_system import ParamType, StringType +from llama_stack.apis.datasets import Datasets from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b from torchtune.models.llama3._tokenizer import Llama3Tokenizer diff --git a/llama_stack/providers/inline/post_training/torchtune/post_training.py b/llama_stack/providers/inline/post_training/torchtune/post_training.py index 9b1269f16..90fbf7026 100644 --- a/llama_stack/providers/inline/post_training/torchtune/post_training.py +++ b/llama_stack/providers/inline/post_training/torchtune/post_training.py @@ -3,11 +3,26 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from datetime import datetime +from typing import Any, Dict, List, Optional + +from llama_models.schema_utils import webmethod + from llama_stack.apis.datasetio import DatasetIO +from llama_stack.apis.datasets import Datasets +from llama_stack.apis.post_training import ( + AlgorithmConfig, + DPOAlignmentConfig, + JobStatus, + LoraFinetuningConfig, + PostTrainingJob, + PostTrainingJobArtifactsResponse, + PostTrainingJobStatusResponse, + TrainingConfig, +) from llama_stack.providers.inline.post_training.torchtune.config import ( TorchtunePostTrainingConfig, ) -from llama_stack.apis.post_training import * # noqa from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import ( LoraFinetuningSingleDevice, ) diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index 71b8bf759..517be6d89 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -14,27 +14,33 @@ from typing import Any, Dict, List, Optional, Tuple import torch from llama_models.sku_list import resolve_model +from llama_stack.apis.common.training_types import PostTrainingMetric from llama_stack.apis.datasetio import DatasetIO +from llama_stack.apis.datasets import Datasets +from llama_stack.apis.post_training import ( + AlgorithmConfig, + Checkpoint, + LoraFinetuningConfig, + OptimizerConfig, + TrainingConfig, +) from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR -from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import ( - TorchtuneCheckpointer, -) -from torch import nn -from torchtune import utils as torchtune_utils -from torchtune.training.metric_logging import DiskLogger -from tqdm import tqdm -from llama_stack.apis.post_training import * # noqa + from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.providers.inline.post_training.torchtune.common import utils +from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import ( + TorchtuneCheckpointer, +) from llama_stack.providers.inline.post_training.torchtune.config import ( TorchtunePostTrainingConfig, ) from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset +from torch import nn from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler -from torchtune import modules, training +from torchtune import modules, training, utils as torchtune_utils from torchtune.data import AlpacaToMessages, padded_collate_sft from torchtune.modules.loss import CEWithChunkedOutputLoss @@ -47,6 +53,8 @@ from torchtune.modules.peft import ( validate_missing_and_unexpected_for_lora, ) from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup +from torchtune.training.metric_logging import DiskLogger +from tqdm import tqdm log = logging.getLogger(__name__) diff --git a/llama_stack/providers/inline/safety/code_scanner/code_scanner.py b/llama_stack/providers/inline/safety/code_scanner/code_scanner.py index 46b5e57da..87d68f74c 100644 --- a/llama_stack/providers/inline/safety/code_scanner/code_scanner.py +++ b/llama_stack/providers/inline/safety/code_scanner/code_scanner.py @@ -7,8 +7,14 @@ import logging from typing import Any, Dict, List -from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.inference import Message +from llama_stack.apis.safety import ( + RunShieldResponse, + Safety, + SafetyViolation, + ViolationLevel, +) +from llama_stack.apis.shields import Shield from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, )