unit test update, warnings for unsupported parameters

This commit is contained in:
Ubuntu 2025-03-12 14:17:26 +00:00
parent 2f619278a6
commit 0a5ca98198
5 changed files with 413 additions and 429 deletions

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import os
import warnings
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
@ -15,27 +16,27 @@ class NvidiaPostTrainingConfig(BaseModel):
api_key: Optional[str] = Field(
default_factory=lambda: os.getenv("NVIDIA_API_KEY"),
description="The NVIDIA API key, only needed of using the hosted service",
description="The NVIDIA API key.",
)
user_id: Optional[str] = Field(
default_factory=lambda: os.getenv("NVIDIA_USER_ID", "llama-stack-user"),
description="The NVIDIA user ID, only needed of using the hosted service",
description="The NVIDIA user ID.",
)
dataset_namespace: Optional[str] = Field(
default_factory=lambda: os.getenv("NVIDIA_DATASET_NAMESPACE", "default"),
description="The NVIDIA dataset namespace, only needed of using the hosted service",
description="The NVIDIA dataset namespace.",
)
access_policies: Optional[dict] = Field(
default_factory=lambda: os.getenv("NVIDIA_ACCESS_POLICIES", {}),
description="The NVIDIA access policies, only needed of using the hosted service",
description="The NVIDIA access policies.",
)
project_id: Optional[str] = Field(
default_factory=lambda: os.getenv("NVIDIA_PROJECT_ID", "test-project"),
description="The NVIDIA project ID, only needed of using the hosted service",
description="The NVIDIA project ID.",
)
# ToDO: validate this, add default value
@ -54,11 +55,35 @@ class NvidiaPostTrainingConfig(BaseModel):
description="Maximum number of retries for the NVIDIA Post Training API",
)
# ToDo: validate this, add default value
output_model_dir: str = Field(
default_factory=lambda: os.getenv("NVIDIA_OUTPUT_MODEL_DIR", "test-example-model@v1"),
description="Directory to save the output model",
)
# warning for default values
def __post_init__(self):
default_values = []
if os.getenv("NVIDIA_OUTPUT_MODEL_DIR") is None:
default_values.append("output_model_dir='test-example-model@v1'")
if os.getenv("NVIDIA_PROJECT_ID") is None:
default_values.append("project_id='test-project'")
if os.getenv("NVIDIA_USER_ID") is None:
default_values.append("user_id='llama-stack-user'")
if os.getenv("NVIDIA_DATASET_NAMESPACE") is None:
default_values.append("dataset_namespace='default'")
if os.getenv("NVIDIA_ACCESS_POLICIES") is None:
default_values.append("access_policies='{}'")
if os.getenv("NVIDIA_CUSTOMIZER_URL") is None:
default_values.append("customizer_url='http://nemo.test'")
if default_values:
warnings.warn(
f"Using default values: {', '.join(default_values)}. \
Please set the environment variables to avoid this default behavior.",
stacklevel=2,
)
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
return {

View file

@ -3,6 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import warnings
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional
@ -190,6 +191,13 @@ class NvidiaPostTrainingAdapter:
job_uuid: str - Unique identifier for the job
hyperparam_search_config: Dict[str, Any] - Configuration for hyperparameter search
logger_config: Dict[str, Any] - Configuration for logging
Environment Variables:
- NVIDIA_PROJECT_ID: ID of the project
- NVIDIA_USER_ID: ID of the user
- NVIDIA_ACCESS_POLICIES: Access policies for the project
- NVIDIA_DATASET_NAMESPACE: Namespace of the dataset
- NVIDIA_OUTPUT_MODEL_DIR: Directory to save the output model
"""
# map model to nvidia model name
model_mapping = {
@ -198,9 +206,30 @@ class NvidiaPostTrainingAdapter:
}
nvidia_model = model_mapping.get(model, model)
# Get output model directory from config
# Check for unsupported parameters
if checkpoint_dir or hyperparam_search_config or logger_config:
warnings.warn(
"Parameters: {} not supported atm, will be ignored".format(
checkpoint_dir,
)
)
def warn_unsupported_params(config_dict: Dict[str, Any], supported_keys: List[str], config_name: str) -> None:
"""Helper function to warn about unsupported parameters in a config dictionary."""
unsupported_params = [k for k in config_dict.keys() if k not in supported_keys]
if unsupported_params:
warnings.warn(f"Parameters: {unsupported_params} in {config_name} not supported and will be ignored.")
# Check for unsupported parameters
warn_unsupported_params(training_config, ["n_epochs", "data_config", "optimizer_config"], "TrainingConfig")
warn_unsupported_params(training_config["data_config"], ["dataset_id", "batch_size"], "DataConfig")
warn_unsupported_params(training_config["optimizer_config"], ["lr"], "OptimizerConfig")
output_model = self.config.output_model_dir
if output_model == "default":
warnings.warn("output_model_dir set via default value, will be ignored")
# Prepare base job configuration
job_config = {
"config": nvidia_model,
@ -226,6 +255,7 @@ class NvidiaPostTrainingAdapter:
# Extract LoRA-specific parameters
lora_config = {k: v for k, v in algorithm_config.items() if k != "type"}
job_config["hyperparameters"]["lora"] = lora_config
warn_unsupported_params(lora_config, ["adapter_dim", "adapter_dropout"], "LoRA config")
else:
raise NotImplementedError(f"Unsupported algorithm config: {algorithm_config}")

View file

@ -13,47 +13,13 @@
import logging
from typing import Tuple
import httpx
from .config import NvidiaPostTrainingConfig
logger = logging.getLogger(__name__)
async def _get_health(url: str) -> Tuple[bool, bool]:
"""
Query {url}/v1/health/{live,ready} to check if the server is running and ready
Args:
url (str): URL of the server
Returns:
Tuple[bool, bool]: (is_live, is_ready)
"""
async with httpx.AsyncClient() as client:
live = await client.get(f"{url}/v1/health/live")
ready = await client.get(f"{url}/v1/health/ready")
return live.status_code == 200, ready.status_code == 200
# ToDo: implement post health checks for customizer are enabled
async def _get_health(url: str) -> Tuple[bool, bool]: ...
async def check_health(config: NvidiaPostTrainingConfig) -> None:
"""
Check if the server is running and ready
Args:
url (str): URL of the server
Raises:
RuntimeError: If the server is not running or ready
"""
if not _is_nvidia_hosted(config):
logger.info("Checking NVIDIA NIM health...")
try:
is_live, is_ready = await _get_health(config.url)
if not is_live:
raise ConnectionError("NVIDIA NIM is not running")
if not is_ready:
raise ConnectionError("NVIDIA NIM is not ready")
# TODO(mf): should we wait for the server to be ready?
except httpx.ConnectError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM: {e}") from e
async def check_health(config: NvidiaPostTrainingConfig) -> None: ...