Merge branch 'main' into enable-vector-stores-files-api-tests

This commit is contained in:
Francisco Arceo 2025-07-31 07:44:31 -04:00 committed by GitHub
commit 8732103995
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
218 changed files with 1264 additions and 563 deletions

View file

@ -6,7 +6,7 @@
from typing import Any
from llama_stack.distribution.datatypes import AccessRule, Api
from llama_stack.core.datatypes import AccessRule, Api
from .config import MetaReferenceAgentsImplConfig

View file

@ -61,7 +61,7 @@ from llama_stack.apis.inference import (
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.distribution.datatypes import AccessRule
from llama_stack.core.datatypes import AccessRule
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import (
BuiltinTool,

View file

@ -41,7 +41,7 @@ from llama_stack.apis.inference import (
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.distribution.datatypes import AccessRule
from llama_stack.core.datatypes import AccessRule
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
from llama_stack.providers.utils.pagination import paginate_records
from llama_stack.providers.utils.responses.responses_store import ResponsesStore

View file

@ -10,10 +10,10 @@ import uuid
from datetime import UTC, datetime
from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
from llama_stack.distribution.access_control.datatypes import AccessRule
from llama_stack.distribution.datatypes import User
from llama_stack.distribution.request_headers import get_authenticated_user
from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed
from llama_stack.core.access_control.datatypes import AccessRule
from llama_stack.core.datatypes import User
from llama_stack.core.request_headers import get_authenticated_user
from llama_stack.providers.utils.kvstore import KVStore
log = logging.getLogger(__name__)

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from typing import Any
from llama_stack.distribution.datatypes import Api
from llama_stack.core.datatypes import Api
from .config import MetaReferenceEvalConfig

View file

@ -6,7 +6,7 @@
from typing import Any
from llama_stack.distribution.datatypes import AccessRule, Api
from llama_stack.core.datatypes import AccessRule, Api
from .config import LocalfsFilesImplConfig
from .files import LocalfsFilesImpl

View file

@ -19,7 +19,7 @@ from llama_stack.apis.files import (
OpenAIFileObject,
OpenAIFilePurpose,
)
from llama_stack.distribution.datatypes import AccessRule
from llama_stack.core.datatypes import AccessRule
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl

View file

@ -6,7 +6,7 @@
from pathlib import Path
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.core.utils.model_utils import model_local_dir
def model_checkpoint_dir(model_id) -> str:

View file

@ -6,7 +6,7 @@
from typing import Any
from llama_stack.distribution.datatypes import Api
from llama_stack.core.datatypes import Api
from .config import HuggingFacePostTrainingConfig

View file

@ -67,6 +67,12 @@ class HuggingFacePostTrainingConfig(BaseModel):
# Can improve data transfer speed to GPU but uses more memory
dataloader_pin_memory: bool = True
# DPO-specific parameters
dpo_beta: float = 0.1
use_reference_model: bool = True
dpo_loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"] = "sigmoid"
dpo_output_dir: str = "./checkpoints/dpo"
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {"checkpoint_format": "huggingface", "distributed_backend": None, "device": "cpu"}

View file

@ -25,6 +25,9 @@ from llama_stack.providers.inline.post_training.huggingface.config import (
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device import (
HFFinetuningSingleDevice,
)
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device_dpo import (
HFDPOAlignmentSingleDevice,
)
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
from llama_stack.schema_utils import webmethod
@ -36,6 +39,7 @@ class TrainingArtifactType(Enum):
_JOB_TYPE_SUPERVISED_FINE_TUNE = "supervised-fine-tune"
_JOB_TYPE_DPO_TRAINING = "dpo-training"
class HuggingFacePostTrainingImpl:
@ -119,12 +123,37 @@ class HuggingFacePostTrainingImpl:
hyperparam_search_config: dict[str, Any],
logger_config: dict[str, Any],
) -> PostTrainingJob:
raise NotImplementedError("DPO alignment is not implemented yet")
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
on_log_message_cb("Starting HF DPO alignment")
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
return ListPostTrainingJobsResponse(
data=[PostTrainingJob(job_uuid=job.id) for job in self._scheduler.get_jobs()]
)
recipe = HFDPOAlignmentSingleDevice(
job_uuid=job_uuid,
datasetio_api=self.datasetio_api,
datasets_api=self.datasets_api,
)
resources_allocated, checkpoints = await recipe.train(
model=finetuned_model,
output_dir=f"{self.config.dpo_output_dir}/{job_uuid}",
job_uuid=job_uuid,
dpo_config=algorithm_config,
config=training_config,
provider_config=self.config,
)
on_artifact_collected_cb(self._resources_stats_to_artifact(resources_allocated))
if checkpoints:
for checkpoint in checkpoints:
artifact = self._checkpoint_to_artifact(checkpoint)
on_artifact_collected_cb(artifact)
else:
on_log_message_cb("Warning: No checkpoints were saved during DPO training")
on_status_change_cb(SchedulerJobStatus.completed)
on_log_message_cb("HF DPO alignment completed")
job_uuid = self._scheduler.schedule(_JOB_TYPE_DPO_TRAINING, job_uuid, handler)
return PostTrainingJob(job_uuid=job_uuid)
@staticmethod
def _get_artifacts_metadata_by_type(job, artifact_type):
@ -174,3 +203,9 @@ class HuggingFacePostTrainingImpl:
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse | None:
job = self._scheduler.get_job(job_uuid)
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job))
@webmethod(route="/post-training/jobs", method="GET")
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
return ListPostTrainingJobsResponse(
data=[PostTrainingJob(job_uuid=job.id) for job in self._scheduler.get_jobs()]
)

View file

@ -8,30 +8,13 @@ import gc
import json
import logging
import multiprocessing
import os
import signal
import sys
from datetime import UTC, datetime
from pathlib import Path
from typing import Any
import psutil
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
# Set tokenizer parallelism environment variable
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Force PyTorch to use OpenBLAS instead of MKL
os.environ["MKL_THREADING_LAYER"] = "GNU"
os.environ["MKL_SERVICE_FORCE_INTEL"] = "0"
os.environ["MKL_NUM_THREADS"] = "1"
import torch
from datasets import Dataset
from peft import LoraConfig
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
)
@ -45,93 +28,25 @@ from llama_stack.apis.post_training import (
LoraFinetuningConfig,
TrainingConfig,
)
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
from ..config import HuggingFacePostTrainingConfig
from ..utils import (
calculate_training_steps,
create_checkpoints,
get_memory_stats,
get_save_strategy,
load_model,
load_rows_from_dataset,
setup_environment,
setup_signal_handlers,
setup_torch_device,
split_dataset,
)
logger = logging.getLogger(__name__)
def get_gb(to_convert: int) -> str:
"""Converts memory stats to GB and formats to 2 decimal places.
Args:
to_convert: Memory value in bytes
Returns:
str: Memory value in GB formatted to 2 decimal places
"""
return f"{(to_convert / (1024**3)):.2f}"
def get_memory_stats(device: torch.device) -> dict[str, Any]:
"""Get memory statistics for the given device."""
stats = {
"system_memory": {
"total": get_gb(psutil.virtual_memory().total),
"available": get_gb(psutil.virtual_memory().available),
"used": get_gb(psutil.virtual_memory().used),
"percent": psutil.virtual_memory().percent,
}
}
if device.type == "cuda":
stats["device_memory"] = {
"allocated": get_gb(torch.cuda.memory_allocated(device)),
"reserved": get_gb(torch.cuda.memory_reserved(device)),
"max_allocated": get_gb(torch.cuda.max_memory_allocated(device)),
}
elif device.type == "mps":
# MPS doesn't provide direct memory stats, but we can track system memory
stats["device_memory"] = {
"note": "MPS memory stats not directly available",
"system_memory_used": get_gb(psutil.virtual_memory().used),
}
elif device.type == "cpu":
# For CPU, we track process memory usage
process = psutil.Process()
stats["device_memory"] = {
"process_rss": get_gb(process.memory_info().rss),
"process_vms": get_gb(process.memory_info().vms),
"process_percent": process.memory_percent(),
}
return stats
def setup_torch_device(device_str: str) -> torch.device:
"""Initialize and validate a PyTorch device.
This function handles device initialization and validation for different device types:
- CUDA: Validates CUDA availability and handles device selection
- MPS: Validates MPS availability for Apple Silicon
- CPU: Basic validation
- HPU: Raises error as it's not supported
Args:
device_str: String specifying the device ('cuda', 'cpu', 'mps')
Returns:
torch.device: The initialized and validated device
Raises:
RuntimeError: If device initialization fails or device is not supported
"""
try:
device = torch.device(device_str)
except RuntimeError as e:
raise RuntimeError(f"Error getting Torch Device {str(e)}") from e
# Validate device capabilities
if device.type == "cuda":
if not torch.cuda.is_available():
raise RuntimeError(
f"{device.type}: Torch has no CUDA/ROCm support or could not detect a compatible device."
)
if device.index is None:
device = torch.device(device.type, torch.cuda.current_device())
elif device.type == "mps":
if not torch.backends.mps.is_available():
raise RuntimeError(f"{device.type}: Torch has no MPS support or could not detect a compatible device.")
elif device.type == "hpu":
raise RuntimeError(f"{device.type}: training does not support Intel Gaudi.")
return device
class HFFinetuningSingleDevice:
def __init__(
self,
@ -262,19 +177,6 @@ class HFFinetuningSingleDevice:
remove_columns=ds.column_names,
)
async def _setup_data(self, dataset_id: str) -> list[dict[str, Any]]:
"""Load dataset from llama stack dataset provider"""
try:
all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id,
limit=-1,
)
if not isinstance(all_rows.data, list):
raise RuntimeError("Expected dataset data to be a list")
return all_rows.data
except Exception as e:
raise RuntimeError(f"Failed to load dataset: {str(e)}") from e
def _run_training_sync(
self,
model: str,
@ -327,7 +229,7 @@ class HFFinetuningSingleDevice:
# Load dataset
logger.info(f"Loading dataset: {config.data_config.dataset_id}")
rows = await self._setup_data(config.data_config.dataset_id)
rows = await load_rows_from_dataset(self.datasetio_api, config.data_config.dataset_id)
if not self.validate_dataset_format(rows):
raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input")
logger.info(f"Loaded {len(rows)} rows from dataset")
@ -369,47 +271,10 @@ class HFFinetuningSingleDevice:
raise ValueError(f"Failed to create dataset: {str(e)}") from e
# Split dataset
logger.info("Splitting dataset into train and validation sets")
train_val_split = ds.train_test_split(test_size=0.1, seed=42)
train_dataset = train_val_split["train"]
eval_dataset = train_val_split["test"]
logger.info(f"Split dataset into {len(train_dataset)} training and {len(eval_dataset)} validation examples")
train_dataset, eval_dataset = split_dataset(ds)
return train_dataset, eval_dataset, tokenizer
def load_model(
self,
model: str,
device: torch.device,
provider_config: HuggingFacePostTrainingConfig,
) -> AutoModelForCausalLM:
"""Load and initialize the model for training.
Args:
model: The model identifier to load
device: The device to load the model onto
provider_config: Provider-specific configuration
Returns:
The loaded and initialized model
Raises:
RuntimeError: If model loading fails
"""
logger.info("Loading the base model")
try:
model_config = AutoConfig.from_pretrained(model, **provider_config.model_specific_config)
model_obj = AutoModelForCausalLM.from_pretrained(
model,
torch_dtype="auto" if device.type != "cpu" else "float32",
quantization_config=None,
config=model_config,
**provider_config.model_specific_config,
)
# Always move model to specified device
model_obj = model_obj.to(device)
logger.info(f"Model loaded and moved to device: {model_obj.device}")
return model_obj
except Exception as e:
raise RuntimeError(f"Failed to load model: {str(e)}") from e
def setup_training_args(
self,
config: TrainingConfig,
@ -439,27 +304,12 @@ class HFFinetuningSingleDevice:
raise ValueError("DataConfig is required for training")
data_config = config.data_config
# Calculate steps
total_steps = steps_per_epoch * config.n_epochs
max_steps = min(config.max_steps_per_epoch, total_steps)
logging_steps = max(1, steps_per_epoch // 50) # Log 50 times per epoch
logger.info("Training configuration:")
logger.info(f"- Steps per epoch: {steps_per_epoch}")
logger.info(f"- Total steps: {total_steps}")
logger.info(f"- Max steps: {max_steps}")
logger.info(f"- Logging steps: {logging_steps}")
# Configure save strategy
save_strategy = "no"
eval_strategy = "no"
if output_dir_path:
save_strategy = "epoch"
eval_strategy = "epoch"
logger.info(f"Will save checkpoints to {output_dir_path}")
# Calculate steps and get save strategy
step_info = calculate_training_steps(steps_per_epoch, config)
save_strategy, eval_strategy = get_save_strategy(output_dir_path)
return SFTConfig(
max_steps=max_steps,
max_steps=step_info["max_steps"],
output_dir=str(output_dir_path) if output_dir_path is not None else None,
num_train_epochs=config.n_epochs,
per_device_train_batch_size=data_config.batch_size,
@ -483,7 +333,7 @@ class HFFinetuningSingleDevice:
load_best_model_at_end=True if output_dir_path else False,
metric_for_best_model="eval_loss",
greater_is_better=False,
logging_steps=logging_steps,
logging_steps=step_info["logging_steps"],
)
def save_model(
@ -523,13 +373,11 @@ class HFFinetuningSingleDevice:
) -> None:
"""Run the training process with signal handling."""
def signal_handler(signum, frame):
"""Handle termination signals gracefully."""
logger.info(f"Received signal {signum}, initiating graceful shutdown")
sys.exit(0)
# Setup environment variables
setup_environment()
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
# Setup signal handlers
setup_signal_handlers()
# Convert config dicts back to objects
logger.info("Initializing configuration objects")
@ -558,7 +406,7 @@ class HFFinetuningSingleDevice:
)
# Load model
model_obj = self.load_model(model, device, provider_config_obj)
model_obj = load_model(model, device, provider_config_obj)
# Initialize trainer
logger.info("Initializing SFTTrainer")
@ -633,7 +481,7 @@ class HFFinetuningSingleDevice:
# Train in a separate process
logger.info("Starting training in separate process")
try:
# Set multiprocessing start method to 'spawn' for CUDA/MPS compatibility
# Setup multiprocessing for device
if device.type in ["cuda", "mps"]:
multiprocessing.set_start_method("spawn", force=True)
@ -663,37 +511,7 @@ class HFFinetuningSingleDevice:
checkpoints = []
if output_dir_path:
# Get all checkpoint directories and sort them numerically
checkpoint_dirs = sorted(
[d for d in output_dir_path.glob("checkpoint-*") if d.is_dir()],
key=lambda x: int(x.name.split("-")[1]),
)
# Add all checkpoint directories
for epoch_number, checkpoint_dir in enumerate(checkpoint_dirs, start=1):
# Get the creation time of the directory
created_time = datetime.fromtimestamp(os.path.getctime(checkpoint_dir), tz=UTC)
checkpoint = Checkpoint(
identifier=checkpoint_dir.name,
created_at=created_time,
epoch=epoch_number,
post_training_job_id=job_uuid,
path=str(checkpoint_dir),
)
checkpoints.append(checkpoint)
# Add the merged model as a checkpoint
merged_model_path = output_dir_path / "merged_model"
if merged_model_path.exists():
checkpoint = Checkpoint(
identifier=f"{model}-sft-{config.n_epochs}",
created_at=datetime.now(UTC),
epoch=config.n_epochs,
post_training_job_id=job_uuid,
path=str(merged_model_path),
)
checkpoints.append(checkpoint)
checkpoints = create_checkpoints(output_dir_path, job_uuid, model, config, "merged_model")
return memory_stats, checkpoints if checkpoints else None
finally:

View file

@ -0,0 +1,485 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import gc
import logging
import multiprocessing
from pathlib import Path
from typing import Any
import torch
from datasets import Dataset
from transformers import (
AutoTokenizer,
)
from trl import DPOConfig, DPOTrainer
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
Checkpoint,
DPOAlignmentConfig,
TrainingConfig,
)
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
from ..config import HuggingFacePostTrainingConfig
from ..utils import (
calculate_training_steps,
create_checkpoints,
get_memory_stats,
get_save_strategy,
load_model,
load_rows_from_dataset,
setup_environment,
setup_signal_handlers,
setup_torch_device,
split_dataset,
)
logger = logging.getLogger(__name__)
class HFDPOAlignmentSingleDevice:
def __init__(
self,
job_uuid: str,
datasetio_api: DatasetIO,
datasets_api: Datasets,
):
self.datasetio_api = datasetio_api
self.datasets_api = datasets_api
self.job_uuid = job_uuid
def validate_dataset_format(self, rows: list[dict]) -> None:
"""Validate that the dataset has the required fields for DPO training."""
required_fields = ["prompt", "chosen", "rejected"]
if not rows:
logger.warning("Dataset is empty")
raise ValueError("Dataset is empty")
for i, row in enumerate(rows):
if not isinstance(row, dict):
logger.warning(f"Row {i} is not a dictionary")
raise ValueError(f"Row {i} is not a dictionary")
for field in required_fields:
if field not in row:
logger.warning(f"Row {i} missing required DPO field: {field}")
raise ValueError(f"Row {i} missing required DPO field: {field}")
# Handle both string and list formats
if field == "prompt":
# Prompt should be a string
if not isinstance(row[field], str):
logger.warning(f"Row {i} field '{field}' is not a string")
raise ValueError(f"Row {i} field '{field}' is not a string")
if not row[field].strip():
logger.warning(f"Row {i} field '{field}' is empty")
raise ValueError(f"Row {i} field '{field}' is empty")
else:
# chosen/rejected can be either strings or lists of messages
if isinstance(row[field], str):
if not row[field].strip():
logger.warning(f"Row {i} field '{field}' is empty")
raise ValueError(f"Row {i} field '{field}' is empty")
elif isinstance(row[field], list):
if not row[field]:
logger.warning(f"Row {i} field '{field}' is empty list")
raise ValueError(f"Row {i} field '{field}' is empty list")
else:
logger.warning(f"Row {i} field '{field}' is neither string nor list")
raise ValueError(f"Row {i} field '{field}' is neither string nor list")
logger.info(f"DPO dataset validation passed: {len(rows)} preference examples")
def _process_dpo_format(self, row: dict) -> tuple[str | None, str | None, str | None]:
"""Process a row in DPO format, handling both string and conversation list formats."""
if all(field in row for field in ["prompt", "chosen", "rejected"]):
prompt = row["prompt"]
# Handle chosen field - convert list to string if needed
if isinstance(row["chosen"], list):
# For conversation format, concatenate messages
chosen = "\n".join(
[msg.get("content", "") if isinstance(msg, dict) else str(msg) for msg in row["chosen"]]
)
else:
chosen = row["chosen"]
# Handle rejected field - convert list to string if needed
if isinstance(row["rejected"], list):
# For conversation format, concatenate messages
rejected = "\n".join(
[msg.get("content", "") if isinstance(msg, dict) else str(msg) for msg in row["rejected"]]
)
else:
rejected = row["rejected"]
return prompt, chosen, rejected
return None, None, None
def _format_text_for_dpo(self, prompt: str, response: str, provider_config: HuggingFacePostTrainingConfig) -> str:
"""Format prompt and response text based on model requirements."""
if hasattr(provider_config, "chat_template") and provider_config.chat_template:
# Use the chat template, supporting both {prompt}/{response} and {input}/{output}
template = provider_config.chat_template
# Try prompt/response first (DPO style)
if "{prompt}" in template and "{response}" in template:
return template.format(prompt=prompt, response=response)
# Fall back to input/output (SFT style)
elif "{input}" in template and "{output}" in template:
return template.format(input=prompt, output=response)
else:
# If template doesn't have expected placeholders, use default
return f"{prompt}\n{response}"
return f"{prompt}\n{response}"
def _create_dataset(
self, rows: list[dict], config: TrainingConfig, provider_config: HuggingFacePostTrainingConfig
) -> Dataset:
"""Create and preprocess the dataset for DPO."""
dpo_examples = []
for row in rows:
prompt, chosen, rejected = self._process_dpo_format(row)
if prompt and chosen and rejected:
# Format the texts
chosen_formatted = self._format_text_for_dpo(prompt, chosen, provider_config)
rejected_formatted = self._format_text_for_dpo(prompt, rejected, provider_config)
dpo_examples.append(
{
"prompt": prompt,
"chosen": chosen_formatted,
"rejected": rejected_formatted,
}
)
if not dpo_examples:
raise ValueError("No valid preference examples found in dataset")
logger.info(f"Created DPO dataset with {len(dpo_examples)} preference pairs")
return Dataset.from_list(dpo_examples)
def _preprocess_dataset(
self, ds: Dataset, tokenizer: AutoTokenizer, provider_config: HuggingFacePostTrainingConfig
) -> Dataset:
"""Preprocess the dataset with tokenizer for DPO."""
# DPOTrainer expects raw text, so we don't tokenize here
# Just return the dataset as is
return ds
def _run_training_sync(
self,
model: str,
provider_config: dict[str, Any],
dpo_config: dict[str, Any],
config: dict[str, Any],
output_dir_path: Path | None,
) -> None:
"""Synchronous wrapper for running DPO training process."""
import asyncio
logger.info("Starting DPO training process with async wrapper")
asyncio.run(
self._run_training(
model=model,
provider_config=provider_config,
dpo_config=dpo_config,
config=config,
output_dir_path=output_dir_path,
)
)
async def load_dataset(
self,
model: str,
config: TrainingConfig,
provider_config: HuggingFacePostTrainingConfig,
) -> tuple[Dataset, Dataset, AutoTokenizer]:
"""Load and prepare the dataset for DPO training."""
# Validate data config
if not config.data_config:
raise ValueError("DataConfig is required for DPO training")
# Load dataset
logger.info(f"Loading dataset: {config.data_config.dataset_id}")
rows = await load_rows_from_dataset(self.datasetio_api, config.data_config.dataset_id)
self.validate_dataset_format(rows)
logger.info(f"Loaded {len(rows)} rows from dataset")
# Initialize tokenizer
logger.info(f"Initializing tokenizer for model: {model}")
try:
tokenizer = AutoTokenizer.from_pretrained(model, **provider_config.model_specific_config)
# Set pad token to eos token if not present
if not tokenizer.pad_token:
tokenizer.pad_token = tokenizer.eos_token
# Set padding side to left for DPO
tokenizer.padding_side = "left"
# Set truncation side to right to keep the beginning of the sequence
tokenizer.truncation_side = "right"
# Set model max length to match provider config
tokenizer.model_max_length = provider_config.max_seq_length
logger.info("Tokenizer initialized successfully for DPO")
except Exception as e:
raise RuntimeError(f"Failed to initialize tokenizer: {str(e)}") from e
# Create and preprocess dataset
logger.info("Creating and preprocessing dataset for DPO")
try:
ds = self._create_dataset(rows, config, provider_config)
ds = self._preprocess_dataset(ds, tokenizer, provider_config)
logger.info(f"Dataset created with {len(ds)} examples")
except Exception as e:
raise ValueError(f"Failed to create dataset: {str(e)}") from e
# Split dataset
train_dataset, eval_dataset = split_dataset(ds)
return train_dataset, eval_dataset, tokenizer
def setup_training_args(
self,
config: TrainingConfig,
provider_config: HuggingFacePostTrainingConfig,
dpo_config: DPOAlignmentConfig,
device: torch.device,
output_dir_path: Path | None,
steps_per_epoch: int,
) -> DPOConfig:
"""Setup DPO training arguments."""
logger.info("Configuring DPO training arguments")
lr = 5e-7 # Lower learning rate for DPO
if config.optimizer_config:
lr = config.optimizer_config.lr
logger.info(f"Using custom learning rate: {lr}")
# Validate data config
if not config.data_config:
raise ValueError("DataConfig is required for training")
data_config = config.data_config
# Calculate steps and get save strategy
step_info = calculate_training_steps(steps_per_epoch, config)
save_strategy, eval_strategy = get_save_strategy(output_dir_path)
logger.info("DPO training configuration:")
logger.info(f"- DPO beta: {dpo_config.beta}")
logger.info(f"- DPO loss type: {provider_config.dpo_loss_type}")
# Calculate max prompt length as half of max sequence length
max_prompt_length = provider_config.max_seq_length // 2
return DPOConfig(
max_steps=step_info["max_steps"],
output_dir=str(output_dir_path) if output_dir_path is not None else None,
num_train_epochs=config.n_epochs,
per_device_train_batch_size=data_config.batch_size,
fp16=device.type == "cuda",
bf16=False, # Causes CPU issues.
eval_strategy=eval_strategy,
use_cpu=True if device.type == "cpu" and not torch.backends.mps.is_available() else False,
save_strategy=save_strategy,
report_to="none",
max_length=provider_config.max_seq_length,
max_prompt_length=max_prompt_length,
gradient_accumulation_steps=config.gradient_accumulation_steps,
gradient_checkpointing=provider_config.gradient_checkpointing,
learning_rate=lr,
warmup_ratio=provider_config.warmup_ratio,
weight_decay=provider_config.weight_decay,
remove_unused_columns=False,
dataloader_pin_memory=provider_config.dataloader_pin_memory,
dataloader_num_workers=provider_config.dataloader_num_workers,
load_best_model_at_end=True if output_dir_path else False,
metric_for_best_model="eval_loss",
greater_is_better=False,
logging_steps=step_info["logging_steps"],
save_total_limit=provider_config.save_total_limit,
# DPO specific parameters
beta=dpo_config.beta,
loss_type=provider_config.dpo_loss_type,
)
def save_model(
self,
trainer: DPOTrainer,
output_dir_path: Path,
) -> None:
"""Save the trained DPO model."""
logger.info("Saving final DPO model")
save_path = output_dir_path / "dpo_model"
logger.info(f"Saving model to {save_path}")
# Save model and tokenizer
trainer.save_model(str(save_path))
async def _run_training(
self,
model: str,
provider_config: dict[str, Any],
dpo_config: dict[str, Any],
config: dict[str, Any],
output_dir_path: Path | None,
) -> None:
"""Run the DPO training process with signal handling."""
# Setup environment variables
setup_environment()
# Setup signal handlers
setup_signal_handlers()
# Convert config dicts back to objects
logger.info("Initializing configuration objects")
provider_config_obj = HuggingFacePostTrainingConfig(**provider_config)
config_obj = TrainingConfig(**config)
dpo_config_obj = DPOAlignmentConfig(**dpo_config)
# Initialize and validate device
device = setup_torch_device(provider_config_obj.device)
logger.info(f"Using device '{device}'")
# Load dataset and tokenizer
train_dataset, eval_dataset, tokenizer = await self.load_dataset(model, config_obj, provider_config_obj)
# Calculate steps per epoch
if not config_obj.data_config:
raise ValueError("DataConfig is required for training")
steps_per_epoch = len(train_dataset) // config_obj.data_config.batch_size
# Setup training arguments
training_args = self.setup_training_args(
config_obj,
provider_config_obj,
dpo_config_obj,
device,
output_dir_path,
steps_per_epoch,
)
# Load model and reference model
model_obj = load_model(model, device, provider_config_obj)
ref_model = None
if provider_config_obj.use_reference_model:
logger.info("Loading separate reference model for DPO")
ref_model = load_model(model, device, provider_config_obj)
else:
logger.info("Using shared reference model for DPO")
# Initialize DPO trainer
logger.info("Initializing DPOTrainer")
trainer = DPOTrainer(
model=model_obj,
ref_model=ref_model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
processing_class=tokenizer,
)
try:
# Train
logger.info("Starting DPO training")
trainer.train()
logger.info("DPO training completed successfully")
# Save final model if output directory is provided
if output_dir_path:
logger.info(f"Saving model to output directory: {output_dir_path}")
self.save_model(trainer, output_dir_path)
logger.info("Model save completed")
finally:
# Clean up resources
logger.info("Cleaning up resources")
if hasattr(trainer, "model"):
evacuate_model_from_device(trainer.model, device.type)
if ref_model:
evacuate_model_from_device(ref_model, device.type)
del trainer
del ref_model
gc.collect()
logger.info("Cleanup completed")
logger.info("DPO training process finishing successfully")
async def train(
self,
model: str,
output_dir: str | None,
job_uuid: str,
dpo_config: DPOAlignmentConfig,
config: TrainingConfig,
provider_config: HuggingFacePostTrainingConfig,
) -> tuple[dict[str, Any], list[Checkpoint] | None]:
"""Train a model using HuggingFace's DPOTrainer"""
# Initialize and validate device
device = setup_torch_device(provider_config.device)
logger.info(f"Using device '{device}'")
output_dir_path = None
if output_dir:
output_dir_path = Path(output_dir)
# Track memory stats
memory_stats = {
"initial": get_memory_stats(device),
"after_training": None,
"final": None,
}
# Validate data config
if not config.data_config:
raise ValueError("DataConfig is required for training")
# Train in a separate process
logger.info("Starting DPO training in separate process")
try:
# Setup multiprocessing for device
if device.type in ["cuda", "mps"]:
multiprocessing.set_start_method("spawn", force=True)
process = multiprocessing.Process(
target=self._run_training_sync,
kwargs={
"model": model,
"provider_config": provider_config.model_dump(),
"dpo_config": dpo_config.model_dump(),
"config": config.model_dump(),
"output_dir_path": output_dir_path,
},
)
process.start()
# Monitor the process
while process.is_alive():
process.join(timeout=1) # Check every second
if not process.is_alive():
break
# Get the return code
if process.exitcode != 0:
raise RuntimeError(f"DPO training failed with exit code {process.exitcode}")
memory_stats["after_training"] = get_memory_stats(device)
checkpoints = []
if output_dir_path:
checkpoints = create_checkpoints(output_dir_path, job_uuid, model, config, "dpo_model")
return memory_stats, checkpoints if checkpoints else None
finally:
memory_stats["final"] = get_memory_stats(device)
gc.collect()

View file

@ -0,0 +1,269 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import os
import signal
import sys
from datetime import UTC, datetime
from pathlib import Path
from typing import Any
import psutil
import torch
from datasets import Dataset
from transformers import AutoConfig, AutoModelForCausalLM
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.post_training import Checkpoint, TrainingConfig
from .config import HuggingFacePostTrainingConfig
logger = logging.getLogger(__name__)
def setup_environment():
"""Setup common environment variables for training."""
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["MKL_THREADING_LAYER"] = "GNU"
os.environ["MKL_SERVICE_FORCE_INTEL"] = "0"
os.environ["MKL_NUM_THREADS"] = "1"
def bytes_to_gb(to_convert: int) -> str:
"""Converts memory stats to GB and formats to 2 decimal places.
Args:
to_convert: Memory value in bytes
Returns:
str: Memory value in GB formatted to 2 decimal places
"""
return f"{(to_convert / (1024**3)):.2f}"
def get_memory_stats(device: torch.device) -> dict[str, Any]:
"""Get memory statistics for the given device."""
stats = {
"system_memory": {
"total": bytes_to_gb(psutil.virtual_memory().total),
"available": bytes_to_gb(psutil.virtual_memory().available),
"used": bytes_to_gb(psutil.virtual_memory().used),
"percent": psutil.virtual_memory().percent,
}
}
if device.type == "cuda":
stats["device_memory"] = {
"allocated": bytes_to_gb(torch.cuda.memory_allocated(device)),
"reserved": bytes_to_gb(torch.cuda.memory_reserved(device)),
"max_allocated": bytes_to_gb(torch.cuda.max_memory_allocated(device)),
}
elif device.type == "mps":
# MPS doesn't provide direct memory stats, but we can track system memory
stats["device_memory"] = {
"note": "MPS memory stats not directly available",
"system_memory_used": bytes_to_gb(psutil.virtual_memory().used),
}
elif device.type == "cpu":
# For CPU, we track process memory usage
process = psutil.Process()
stats["device_memory"] = {
"process_rss": bytes_to_gb(process.memory_info().rss),
"process_vms": bytes_to_gb(process.memory_info().vms),
"process_percent": process.memory_percent(),
}
return stats
def setup_torch_device(device_str: str) -> torch.device:
"""Initialize and validate a PyTorch device.
This function handles device initialization and validation for different device types:
- CUDA: Validates CUDA availability and handles device selection
- MPS: Validates MPS availability for Apple Silicon
- CPU: Basic validation
- HPU: Raises error as it's not supported
Args:
device_str: String specifying the device ('cuda', 'cpu', 'mps')
Returns:
torch.device: The initialized and validated device
Raises:
RuntimeError: If device initialization fails or device is not supported
"""
try:
device = torch.device(device_str)
except RuntimeError as e:
raise RuntimeError(f"Error getting Torch Device {str(e)}") from e
# Validate device capabilities
if device.type == "cuda":
if not torch.cuda.is_available():
raise RuntimeError(
f"{device.type}: Torch has no CUDA/ROCm support or could not detect a compatible device."
)
if device.index is None:
device = torch.device(device.type, torch.cuda.current_device())
elif device.type == "mps":
if not torch.backends.mps.is_available():
raise RuntimeError(f"{device.type}: Torch has no MPS support or could not detect a compatible device.")
elif device.type == "hpu":
raise RuntimeError(f"{device.type}: training does not support Intel Gaudi.")
return device
async def load_rows_from_dataset(datasetio_api: DatasetIO, dataset_id: str) -> list[dict[str, Any]]:
"""Load dataset from llama stack dataset provider"""
try:
all_rows = await datasetio_api.iterrows(
dataset_id=dataset_id,
limit=-1,
)
if not isinstance(all_rows.data, list):
raise RuntimeError("Expected dataset data to be a list")
return all_rows.data
except Exception as e:
raise RuntimeError(f"Failed to load dataset: {str(e)}") from e
def load_model(
model: str,
device: torch.device,
provider_config: HuggingFacePostTrainingConfig,
) -> AutoModelForCausalLM:
"""Load and initialize the model for training.
Args:
model: The model identifier to load
device: The device to load the model onto
provider_config: Provider-specific configuration
Returns:
The loaded and initialized model
Raises:
RuntimeError: If model loading fails
"""
logger.info("Loading the base model")
try:
model_config = AutoConfig.from_pretrained(model, **provider_config.model_specific_config)
model_obj = AutoModelForCausalLM.from_pretrained(
model,
torch_dtype="auto" if device.type != "cpu" else "float32",
quantization_config=None,
config=model_config,
**provider_config.model_specific_config,
)
# Always move model to specified device
model_obj = model_obj.to(device)
logger.info(f"Model loaded and moved to device: {model_obj.device}")
return model_obj
except Exception as e:
raise RuntimeError(f"Failed to load model: {str(e)}") from e
def split_dataset(ds: Dataset) -> tuple[Dataset, Dataset]:
"""Split dataset into train and validation sets.
Args:
ds: Dataset to split
Returns:
tuple: (train_dataset, eval_dataset)
"""
logger.info("Splitting dataset into train and validation sets")
train_val_split = ds.train_test_split(test_size=0.1, seed=42)
train_dataset = train_val_split["train"]
eval_dataset = train_val_split["test"]
logger.info(f"Split dataset into {len(train_dataset)} training and {len(eval_dataset)} validation examples")
return train_dataset, eval_dataset
def setup_signal_handlers():
"""Setup signal handlers for graceful shutdown."""
def signal_handler(signum, frame):
logger.info(f"Received signal {signum}, initiating graceful shutdown")
sys.exit(0)
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
def calculate_training_steps(steps_per_epoch: int, config: TrainingConfig) -> dict[str, int]:
"""Calculate training steps and logging configuration.
Args:
steps_per_epoch: Number of training steps per epoch
config: Training configuration
Returns:
dict: Dictionary with calculated step values
"""
total_steps = steps_per_epoch * config.n_epochs
max_steps = min(config.max_steps_per_epoch, total_steps)
logging_steps = max(1, steps_per_epoch // 50) # Log 50 times per epoch
logger.info("Training configuration:")
logger.info(f"- Steps per epoch: {steps_per_epoch}")
logger.info(f"- Total steps: {total_steps}")
logger.info(f"- Max steps: {max_steps}")
logger.info(f"- Logging steps: {logging_steps}")
return {"total_steps": total_steps, "max_steps": max_steps, "logging_steps": logging_steps}
def get_save_strategy(output_dir_path: Path | None) -> tuple[str, str]:
"""Get save and evaluation strategy based on output directory.
Args:
output_dir_path: Optional path to save the model
Returns:
tuple: (save_strategy, eval_strategy)
"""
if output_dir_path:
logger.info(f"Will save checkpoints to {output_dir_path}")
return "epoch", "epoch"
return "no", "no"
def create_checkpoints(
output_dir_path: Path, job_uuid: str, model: str, config: TrainingConfig, final_model_name: str
) -> list[Checkpoint]:
"""Create checkpoint objects from training output.
Args:
output_dir_path: Path to the training output directory
job_uuid: Unique identifier for the training job
model: Model identifier
config: Training configuration
final_model_name: Name of the final model directory ("merged_model" for SFT, "dpo_model" for DPO)
Returns:
List of Checkpoint objects
"""
checkpoints = []
# Add checkpoint directories
checkpoint_dirs = sorted(
[d for d in output_dir_path.glob("checkpoint-*") if d.is_dir()],
key=lambda x: int(x.name.split("-")[1]),
)
for epoch_number, checkpoint_dir in enumerate(checkpoint_dirs, start=1):
created_time = datetime.fromtimestamp(os.path.getctime(checkpoint_dir), tz=UTC)
checkpoint = Checkpoint(
identifier=checkpoint_dir.name,
created_at=created_time,
epoch=epoch_number,
post_training_job_id=job_uuid,
path=str(checkpoint_dir),
)
checkpoints.append(checkpoint)
# Add final model
final_model_path = output_dir_path / final_model_name
if final_model_path.exists():
training_type = "sft" if final_model_name == "merged_model" else "dpo"
checkpoint = Checkpoint(
identifier=f"{model}-{training_type}-{config.n_epochs}",
created_at=datetime.now(UTC),
epoch=config.n_epochs,
post_training_job_id=job_uuid,
path=str(final_model_path),
)
checkpoints.append(checkpoint)
return checkpoints

View file

@ -6,7 +6,7 @@
from typing import Any
from llama_stack.distribution.datatypes import Api
from llama_stack.core.datatypes import Api
from .config import TorchtunePostTrainingConfig

View file

@ -43,8 +43,8 @@ from llama_stack.apis.post_training import (
QATFinetuningConfig,
TrainingConfig,
)
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.core.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
from llama_stack.core.utils.model_utils import model_local_dir
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
from llama_stack.providers.inline.post_training.torchtune.common import utils

View file

@ -21,7 +21,7 @@ from llama_stack.apis.safety import (
ViolationLevel,
)
from llama_stack.apis.shields import Shield
from llama_stack.distribution.datatypes import Api
from llama_stack.core.datatypes import Api
from llama_stack.models.llama.datatypes import Role
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.datatypes import ShieldsProtocolPrivate

View file

@ -18,7 +18,7 @@ from llama_stack.apis.safety import (
ViolationLevel,
)
from llama_stack.apis.shields import Shield
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.core.utils.model_utils import model_local_dir
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from typing import Any
from llama_stack.distribution.datatypes import Api
from llama_stack.core.datatypes import Api
from .config import BasicScoringConfig

View file

@ -14,7 +14,7 @@ from llama_stack.apis.scoring import (
ScoringResult,
)
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.distribution.datatypes import Api
from llama_stack.core.datatypes import Api
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator import (
get_valid_schemas,

View file

@ -7,7 +7,7 @@ from typing import Any
from pydantic import BaseModel
from llama_stack.distribution.datatypes import Api
from llama_stack.core.datatypes import Api
from .config import BraintrustScoringConfig

View file

@ -29,8 +29,8 @@ from llama_stack.apis.scoring import (
ScoringResultRow,
)
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.distribution.datatypes import Api
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.core.datatypes import Api
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator import (
get_valid_schemas,

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from typing import Any
from llama_stack.distribution.datatypes import Api
from llama_stack.core.datatypes import Api
from .config import LlmAsJudgeScoringConfig

View file

@ -15,7 +15,7 @@ from llama_stack.apis.scoring import (
ScoringResult,
)
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.distribution.datatypes import Api
from llama_stack.core.datatypes import Api
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator import (
get_valid_schemas,

View file

@ -6,7 +6,7 @@
from typing import Any
from llama_stack.distribution.datatypes import Api
from llama_stack.core.datatypes import Api
from .config import TelemetryConfig, TelemetrySink

View file

@ -9,7 +9,7 @@ from typing import Any
from pydantic import BaseModel, Field, field_validator
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR
class TelemetrySink(StrEnum):

View file

@ -36,7 +36,7 @@ from llama_stack.apis.telemetry import (
Trace,
UnstructuredLogEvent,
)
from llama_stack.distribution.datatypes import Api
from llama_stack.core.datatypes import Api
from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import (
ConsoleSpanProcessor,
)

View file

@ -34,7 +34,7 @@ os.environ["NVIDIA_API_KEY"] = "your-api-key"
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
os.environ["NVIDIA_DATASET_NAMESPACE"] = "default"
os.environ["NVIDIA_PROJECT_ID"] = "test-project"
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
from llama_stack.core.library_client import LlamaStackAsLibraryClient
client = LlamaStackAsLibraryClient("nvidia")
client.initialize()

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from typing import Any
from llama_stack.distribution.datatypes import Api
from llama_stack.core.datatypes import Api
from .config import NVIDIAEvalConfig

View file

@ -39,7 +39,7 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,

View file

@ -33,7 +33,7 @@ os.environ["NVIDIA_API_KEY"] = (
)
os.environ["NVIDIA_BASE_URL"] = "http://nim.test" # NIM URL
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
from llama_stack.core.library_client import LlamaStackAsLibraryClient
client = LlamaStackAsLibraryClient("nvidia")
client.initialize()

View file

@ -34,7 +34,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat,
)
from llama_stack.apis.models import Model
from llama_stack.distribution.library_client import convert_pydantic_to_json_value, convert_to_pydantic
from llama_stack.core.library_client import convert_pydantic_to_json_value, convert_to_pydantic
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params

View file

@ -38,7 +38,7 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (

View file

@ -40,7 +40,7 @@ os.environ["NVIDIA_DATASET_NAMESPACE"] = "default"
os.environ["NVIDIA_PROJECT_ID"] = "test-project"
os.environ["NVIDIA_OUTPUT_MODEL_DIR"] = "test-example-model@v1"
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
from llama_stack.core.library_client import LlamaStackAsLibraryClient
client = LlamaStackAsLibraryClient("nvidia")
client.initialize()

View file

@ -32,7 +32,7 @@ import os
os.environ["NVIDIA_API_KEY"] = "your-api-key"
os.environ["NVIDIA_GUARDRAILS_URL"] = "http://guardrails.test"
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
from llama_stack.core.library_client import LlamaStackAsLibraryClient
client = LlamaStackAsLibraryClient("nvidia")
client.initialize()

View file

@ -19,7 +19,7 @@ from llama_stack.apis.safety import (
ViolationLevel,
)
from llama_stack.apis.shields import Shield
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new

View file

@ -18,7 +18,7 @@ from llama_stack.apis.tools import (
ToolParameter,
ToolRuntime,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from .config import BingSearchToolConfig

View file

@ -17,7 +17,7 @@ from llama_stack.apis.tools import (
ToolParameter,
ToolRuntime,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.models.llama.datatypes import BuiltinTool
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate

View file

@ -15,7 +15,7 @@ from llama_stack.apis.tools import (
ToolInvocationResult,
ToolRuntime,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool, list_mcp_tools

View file

@ -18,7 +18,7 @@ from llama_stack.apis.tools import (
ToolParameter,
ToolRuntime,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from .config import TavilySearchToolConfig

View file

@ -18,7 +18,7 @@ from llama_stack.apis.tools import (
ToolParameter,
ToolRuntime,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from .config import WolframAlphaToolConfig

View file

@ -18,7 +18,7 @@ from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.files.files import Files
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore

View file

@ -12,7 +12,7 @@ from llama_stack.apis.common.type_system import (
CompletionInputType,
StringType,
)
from llama_stack.distribution.datatypes import Api
from llama_stack.core.datatypes import Api
class ColumnName(Enum):

View file

@ -10,8 +10,8 @@ from llama_stack.apis.inference import (
OpenAIMessageParam,
Order,
)
from llama_stack.distribution.datatypes import AccessRule
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
from llama_stack.core.datatypes import AccessRule
from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR
from ..sqlstore.api import ColumnDefinition, ColumnType
from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore

View file

@ -38,7 +38,7 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (

View file

@ -10,7 +10,7 @@ from typing import Annotated, Literal
from pydantic import BaseModel, Field, field_validator
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR
class KVStoreType(Enum):

View file

@ -14,8 +14,8 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseObject,
OpenAIResponseObjectWithInput,
)
from llama_stack.distribution.datatypes import AccessRule
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
from llama_stack.core.datatypes import AccessRule
from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR
from ..sqlstore.api import ColumnDefinition, ColumnType
from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore

View file

@ -7,11 +7,11 @@
from collections.abc import Mapping
from typing import Any, Literal
from llama_stack.distribution.access_control.access_control import default_policy, is_action_allowed
from llama_stack.distribution.access_control.conditions import ProtectedResource
from llama_stack.distribution.access_control.datatypes import AccessRule, Action, Scope
from llama_stack.distribution.datatypes import User
from llama_stack.distribution.request_headers import get_authenticated_user
from llama_stack.core.access_control.access_control import default_policy, is_action_allowed
from llama_stack.core.access_control.conditions import ProtectedResource
from llama_stack.core.access_control.datatypes import AccessRule, Action, Scope
from llama_stack.core.datatypes import User
from llama_stack.core.request_headers import get_authenticated_user
from llama_stack.log import get_logger
from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore

View file

@ -11,7 +11,7 @@ from typing import Annotated, Literal
from pydantic import BaseModel, Field
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR
from .api import SqlStore

View file

@ -22,7 +22,7 @@ from llama_stack.apis.tools import (
ToolInvocationResult,
ToolParameter,
)
from llama_stack.distribution.datatypes import AuthenticationRequiredError
from llama_stack.core.datatypes import AuthenticationRequiredError
from llama_stack.log import get_logger
from llama_stack.providers.utils.tools.ttl_dict import TTLDict