From 46c5b14a22ee41b4ab46e189831d79605dd15164 Mon Sep 17 00:00:00 2001 From: Charlie Doern Date: Wed, 14 May 2025 15:43:41 -0400 Subject: [PATCH] feat: handle graceful shutdown currently this impl hangs because of `trainer.train()` blocking. Re-write the implementation to kick off the model download, device instantiation, dataset processing, and training in a monitored subprocess. All of these steps need to be in a subprocess or else different devices are used which causes torch errors. Signed-off-by: Charlie Doern --- .../inline/post_training/common/utils.py | 35 ++ .../recipes/finetune_single_device.py | 519 ++++++++++++------ .../recipes/lora_finetuning_single_device.py | 8 +- requirements.txt | 137 ----- 4 files changed, 387 insertions(+), 312 deletions(-) create mode 100644 llama_stack/providers/inline/post_training/common/utils.py diff --git a/llama_stack/providers/inline/post_training/common/utils.py b/llama_stack/providers/inline/post_training/common/utils.py new file mode 100644 index 000000000..7840b21e8 --- /dev/null +++ b/llama_stack/providers/inline/post_training/common/utils.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import gc + + +def evacuate_model_from_device(model, device: str): + """Safely clear a model from memory and free device resources. + This function handles the proper cleanup of a model by: + 1. Moving the model to CPU if it's on a non-CPU device + 2. Deleting the model object to free memory + 3. Running garbage collection + 4. Clearing CUDA cache if the model was on a CUDA device + Args: + model: The PyTorch model to clear + device: The device type the model is currently on ('cuda', 'mps', 'cpu') + Note: + - For CUDA devices, this will clear the CUDA cache after moving the model to CPU + - For MPS devices, only moves the model to CPU (no cache clearing available) + - For CPU devices, only deletes the model object and runs garbage collection + """ + if device != "cpu": + model.to("cpu") + + del model + gc.collect() + + if device == "cuda": + # we need to import such that this is only imported when the method is called + import torch + + torch.cuda.empty_cache() diff --git a/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py b/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py index fd1c68655..b6d13b029 100644 --- a/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py +++ b/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py @@ -7,16 +7,26 @@ import gc import json import logging +import multiprocessing import os +import signal +import sys from datetime import datetime, timezone from pathlib import Path from typing import Any import psutil +from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device + # Set tokenizer parallelism environment variable os.environ["TOKENIZERS_PARALLELISM"] = "false" +# Force PyTorch to use OpenBLAS instead of MKL +os.environ["MKL_THREADING_LAYER"] = "GNU" +os.environ["MKL_SERVICE_FORCE_INTEL"] = "0" +os.environ["MKL_NUM_THREADS"] = "1" + import torch from datasets import Dataset from peft import LoraConfig @@ -86,10 +96,46 @@ def get_memory_stats(device: torch.device) -> dict[str, Any]: return stats +def setup_torch_device(device_str: str) -> torch.device: + """Initialize and validate a PyTorch device. + This function handles device initialization and validation for different device types: + - CUDA: Validates CUDA availability and handles device selection + - MPS: Validates MPS availability for Apple Silicon + - CPU: Basic validation + - HPU: Raises error as it's not supported + Args: + device_str: String specifying the device ('cuda', 'cpu', 'mps') + Returns: + torch.device: The initialized and validated device + Raises: + RuntimeError: If device initialization fails or device is not supported + """ + try: + device = torch.device(device_str) + except RuntimeError as e: + raise RuntimeError(f"Error getting Torch Device {str(e)}") from e + + # Validate device capabilities + if device.type == "cuda": + if not torch.cuda.is_available(): + raise RuntimeError( + f"{device.type}: Torch has no CUDA/ROCm support or could not detect a compatible device." + ) + if device.index is None: + device = torch.device(device.type, torch.cuda.current_device()) + elif device.type == "mps": + if not torch.backends.mps.is_available(): + raise RuntimeError(f"{device.type}: Torch has no MPS support or could not detect a compatible device.") + elif device.type == "hpu": + raise RuntimeError(f"{device.type}: training does not support Intel Gaudi.") + + return device + + class HFFinetuningSingleDevice: def __init__( self, - job_uuid, + job_uuid: str, datasetio_api: DatasetIO, datasets_api: Datasets, ): @@ -216,58 +262,120 @@ class HFFinetuningSingleDevice: remove_columns=ds.column_names, ) + async def _setup_data(self, dataset_id: str) -> list[dict[str, Any]]: + """Load dataset from llama stack dataset provider""" + try: + all_rows = await self.datasetio_api.iterrows( + dataset_id=dataset_id, + limit=-1, + ) + if not isinstance(all_rows.data, list): + raise RuntimeError("Expected dataset data to be a list") + return all_rows.data + except Exception as e: + raise RuntimeError(f"Failed to load dataset: {str(e)}") from e + + def _run_training_sync( + self, + model: str, + provider_config: dict[str, Any], + peft_config: LoraConfig | None, + config: dict[str, Any], + output_dir_path: Path | None, + ) -> None: + """Synchronous wrapper for running training process. + This method serves as a bridge between the multiprocessing Process and the async training function. + It creates a new event loop to run the async training process. + Args: + model: The model identifier to load + dataset_id: ID of the dataset to use for training + provider_config: Configuration specific to the HuggingFace provider + peft_config: Optional LoRA configuration + config: General training configuration + output_dir_path: Optional path to save the model + """ + import asyncio + + logger.info("Starting training process with async wrapper") + asyncio.run( + self._run_training( + model=model, + provider_config=provider_config, + peft_config=peft_config, + config=config, + output_dir_path=output_dir_path, + ) + ) + async def load_dataset( self, model: str, config: TrainingConfig, provider_config: HuggingFacePostTrainingConfig, ) -> tuple[Dataset, Dataset, AutoTokenizer]: - """Load and preprocess the dataset for training. - + """Load and prepare the dataset for training. Args: model: The model identifier to load - config: Training configuration containing dataset settings + config: Training configuration provider_config: Provider-specific configuration - Returns: - tuple containing: - - Training dataset - - Evaluation dataset - - Tokenizer - - Raises: - ValueError: If dataset is missing required fields - RuntimeError: If tokenizer initialization fails + tuple: (train_dataset, eval_dataset, tokenizer) """ - assert isinstance(config.data_config, DataConfig), "DataConfig must be initialized" - rows = await self._setup_data(config.data_config.dataset_id) + # Validate data config + if not config.data_config: + raise ValueError("DataConfig is required for training") - # Validate that the dataset has the required fields for training + # Load dataset + logger.info(f"Loading dataset: {config.data_config.dataset_id}") + rows = await self._setup_data(config.data_config.dataset_id) if not self.validate_dataset_format(rows): raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input") + logger.info(f"Loaded {len(rows)} rows from dataset") - # Initialize tokenizer with model-specific config + # Initialize tokenizer + logger.info(f"Initializing tokenizer for model: {model}") try: tokenizer = AutoTokenizer.from_pretrained(model, **provider_config.model_specific_config) - # Set up tokenizer defaults + + # Set pad token to eos token if not present + # This is common for models that don't have a dedicated pad token if not tokenizer.pad_token: tokenizer.pad_token = tokenizer.eos_token + + # Set padding side to right for causal language modeling + # This ensures that padding tokens don't interfere with the model's ability + # to predict the next token in the sequence tokenizer.padding_side = "right" + + # Set truncation side to right to keep the beginning of the sequence + # This is important for maintaining context and instruction format tokenizer.truncation_side = "right" + + # Set model max length to match provider config + # This ensures consistent sequence lengths across the training process tokenizer.model_max_length = provider_config.max_seq_length + + logger.info("Tokenizer initialized successfully") except Exception as e: raise RuntimeError(f"Failed to initialize tokenizer: {str(e)}") from e # Create and preprocess dataset + logger.info("Creating and preprocessing dataset") try: ds = self._create_dataset(rows, config, provider_config) ds = self._preprocess_dataset(ds, tokenizer, provider_config) + logger.info(f"Dataset created with {len(ds)} examples") except Exception as e: raise ValueError(f"Failed to create dataset: {str(e)}") from e - # Split dataset into train and validation + # Split dataset + logger.info("Splitting dataset into train and validation sets") train_val_split = ds.train_test_split(test_size=0.1, seed=42) - return train_val_split["train"], train_val_split["test"], tokenizer + train_dataset = train_val_split["train"] + eval_dataset = train_val_split["test"] + logger.info(f"Split dataset into {len(train_dataset)} training and {len(eval_dataset)} validation examples") + + return train_dataset, eval_dataset, tokenizer def load_model( self, @@ -276,15 +384,12 @@ class HFFinetuningSingleDevice: provider_config: HuggingFacePostTrainingConfig, ) -> AutoModelForCausalLM: """Load and initialize the model for training. - Args: model: The model identifier to load device: The device to load the model onto provider_config: Provider-specific configuration - Returns: The loaded and initialized model - Raises: RuntimeError: If model loading fails """ @@ -293,18 +398,201 @@ class HFFinetuningSingleDevice: model_config = AutoConfig.from_pretrained(model, **provider_config.model_specific_config) model_obj = AutoModelForCausalLM.from_pretrained( model, - torch_dtype="auto", + torch_dtype="auto" if device.type != "cpu" else "float32", quantization_config=None, config=model_config, **provider_config.model_specific_config, ) - if model_obj.device != device: - model_obj = model_obj.to(device) + # Always move model to specified device + model_obj = model_obj.to(device) logger.info(f"Model loaded and moved to device: {model_obj.device}") return model_obj except Exception as e: raise RuntimeError(f"Failed to load model: {str(e)}") from e + def setup_training_args( + self, + config: TrainingConfig, + provider_config: HuggingFacePostTrainingConfig, + device: torch.device, + output_dir_path: Path | None, + steps_per_epoch: int, + ) -> SFTConfig: + """Setup training arguments. + Args: + config: Training configuration + provider_config: Provider-specific configuration + device: The device to train on + output_dir_path: Optional path to save the model + steps_per_epoch: Number of steps per epoch + Returns: + Configured SFTConfig object + """ + logger.info("Configuring training arguments") + lr = 2e-5 + if config.optimizer_config: + lr = config.optimizer_config.lr + logger.info(f"Using custom learning rate: {lr}") + + # Validate data config + if not config.data_config: + raise ValueError("DataConfig is required for training") + data_config = config.data_config + + # Calculate steps + total_steps = steps_per_epoch * config.n_epochs + max_steps = min(config.max_steps_per_epoch, total_steps) + eval_steps = max(1, steps_per_epoch // 10) # Evaluate 10 times per epoch + save_steps = max(1, steps_per_epoch // 5) # Save 5 times per epoch + logging_steps = max(1, steps_per_epoch // 50) # Log 50 times per epoch + + logger.info("Training configuration:") + logger.info(f"- Steps per epoch: {steps_per_epoch}") + logger.info(f"- Total steps: {total_steps}") + logger.info(f"- Max steps: {max_steps}") + logger.info(f"- Eval steps: {eval_steps}") + logger.info(f"- Save steps: {save_steps}") + logger.info(f"- Logging steps: {logging_steps}") + + # Configure save strategy + save_strategy = "no" + if output_dir_path: + save_strategy = "steps" + logger.info(f"Will save checkpoints to {output_dir_path}") + + return SFTConfig( + max_steps=max_steps, + output_dir=str(output_dir_path) if output_dir_path is not None else None, + num_train_epochs=config.n_epochs, + per_device_train_batch_size=data_config.batch_size, + fp16=device.type == "cuda", + bf16=False, # Causes CPU issues. + eval_strategy="steps", + use_cpu=True if device.type == "cpu" and not torch.backends.mps.is_available() else False, + save_strategy=save_strategy, + report_to="none", + max_seq_length=provider_config.max_seq_length, + gradient_accumulation_steps=config.gradient_accumulation_steps, + gradient_checkpointing=provider_config.gradient_checkpointing, + learning_rate=lr, + warmup_ratio=provider_config.warmup_ratio, + weight_decay=provider_config.weight_decay, + remove_unused_columns=False, + dataloader_pin_memory=provider_config.dataloader_pin_memory, + dataloader_num_workers=provider_config.dataloader_num_workers, + dataset_text_field="text", + packing=False, + load_best_model_at_end=True if output_dir_path else False, + metric_for_best_model="eval_loss", + greater_is_better=False, + eval_steps=eval_steps, + save_steps=save_steps, + logging_steps=logging_steps, + ) + + def save_model( + self, + model_obj: AutoModelForCausalLM, + trainer: SFTTrainer, + peft_config: LoraConfig | None, + output_dir_path: Path, + ) -> None: + """Save the trained model. + Args: + model_obj: The model to save + trainer: The trainer instance + peft_config: Optional LoRA configuration + output_dir_path: Path to save the model + """ + logger.info("Saving final model") + model_obj.config.use_cache = True + + if peft_config: + logger.info("Merging LoRA weights with base model") + model_obj = trainer.model.merge_and_unload() + else: + model_obj = trainer.model + + save_path = output_dir_path / "merged_model" + logger.info(f"Saving model to {save_path}") + model_obj.save_pretrained(save_path) + + async def _run_training( + self, + model: str, + provider_config: dict[str, Any], + peft_config: LoraConfig | None, + config: dict[str, Any], + output_dir_path: Path | None, + ) -> None: + """Run the training process with signal handling.""" + + def signal_handler(signum, frame): + """Handle termination signals gracefully.""" + logger.info(f"Received signal {signum}, initiating graceful shutdown") + sys.exit(0) + + signal.signal(signal.SIGTERM, signal_handler) + signal.signal(signal.SIGINT, signal_handler) + + # Convert config dicts back to objects + logger.info("Initializing configuration objects") + provider_config_obj = HuggingFacePostTrainingConfig(**provider_config) + config_obj = TrainingConfig(**config) + + # Initialize and validate device + device = setup_torch_device(provider_config_obj.device) + logger.info(f"Using device '{device}'") + + # Load dataset and tokenizer + train_dataset, eval_dataset, tokenizer = await self.load_dataset(model, config_obj, provider_config_obj) + + # Calculate steps per epoch + if not config_obj.data_config: + raise ValueError("DataConfig is required for training") + steps_per_epoch = len(train_dataset) // config_obj.data_config.batch_size + + # Setup training arguments + training_args = self.setup_training_args( + config_obj, + provider_config_obj, + device, + output_dir_path, + steps_per_epoch, + ) + + # Load model + model_obj = self.load_model(model, device, provider_config_obj) + + # Initialize trainer + logger.info("Initializing SFTTrainer") + trainer = SFTTrainer( + model=model_obj, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + peft_config=peft_config, + args=training_args, + ) + + try: + # Train + logger.info("Starting training") + trainer.train() + logger.info("Training completed successfully") + + # Save final model if output directory is provided + if output_dir_path: + self.save_model(model_obj, trainer, peft_config, output_dir_path) + + finally: + # Clean up resources + logger.info("Cleaning up resources") + if hasattr(trainer, "model"): + evacuate_model_from_device(trainer.model, device.type) + del trainer + gc.collect() + logger.info("Cleanup completed") + async def train( self, model: str, @@ -315,50 +603,21 @@ class HFFinetuningSingleDevice: provider_config: HuggingFacePostTrainingConfig, ) -> tuple[dict[str, Any], list[Checkpoint] | None]: """Train a model using HuggingFace's SFTTrainer""" - try: - device = torch.device(provider_config.device) - except RuntimeError as e: - raise RuntimeError(f"Error getting Torch Device {str(e)}") from e - - # Detect device type and validate - if device.type == "cuda": - if not torch.cuda.is_available(): - raise RuntimeError( - f"{device.type}: Torch has no CUDA/ROCm support or could not detect a compatible device." - ) - # map unqualified 'cuda' to current device - if device.index is None: - device = torch.device(device.type, torch.cuda.current_device()) - elif device.type == "mps": - if not torch.backends.mps.is_available(): - raise RuntimeError(f"{device.type}: Torch has no MPS support or could not detect a compatible device.") - elif device.type == "hpu": - raise RuntimeError(f"{device.type}: training does not support Intel Gaudi.") - + # Initialize and validate device + device = setup_torch_device(provider_config.device) logger.info(f"Using device '{device}'") + output_dir_path = None if output_dir: output_dir_path = Path(output_dir) - # Track memory stats throughout training + # Track memory stats memory_stats = { "initial": get_memory_stats(device), - "after_model_load": None, "after_training": None, "final": None, } - # Validate data config - if not config.data_config: - raise ValueError("DataConfig is required for training") - - # Load dataset and tokenizer - train_dataset, eval_dataset, tokenizer = await self.load_dataset(model, config, provider_config) - - # Load model with model-specific config - model_obj = self.load_model(model, device, provider_config) - memory_stats["after_model_load"] = get_memory_stats(device) - # Configure LoRA peft_config = None if lora_config: @@ -371,93 +630,43 @@ class HFFinetuningSingleDevice: target_modules=lora_config.lora_attn_modules, ) - # Setup training arguments - lr = 2e-5 - if config.optimizer_config: - lr = config.optimizer_config.lr + # Validate data config + if not config.data_config: + raise ValueError("DataConfig is required for training") - # Calculate steps per epoch and appropriate intervals - steps_per_epoch = len(train_dataset) // config.data_config.batch_size - eval_steps = max(1, steps_per_epoch // 10) # Evaluate 10 times per epoch - save_steps = max(1, steps_per_epoch // 5) # Save 5 times per epoch - logging_steps = max(1, steps_per_epoch // 50) # Log 50 times per epoch - - logger.info(f"Dataset size: {len(train_dataset)} examples") - logger.info(f"Batch size: {config.data_config.batch_size}") - logger.info(f"Steps per epoch: {steps_per_epoch}") - logger.info(f"Will evaluate every {eval_steps} steps") - logger.info(f"Will save every {save_steps} steps") - logger.info(f"Will log every {logging_steps} steps") - - # save_strategy should be none if output dir is none - save_strategy = "no" - if output_dir_path: - save_strategy = "steps" - training_arguments = SFTConfig( - max_steps=config.max_steps_per_epoch, - output_dir=str(output_dir_path) if output_dir_path is not None else None, - num_train_epochs=config.n_epochs, - per_device_train_batch_size=config.data_config.batch_size, - fp16=device.type == "cuda", - bf16=device.type != "cuda", - # use_cpu should only be set if we are on a "True" CPU machine, not a MPS enabled Mac due to stability issues. - use_cpu=True if device.type == "cpu" and not torch.backends.mps.is_available() else False, - save_strategy=save_strategy, - save_steps=save_steps, - report_to="none", - max_seq_length=provider_config.max_seq_length, - gradient_accumulation_steps=config.gradient_accumulation_steps, - gradient_checkpointing=provider_config.gradient_checkpointing, - learning_rate=lr, - warmup_ratio=provider_config.warmup_ratio, - weight_decay=provider_config.weight_decay, - logging_steps=logging_steps, - # Enable validation - eval_strategy="steps", - eval_steps=eval_steps, - save_total_limit=provider_config.save_total_limit, - remove_unused_columns=False, - dataloader_pin_memory=provider_config.dataloader_pin_memory, - dataloader_num_workers=provider_config.dataloader_num_workers, - dataset_text_field="text", - packing=False, - # Add evaluation metrics - # loading the best model can only happen if we have saved a model - load_best_model_at_end=True if output_dir_path else False, - metric_for_best_model="eval_loss", - greater_is_better=False, - ) - - # Initialize trainer with both train and eval datasets - trainer = SFTTrainer( - model=model_obj, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - peft_config=peft_config, - args=training_arguments, - ) - - # Train - logger.info("Starting training") + # Train in a separate process + logger.info("Starting training in separate process") try: - trainer.train() + # Set multiprocessing start method to 'spawn' for CUDA/MPS compatibility + if device.type in ["cuda", "mps"]: + multiprocessing.set_start_method("spawn", force=True) + + process = multiprocessing.Process( + target=self._run_training_sync, + kwargs={ + "model": model, + "provider_config": provider_config.model_dump(), + "peft_config": peft_config, + "config": config.model_dump(), + "output_dir_path": output_dir_path, + }, + ) + process.start() + + # Monitor the process + while process.is_alive(): + process.join(timeout=1) # Check every second + if not process.is_alive(): + break + + # Get the return code + if process.exitcode != 0: + raise RuntimeError(f"Training failed with exit code {process.exitcode}") + memory_stats["after_training"] = get_memory_stats(device) - # Save final model - model_obj.config.use_cache = True - # if we have LoRA we need to do `merge_and_unload` - if lora_config: - model_obj = trainer.model.merge_and_unload() - else: - model_obj = trainer.model - - checkpoint = None checkpoints = None - # only save a final model if checkpoint dir is specified - # this is especially useful to test training rather than saving of checkpoints if output_dir_path: - model_obj.save_pretrained(output_dir_path / "merged_model") - # Create checkpoint checkpoint = Checkpoint( identifier=f"{model}-sft-{config.n_epochs}", @@ -470,33 +679,5 @@ class HFFinetuningSingleDevice: return memory_stats, checkpoints finally: - # Clean up resources - if hasattr(trainer, "model"): - if device.type != "cpu": - trainer.model.to("cpu") - if device.type == "cuda": - torch.cuda.empty_cache() - del trainer.model - del trainer - gc.collect() memory_stats["final"] = get_memory_stats(device) - - async def _setup_data( - self, - dataset_id: str, - ) -> list[dict[str, Any]]: - """Load dataset from llama stack dataset provider""" - try: - - async def fetch_rows(dataset_id: str): - return await self.datasetio_api.iterrows( - dataset_id=dataset_id, - limit=-1, - ) - - all_rows = await fetch_rows(dataset_id) - if not isinstance(all_rows.data, list): - raise RuntimeError("Expected dataset data to be a list") - return all_rows.data - except Exception as e: - raise RuntimeError(f"Failed to load dataset: {str(e)}") from e + gc.collect() diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index b5a495935..f56dd2499 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import gc import logging import os import time @@ -47,6 +46,7 @@ from llama_stack.apis.post_training import ( from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.models.llama.sku_list import resolve_model +from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device from llama_stack.providers.inline.post_training.torchtune.common import utils from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import ( TorchtuneCheckpointer, @@ -554,11 +554,7 @@ class LoraFinetuningSingleDevice: checkpoints.append(checkpoint) # clean up the memory after training finishes - if self._device.type != "cpu": - self._model.to("cpu") - torch.cuda.empty_cache() - del self._model - gc.collect() + evacuate_model_from_device(self._model, self._device.type) return (memory_stats, checkpoints) diff --git a/requirements.txt b/requirements.txt index 0857a9886..c2571025a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,206 +1,69 @@ # This file was autogenerated by uv via the following command: # uv export --frozen --no-hashes --no-emit-project --output-file=requirements.txt annotated-types==0.7.0 - # via pydantic anyio==4.8.0 - # via - # httpx - # llama-stack-client - # openai attrs==25.1.0 - # via - # jsonschema - # referencing blobfile==3.0.0 - # via llama-stack cachetools==5.5.2 - # via google-auth certifi==2025.1.31 - # via - # httpcore - # httpx - # kubernetes - # requests charset-normalizer==3.4.1 - # via requests click==8.1.8 - # via llama-stack-client colorama==0.4.6 ; sys_platform == 'win32' - # via - # click - # tqdm distro==1.9.0 - # via - # llama-stack-client - # openai durationpy==0.9 - # via kubernetes exceptiongroup==1.2.2 ; python_full_version < '3.11' - # via anyio filelock==3.17.0 - # via - # blobfile - # huggingface-hub fire==0.7.0 - # via llama-stack fsspec==2024.12.0 - # via huggingface-hub google-auth==2.38.0 - # via kubernetes h11==0.16.0 - # via - # httpcore - # llama-stack httpcore==1.0.9 - # via httpx httpx==0.28.1 - # via - # llama-stack - # llama-stack-client - # openai huggingface-hub==0.29.0 - # via llama-stack idna==3.10 - # via - # anyio - # httpx - # requests jinja2==3.1.6 - # via llama-stack jiter==0.8.2 - # via openai jsonschema==4.23.0 - # via llama-stack jsonschema-specifications==2024.10.1 - # via jsonschema kubernetes==32.0.1 - # via llama-stack llama-stack-client==0.2.7 - # via llama-stack lxml==5.3.1 - # via blobfile markdown-it-py==3.0.0 - # via rich markupsafe==3.0.2 - # via jinja2 mdurl==0.1.2 - # via markdown-it-py numpy==2.2.3 - # via pandas oauthlib==3.2.2 - # via - # kubernetes - # requests-oauthlib openai==1.71.0 - # via llama-stack packaging==24.2 - # via huggingface-hub pandas==2.2.3 - # via llama-stack-client pillow==11.1.0 - # via llama-stack prompt-toolkit==3.0.50 - # via - # llama-stack - # llama-stack-client pyaml==25.1.0 - # via llama-stack-client pyasn1==0.6.1 - # via - # pyasn1-modules - # rsa pyasn1-modules==0.4.2 - # via google-auth pycryptodomex==3.21.0 - # via blobfile pydantic==2.10.6 - # via - # llama-stack - # llama-stack-client - # openai pydantic-core==2.27.2 - # via pydantic pygments==2.19.1 - # via rich python-dateutil==2.9.0.post0 - # via - # kubernetes - # pandas python-dotenv==1.0.1 - # via llama-stack pytz==2025.1 - # via pandas pyyaml==6.0.2 - # via - # huggingface-hub - # kubernetes - # pyaml referencing==0.36.2 - # via - # jsonschema - # jsonschema-specifications regex==2024.11.6 - # via tiktoken requests==2.32.3 - # via - # huggingface-hub - # kubernetes - # llama-stack - # requests-oauthlib - # tiktoken requests-oauthlib==2.0.0 - # via kubernetes rich==13.9.4 - # via - # llama-stack - # llama-stack-client rpds-py==0.22.3 - # via - # jsonschema - # referencing rsa==4.9 - # via google-auth setuptools==75.8.0 - # via llama-stack six==1.17.0 - # via - # kubernetes - # python-dateutil sniffio==1.3.1 - # via - # anyio - # llama-stack-client - # openai termcolor==2.5.0 - # via - # fire - # llama-stack - # llama-stack-client tiktoken==0.9.0 - # via llama-stack tqdm==4.67.1 - # via - # huggingface-hub - # llama-stack-client - # openai typing-extensions==4.12.2 - # via - # anyio - # huggingface-hub - # llama-stack-client - # openai - # pydantic - # pydantic-core - # referencing - # rich tzdata==2025.1 - # via pandas urllib3==2.3.0 - # via - # blobfile - # kubernetes - # requests wcwidth==0.2.13 - # via prompt-toolkit websocket-client==1.8.0 - # via kubernetes