mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-21 03:59:42 +00:00
In progress: Add NVIDIA e2e notebook
This commit is contained in:
parent
66d6c2580e
commit
861962fa80
18 changed files with 4888 additions and 7 deletions
|
@ -243,7 +243,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
provider_model_id = self.get_provider_model_id(model_id)
|
||||
request = await convert_chat_completion_request(
|
||||
request=ChatCompletionRequest(
|
||||
model=self.get_provider_model_id(model_id),
|
||||
model=provider_model_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
|
|
|
@ -94,7 +94,8 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|||
request_headers["Content-Type"] = "application/json"
|
||||
|
||||
for _ in range(self.config.max_retries):
|
||||
async with self.session.request(method, url, params=params, json=json, **kwargs) as response:
|
||||
# TODO: Remove `verify_ssl=False`. Added for testing purposes to call NMP int environment from `docs/notebooks/nvidia/`
|
||||
async with self.session.request(method, url, params=params, json=json, verify_ssl=False, **kwargs) as response:
|
||||
if response.status >= 400:
|
||||
error_data = await response.json()
|
||||
raise Exception(f"API request failed: {error_data}")
|
||||
|
@ -392,19 +393,20 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|||
|
||||
# Handle LoRA-specific configuration
|
||||
if algorithm_config:
|
||||
if isinstance(algorithm_config, dict) and algorithm_config.get("type") == "LoRA":
|
||||
algortihm_config_dict = algorithm_config.model_dump()
|
||||
if algortihm_config_dict.get("type") == "LoRA":
|
||||
warn_unsupported_params(algorithm_config, supported_params["lora_config"], "LoRA config")
|
||||
job_config["hyperparameters"]["lora"] = {
|
||||
k: v
|
||||
for k, v in {
|
||||
"adapter_dim": algorithm_config.get("adapter_dim"),
|
||||
"alpha": algorithm_config.get("alpha"),
|
||||
"adapter_dropout": algorithm_config.get("adapter_dropout"),
|
||||
"adapter_dim": algortihm_config_dict.get("adapter_dim"),
|
||||
"alpha": algortihm_config_dict.get("alpha"),
|
||||
"adapter_dropout": algortihm_config_dict.get("adapter_dropout"),
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported algorithm config: {algorithm_config}")
|
||||
raise NotImplementedError(f"JASH was here Unsupported algorithm config: {algorithm_config}")
|
||||
|
||||
# Create the customization job
|
||||
response = await self._make_request(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue