Clean up instructions and implementation; reorganize notebooks

This commit is contained in:
Jash Gulabrai 2025-04-18 16:27:19 -04:00
parent 0d9d333a4e
commit 4131e8146f
29 changed files with 2756 additions and 89 deletions

View file

@ -47,10 +47,15 @@ class NVIDIAConfig(BaseModel):
default=60,
description="Timeout for the HTTP requests",
)
append_api_version: bool = Field(
default_factory=lambda: os.getenv("NVIDIA_APPEND_API_VERSION", "True").lower() != "false",
description="When set to false, the API version will not be appended to the base_url. By default, it is true.",
)
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
return {
"url": "${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com}",
"api_key": "${env.NVIDIA_API_KEY:}",
"append_api_version": "${env.NVIDIA_APPEND_API_VERSION:True}",
}

View file

@ -42,10 +42,7 @@ from llama_stack.apis.inference.inference import (
OpenAIResponseFormatParam,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.models.llama.datatypes import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat
from llama_stack.providers.utils.inference import (
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
)
@ -126,15 +123,10 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
"meta/llama-3.2-90b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct",
}
# add /v1 in case of hosted models
base_url = self._config.url
if _is_nvidia_hosted(self._config):
if provider_model_id in special_model_urls:
base_url = special_model_urls[provider_model_id]
else:
base_url = f"{self._config.url}/v1"
elif "nim.int.aire.nvidia.com" in base_url:
base_url = f"{base_url}/v1"
base_url = f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
if _is_nvidia_hosted(self._config) and provider_model_id in special_model_urls:
base_url = special_model_urls[provider_model_id]
return _get_client_for_base_url(base_url)
async def completion(
@ -258,9 +250,10 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
# await check_health(self._config) # this raises errors
provider_model_id = self.get_provider_model_id(model_id)
print(f"provider_model_id: {provider_model_id}")
request = await convert_chat_completion_request(
request=ChatCompletionRequest(
model=provider_model_id,
model=self.get_provider_model_id(model_id),
messages=messages,
sampling_params=sampling_params,
response_format=response_format,