mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-18 19:02:30 +00:00
feat: allow dynamic model registration for nvidia inference provider (#2726)
# What does this PR do? let's users register models available at https://integrate.api.nvidia.com/v1/models that isn't already in llama_stack/providers/remote/inference/nvidia/models.py ## Test Plan 1. run the nvidia distro 2. register a model from https://integrate.api.nvidia.com/v1/models that isn't already know, as of this writing nvidia/llama-3.1-nemotron-ultra-253b-v1 is a good example 3. perform inference w/ the model
This commit is contained in:
parent
57745101be
commit
477bcd4d09
2 changed files with 23 additions and 48 deletions
|
@ -9,7 +9,7 @@ import warnings
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from openai import APIConnectionError, AsyncOpenAI, BadRequestError
|
from openai import APIConnectionError, AsyncOpenAI, BadRequestError, NotFoundError
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
|
@ -40,11 +40,7 @@ from llama_stack.apis.inference import (
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model, ModelType
|
|
||||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat
|
from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat
|
||||||
from llama_stack.providers.utils.inference import (
|
|
||||||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
)
|
)
|
||||||
|
@ -92,6 +88,22 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
|
|
||||||
self._config = config
|
self._config = config
|
||||||
|
|
||||||
|
async def check_model_availability(self, model: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a specific model is available.
|
||||||
|
|
||||||
|
:param model: The model identifier to check.
|
||||||
|
:return: True if the model is available dynamically, False otherwise.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await self._client.models.retrieve(model)
|
||||||
|
return True
|
||||||
|
except NotFoundError:
|
||||||
|
logger.error(f"Model {model} is not available")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to check model availability: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _client(self) -> AsyncOpenAI:
|
def _client(self) -> AsyncOpenAI:
|
||||||
"""
|
"""
|
||||||
|
@ -380,44 +392,3 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
return await self._client.chat.completions.create(**params)
|
return await self._client.chat.completions.create(**params)
|
||||||
except APIConnectionError as e:
|
except APIConnectionError as e:
|
||||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> Model:
|
|
||||||
"""
|
|
||||||
Allow non-llama model registration.
|
|
||||||
|
|
||||||
Non-llama model registration: API Catalogue models, post-training models, etc.
|
|
||||||
client = LlamaStackAsLibraryClient("nvidia")
|
|
||||||
client.models.register(
|
|
||||||
model_id="mistralai/mixtral-8x7b-instruct-v0.1",
|
|
||||||
model_type=ModelType.llm,
|
|
||||||
provider_id="nvidia",
|
|
||||||
provider_model_id="mistralai/mixtral-8x7b-instruct-v0.1"
|
|
||||||
)
|
|
||||||
|
|
||||||
NOTE: Only supports models endpoints compatible with AsyncOpenAI base_url format.
|
|
||||||
"""
|
|
||||||
if model.model_type == ModelType.embedding:
|
|
||||||
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
|
|
||||||
provider_resource_id = model.provider_resource_id
|
|
||||||
else:
|
|
||||||
provider_resource_id = self.get_provider_model_id(model.provider_resource_id)
|
|
||||||
|
|
||||||
if provider_resource_id:
|
|
||||||
model.provider_resource_id = provider_resource_id
|
|
||||||
else:
|
|
||||||
llama_model = model.metadata.get("llama_model")
|
|
||||||
existing_llama_model = self.get_llama_model(model.provider_resource_id)
|
|
||||||
if existing_llama_model:
|
|
||||||
if existing_llama_model != llama_model:
|
|
||||||
raise ValueError(
|
|
||||||
f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# not llama model
|
|
||||||
if llama_model in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
|
|
||||||
self.provider_id_to_llama_model_map[model.provider_resource_id] = (
|
|
||||||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.alias_to_provider_id_map[model.provider_model_id] = model.provider_model_id
|
|
||||||
return model
|
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
import warnings
|
import warnings
|
||||||
from unittest.mock import patch
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
@ -343,6 +343,10 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
||||||
provider_resource_id=model_id,
|
provider_resource_id=model_id,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# simulate a NIM where default/job-1234 is an available model
|
||||||
|
with patch.object(self.inference_adapter, "check_model_availability", new_callable=AsyncMock) as mock_check:
|
||||||
|
mock_check.return_value = True
|
||||||
result = self.run_async(self.inference_adapter.register_model(model))
|
result = self.run_async(self.inference_adapter.register_model(model))
|
||||||
assert result == model
|
assert result == model
|
||||||
assert len(self.inference_adapter.alias_to_provider_id_map) > 1
|
assert len(self.inference_adapter.alias_to_provider_id_map) > 1
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue