mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 01:03:59 +00:00
fix: Correctly parse algorithm_config when launching NVIDIA customization job
This commit is contained in:
parent
7ed137e963
commit
26c10b5ab5
1 changed files with 21 additions and 12 deletions
|
@ -67,13 +67,18 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
||||||
self.timeout = aiohttp.ClientTimeout(total=config.timeout)
|
self.timeout = aiohttp.ClientTimeout(total=config.timeout)
|
||||||
# TODO: filter by available models based on /config endpoint
|
# TODO: filter by available models based on /config endpoint
|
||||||
ModelRegistryHelper.__init__(self, model_entries=_MODEL_ENTRIES)
|
ModelRegistryHelper.__init__(self, model_entries=_MODEL_ENTRIES)
|
||||||
self.session = aiohttp.ClientSession(headers=self.headers, timeout=self.timeout)
|
self.session = None
|
||||||
self.customizer_url = config.customizer_url
|
|
||||||
|
|
||||||
|
self.customizer_url = config.customizer_url
|
||||||
if not self.customizer_url:
|
if not self.customizer_url:
|
||||||
warnings.warn("Customizer URL is not set, using default value: http://nemo.test", stacklevel=2)
|
warnings.warn("Customizer URL is not set, using default value: http://nemo.test", stacklevel=2)
|
||||||
self.customizer_url = "http://nemo.test"
|
self.customizer_url = "http://nemo.test"
|
||||||
|
|
||||||
|
async def _get_session(self) -> aiohttp.ClientSession:
|
||||||
|
if self.session is None or self.session.closed:
|
||||||
|
self.session = aiohttp.ClientSession(headers=self.headers, timeout=self.timeout)
|
||||||
|
return self.session
|
||||||
|
|
||||||
async def _make_request(
|
async def _make_request(
|
||||||
self,
|
self,
|
||||||
method: str,
|
method: str,
|
||||||
|
@ -94,8 +99,9 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
||||||
if json and "Content-Type" not in request_headers:
|
if json and "Content-Type" not in request_headers:
|
||||||
request_headers["Content-Type"] = "application/json"
|
request_headers["Content-Type"] = "application/json"
|
||||||
|
|
||||||
|
session = await self._get_session()
|
||||||
for _ in range(self.config.max_retries):
|
for _ in range(self.config.max_retries):
|
||||||
async with self.session.request(method, url, params=params, json=json, **kwargs) as response:
|
async with session.request(method, url, params=params, json=json, **kwargs) as response:
|
||||||
if response.status >= 400:
|
if response.status >= 400:
|
||||||
error_data = await response.json()
|
error_data = await response.json()
|
||||||
raise Exception(f"API request failed: {error_data}")
|
raise Exception(f"API request failed: {error_data}")
|
||||||
|
@ -122,8 +128,8 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
||||||
jobs = []
|
jobs = []
|
||||||
for job in response.get("data", []):
|
for job in response.get("data", []):
|
||||||
job_id = job.pop("id")
|
job_id = job.pop("id")
|
||||||
job_status = job.pop("status", "unknown").lower()
|
job_status = job.pop("status", "scheduled").lower()
|
||||||
mapped_status = STATUS_MAPPING.get(job_status, "unknown")
|
mapped_status = STATUS_MAPPING.get(job_status, "scheduled")
|
||||||
|
|
||||||
# Convert string timestamps to datetime objects
|
# Convert string timestamps to datetime objects
|
||||||
created_at = (
|
created_at = (
|
||||||
|
@ -177,7 +183,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
||||||
)
|
)
|
||||||
|
|
||||||
api_status = response.pop("status").lower()
|
api_status = response.pop("status").lower()
|
||||||
mapped_status = STATUS_MAPPING.get(api_status, "unknown")
|
mapped_status = STATUS_MAPPING.get(api_status, "scheduled")
|
||||||
|
|
||||||
return NvidiaPostTrainingJobStatusResponse(
|
return NvidiaPostTrainingJobStatusResponse(
|
||||||
status=JobStatus(mapped_status),
|
status=JobStatus(mapped_status),
|
||||||
|
@ -297,7 +303,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
||||||
User is informed about unsupported parameters via warnings.
|
User is informed about unsupported parameters via warnings.
|
||||||
"""
|
"""
|
||||||
# Map model to nvidia model name
|
# Map model to nvidia model name
|
||||||
# ToDo: only supports llama-3.1-8b-instruct now, need to update this to support other models
|
# See `_MODEL_ENTRIES` for supported models
|
||||||
nvidia_model = self.get_provider_model_id(model)
|
nvidia_model = self.get_provider_model_id(model)
|
||||||
|
|
||||||
# Check for unsupported method parameters
|
# Check for unsupported method parameters
|
||||||
|
@ -389,14 +395,17 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
||||||
|
|
||||||
# Handle LoRA-specific configuration
|
# Handle LoRA-specific configuration
|
||||||
if algorithm_config:
|
if algorithm_config:
|
||||||
if isinstance(algorithm_config, dict) and algorithm_config.get("type") == "LoRA":
|
algorithm_config_dict = (
|
||||||
warn_unsupported_params(algorithm_config, supported_params["lora_config"], "LoRA config")
|
algorithm_config.model_dump() if hasattr(algorithm_config, "model_dump") else algorithm_config
|
||||||
|
)
|
||||||
|
if isinstance(algorithm_config_dict, dict) and algorithm_config_dict.get("type") == "LoRA":
|
||||||
|
warn_unsupported_params(algorithm_config_dict, supported_params["lora_config"], "LoRA config")
|
||||||
job_config["hyperparameters"]["lora"] = {
|
job_config["hyperparameters"]["lora"] = {
|
||||||
k: v
|
k: v
|
||||||
for k, v in {
|
for k, v in {
|
||||||
"adapter_dim": algorithm_config.get("adapter_dim"),
|
"adapter_dim": algorithm_config_dict.get("adapter_dim"),
|
||||||
"alpha": algorithm_config.get("alpha"),
|
"alpha": algorithm_config_dict.get("alpha"),
|
||||||
"adapter_dropout": algorithm_config.get("adapter_dropout"),
|
"adapter_dropout": algorithm_config_dict.get("adapter_dropout"),
|
||||||
}.items()
|
}.items()
|
||||||
if v is not None
|
if v is not None
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue