mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-31 07:39:59 +00:00
unit test update, warnings for unsupported parameters
This commit is contained in:
parent
152261a249
commit
bd9b6a6e00
5 changed files with 413 additions and 429 deletions
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
|
@ -15,27 +16,27 @@ class NvidiaPostTrainingConfig(BaseModel):
|
|||
|
||||
api_key: Optional[str] = Field(
|
||||
default_factory=lambda: os.getenv("NVIDIA_API_KEY"),
|
||||
description="The NVIDIA API key, only needed of using the hosted service",
|
||||
description="The NVIDIA API key.",
|
||||
)
|
||||
|
||||
user_id: Optional[str] = Field(
|
||||
default_factory=lambda: os.getenv("NVIDIA_USER_ID", "llama-stack-user"),
|
||||
description="The NVIDIA user ID, only needed of using the hosted service",
|
||||
description="The NVIDIA user ID.",
|
||||
)
|
||||
|
||||
dataset_namespace: Optional[str] = Field(
|
||||
default_factory=lambda: os.getenv("NVIDIA_DATASET_NAMESPACE", "default"),
|
||||
description="The NVIDIA dataset namespace, only needed of using the hosted service",
|
||||
description="The NVIDIA dataset namespace.",
|
||||
)
|
||||
|
||||
access_policies: Optional[dict] = Field(
|
||||
default_factory=lambda: os.getenv("NVIDIA_ACCESS_POLICIES", {}),
|
||||
description="The NVIDIA access policies, only needed of using the hosted service",
|
||||
description="The NVIDIA access policies.",
|
||||
)
|
||||
|
||||
project_id: Optional[str] = Field(
|
||||
default_factory=lambda: os.getenv("NVIDIA_PROJECT_ID", "test-project"),
|
||||
description="The NVIDIA project ID, only needed of using the hosted service",
|
||||
description="The NVIDIA project ID.",
|
||||
)
|
||||
|
||||
# ToDO: validate this, add default value
|
||||
|
|
@ -54,11 +55,35 @@ class NvidiaPostTrainingConfig(BaseModel):
|
|||
description="Maximum number of retries for the NVIDIA Post Training API",
|
||||
)
|
||||
|
||||
# ToDo: validate this, add default value
|
||||
output_model_dir: str = Field(
|
||||
default_factory=lambda: os.getenv("NVIDIA_OUTPUT_MODEL_DIR", "test-example-model@v1"),
|
||||
description="Directory to save the output model",
|
||||
)
|
||||
|
||||
# warning for default values
|
||||
def __post_init__(self):
|
||||
default_values = []
|
||||
if os.getenv("NVIDIA_OUTPUT_MODEL_DIR") is None:
|
||||
default_values.append("output_model_dir='test-example-model@v1'")
|
||||
if os.getenv("NVIDIA_PROJECT_ID") is None:
|
||||
default_values.append("project_id='test-project'")
|
||||
if os.getenv("NVIDIA_USER_ID") is None:
|
||||
default_values.append("user_id='llama-stack-user'")
|
||||
if os.getenv("NVIDIA_DATASET_NAMESPACE") is None:
|
||||
default_values.append("dataset_namespace='default'")
|
||||
if os.getenv("NVIDIA_ACCESS_POLICIES") is None:
|
||||
default_values.append("access_policies='{}'")
|
||||
if os.getenv("NVIDIA_CUSTOMIZER_URL") is None:
|
||||
default_values.append("customizer_url='http://nemo.test'")
|
||||
|
||||
if default_values:
|
||||
warnings.warn(
|
||||
f"Using default values: {', '.join(default_values)}. \
|
||||
Please set the environment variables to avoid this default behavior.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import warnings
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
|
|
@ -190,6 +191,13 @@ class NvidiaPostTrainingAdapter:
|
|||
job_uuid: str - Unique identifier for the job
|
||||
hyperparam_search_config: Dict[str, Any] - Configuration for hyperparameter search
|
||||
logger_config: Dict[str, Any] - Configuration for logging
|
||||
|
||||
Environment Variables:
|
||||
- NVIDIA_PROJECT_ID: ID of the project
|
||||
- NVIDIA_USER_ID: ID of the user
|
||||
- NVIDIA_ACCESS_POLICIES: Access policies for the project
|
||||
- NVIDIA_DATASET_NAMESPACE: Namespace of the dataset
|
||||
- NVIDIA_OUTPUT_MODEL_DIR: Directory to save the output model
|
||||
"""
|
||||
# map model to nvidia model name
|
||||
model_mapping = {
|
||||
|
|
@ -198,9 +206,30 @@ class NvidiaPostTrainingAdapter:
|
|||
}
|
||||
nvidia_model = model_mapping.get(model, model)
|
||||
|
||||
# Get output model directory from config
|
||||
# Check for unsupported parameters
|
||||
if checkpoint_dir or hyperparam_search_config or logger_config:
|
||||
warnings.warn(
|
||||
"Parameters: {} not supported atm, will be ignored".format(
|
||||
checkpoint_dir,
|
||||
)
|
||||
)
|
||||
|
||||
def warn_unsupported_params(config_dict: Dict[str, Any], supported_keys: List[str], config_name: str) -> None:
|
||||
"""Helper function to warn about unsupported parameters in a config dictionary."""
|
||||
unsupported_params = [k for k in config_dict.keys() if k not in supported_keys]
|
||||
if unsupported_params:
|
||||
warnings.warn(f"Parameters: {unsupported_params} in {config_name} not supported and will be ignored.")
|
||||
|
||||
# Check for unsupported parameters
|
||||
warn_unsupported_params(training_config, ["n_epochs", "data_config", "optimizer_config"], "TrainingConfig")
|
||||
warn_unsupported_params(training_config["data_config"], ["dataset_id", "batch_size"], "DataConfig")
|
||||
warn_unsupported_params(training_config["optimizer_config"], ["lr"], "OptimizerConfig")
|
||||
|
||||
output_model = self.config.output_model_dir
|
||||
|
||||
if output_model == "default":
|
||||
warnings.warn("output_model_dir set via default value, will be ignored")
|
||||
|
||||
# Prepare base job configuration
|
||||
job_config = {
|
||||
"config": nvidia_model,
|
||||
|
|
@ -226,6 +255,7 @@ class NvidiaPostTrainingAdapter:
|
|||
# Extract LoRA-specific parameters
|
||||
lora_config = {k: v for k, v in algorithm_config.items() if k != "type"}
|
||||
job_config["hyperparameters"]["lora"] = lora_config
|
||||
warn_unsupported_params(lora_config, ["adapter_dim", "adapter_dropout"], "LoRA config")
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported algorithm config: {algorithm_config}")
|
||||
|
||||
|
|
|
|||
|
|
@ -13,47 +13,13 @@
|
|||
import logging
|
||||
from typing import Tuple
|
||||
|
||||
import httpx
|
||||
|
||||
from .config import NvidiaPostTrainingConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _get_health(url: str) -> Tuple[bool, bool]:
|
||||
"""
|
||||
Query {url}/v1/health/{live,ready} to check if the server is running and ready
|
||||
|
||||
Args:
|
||||
url (str): URL of the server
|
||||
|
||||
Returns:
|
||||
Tuple[bool, bool]: (is_live, is_ready)
|
||||
"""
|
||||
async with httpx.AsyncClient() as client:
|
||||
live = await client.get(f"{url}/v1/health/live")
|
||||
ready = await client.get(f"{url}/v1/health/ready")
|
||||
return live.status_code == 200, ready.status_code == 200
|
||||
# ToDo: implement post health checks for customizer are enabled
|
||||
async def _get_health(url: str) -> Tuple[bool, bool]: ...
|
||||
|
||||
|
||||
async def check_health(config: NvidiaPostTrainingConfig) -> None:
|
||||
"""
|
||||
Check if the server is running and ready
|
||||
|
||||
Args:
|
||||
url (str): URL of the server
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the server is not running or ready
|
||||
"""
|
||||
if not _is_nvidia_hosted(config):
|
||||
logger.info("Checking NVIDIA NIM health...")
|
||||
try:
|
||||
is_live, is_ready = await _get_health(config.url)
|
||||
if not is_live:
|
||||
raise ConnectionError("NVIDIA NIM is not running")
|
||||
if not is_ready:
|
||||
raise ConnectionError("NVIDIA NIM is not ready")
|
||||
# TODO(mf): should we wait for the server to be ready?
|
||||
except httpx.ConnectError as e:
|
||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM: {e}") from e
|
||||
async def check_health(config: NvidiaPostTrainingConfig) -> None: ...
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue