In progress: Add NVIDIA e2e notebook

This commit is contained in:
Jash Gulabrai 2025-04-03 11:19:43 -04:00
parent 66d6c2580e
commit 861962fa80
18 changed files with 4888 additions and 7 deletions

View file

@ -243,7 +243,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
provider_model_id = self.get_provider_model_id(model_id)
request = await convert_chat_completion_request(
request=ChatCompletionRequest(
model=self.get_provider_model_id(model_id),
model=provider_model_id,
messages=messages,
sampling_params=sampling_params,
response_format=response_format,

View file

@ -94,7 +94,8 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
request_headers["Content-Type"] = "application/json"
for _ in range(self.config.max_retries):
async with self.session.request(method, url, params=params, json=json, **kwargs) as response:
# TODO: Remove `verify_ssl=False`. Added for testing purposes to call NMP int environment from `docs/notebooks/nvidia/`
async with self.session.request(method, url, params=params, json=json, verify_ssl=False, **kwargs) as response:
if response.status >= 400:
error_data = await response.json()
raise Exception(f"API request failed: {error_data}")
@ -392,19 +393,20 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
# Handle LoRA-specific configuration
if algorithm_config:
if isinstance(algorithm_config, dict) and algorithm_config.get("type") == "LoRA":
algortihm_config_dict = algorithm_config.model_dump()
if algortihm_config_dict.get("type") == "LoRA":
warn_unsupported_params(algorithm_config, supported_params["lora_config"], "LoRA config")
job_config["hyperparameters"]["lora"] = {
k: v
for k, v in {
"adapter_dim": algorithm_config.get("adapter_dim"),
"alpha": algorithm_config.get("alpha"),
"adapter_dropout": algorithm_config.get("adapter_dropout"),
"adapter_dim": algortihm_config_dict.get("adapter_dim"),
"alpha": algortihm_config_dict.get("alpha"),
"adapter_dropout": algortihm_config_dict.get("adapter_dropout"),
}.items()
if v is not None
}
else:
raise NotImplementedError(f"Unsupported algorithm config: {algorithm_config}")
raise NotImplementedError(f"JASH was here Unsupported algorithm config: {algorithm_config}")
# Create the customization job
response = await self._make_request(