feat: NVIDIA allow non-llama model registration (#1859)

# What does this PR do?
Adds custom model registration functionality to NVIDIAInferenceAdapter
which let's the inference happen on:
- post-training model
- non-llama models in API Catalogue(behind
https://integrate.api.nvidia.com and endpoints compatible with
AyncOpenAI)

## Example Usage:
```python
from llama_stack.apis.models import Model, ModelType
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
client = LlamaStackAsLibraryClient("nvidia")
_ = client.initialize()

client.models.register(
        model_id=model_name,
        model_type=ModelType.llm,
        provider_id="nvidia"
)

response = client.inference.chat_completion(
    model_id=model_name,
    messages=[{"role":"system","content":"You are a helpful assistant."},{"role":"user","content":"Write a limerick about the wonders of GPU computing."}],
)
```

## Test Plan
```bash
pytest tests/unit/providers/nvidia/test_supervised_fine_tuning.py 
========================================================== test session starts ===========================================================
platform linux -- Python 3.10.0, pytest-8.3.5, pluggy-1.5.0
rootdir: /home/ubuntu/llama-stack
configfile: pyproject.toml
plugins: anyio-4.9.0
collected 6 items                                                                                                                        

tests/unit/providers/nvidia/test_supervised_fine_tuning.py ......                                                                  [100%]

============================================================ warnings summary ============================================================
../miniconda/envs/nvidia-1/lib/python3.10/site-packages/pydantic/fields.py:1076
  /home/ubuntu/miniconda/envs/nvidia-1/lib/python3.10/site-packages/pydantic/fields.py:1076: PydanticDeprecatedSince20: Using extra keyword arguments on `Field` is deprecated and will be removed. Use `json_schema_extra` instead. (Extra keys: 'contentEncoding'). Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.11/migration/
    warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
====================================================== 6 passed, 1 warning in 1.51s ======================================================
```

[//]: # (## Documentation)
Updated Readme.md

cc: @dglogo, @sumitb, @mattf
This commit is contained in:
Rashmi Pawar 2025-04-25 05:43:33 +05:30 committed by GitHub
parent cc77f79f55
commit ace82836c1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 116 additions and 15 deletions

View file

@ -33,7 +33,6 @@ from llama_stack.apis.inference import (
TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
@ -42,7 +41,11 @@ from llama_stack.apis.inference.inference import (
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.models.llama.datatypes import ToolPromptFormat
from llama_stack.apis.models import Model, ModelType
from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat
from llama_stack.providers.utils.inference import (
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
)
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
@ -120,10 +123,10 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
"meta/llama-3.2-90b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct",
}
base_url = f"{self._config.url}/v1"
base_url = f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
if _is_nvidia_hosted(self._config) and provider_model_id in special_model_urls:
base_url = special_model_urls[provider_model_id]
return _get_client_for_base_url(base_url)
async def _get_provider_model_id(self, model_id: str) -> str:
@ -387,3 +390,44 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
return await self._get_client(provider_model_id).chat.completions.create(**params)
except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
async def register_model(self, model: Model) -> Model:
"""
Allow non-llama model registration.
Non-llama model registration: API Catalogue models, post-training models, etc.
client = LlamaStackAsLibraryClient("nvidia")
client.models.register(
model_id="mistralai/mixtral-8x7b-instruct-v0.1",
model_type=ModelType.llm,
provider_id="nvidia",
provider_model_id="mistralai/mixtral-8x7b-instruct-v0.1"
)
NOTE: Only supports models endpoints compatible with AsyncOpenAI base_url format.
"""
if model.model_type == ModelType.embedding:
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
provider_resource_id = model.provider_resource_id
else:
provider_resource_id = self.get_provider_model_id(model.provider_resource_id)
if provider_resource_id:
model.provider_resource_id = provider_resource_id
else:
llama_model = model.metadata.get("llama_model")
existing_llama_model = self.get_llama_model(model.provider_resource_id)
if existing_llama_model:
if existing_llama_model != llama_model:
raise ValueError(
f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'"
)
else:
# not llama model
if llama_model in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
self.provider_id_to_llama_model_map[model.provider_resource_id] = (
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
)
else:
self.alias_to_provider_id_map[model.provider_model_id] = model.provider_model_id
return model