mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
feat: allow dynamic model registration for nvidia inference provider
implements query_available_models on NVIDIAInferenceAdapter
This commit is contained in:
parent
d035fe93c6
commit
6c16e2c0fd
1 changed files with 7 additions and 48 deletions
|
@ -41,11 +41,7 @@ from llama_stack.apis.inference import (
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model, ModelType
|
|
||||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat
|
from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat
|
||||||
from llama_stack.providers.utils.inference import (
|
|
||||||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
)
|
)
|
||||||
|
@ -93,8 +89,12 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
|
|
||||||
self._config = config
|
self._config = config
|
||||||
|
|
||||||
|
async def query_available_models(self) -> list[str]:
|
||||||
|
"""Query available models from the NVIDIA API."""
|
||||||
|
return [model.id async for model in self._get_client().models.list()]
|
||||||
|
|
||||||
@lru_cache # noqa: B019
|
@lru_cache # noqa: B019
|
||||||
def _get_client(self, provider_model_id: str) -> AsyncOpenAI:
|
def _get_client(self, provider_model_id: str | None = None) -> AsyncOpenAI:
|
||||||
"""
|
"""
|
||||||
For hosted models, https://integrate.api.nvidia.com/v1 is the primary base_url. However,
|
For hosted models, https://integrate.api.nvidia.com/v1 is the primary base_url. However,
|
||||||
some models are hosted on different URLs. This function returns the appropriate client
|
some models are hosted on different URLs. This function returns the appropriate client
|
||||||
|
@ -103,7 +103,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
This relies on lru_cache and self._default_client to avoid creating a new client for each request
|
This relies on lru_cache and self._default_client to avoid creating a new client for each request
|
||||||
or for each model that is hosted on https://integrate.api.nvidia.com/v1.
|
or for each model that is hosted on https://integrate.api.nvidia.com/v1.
|
||||||
|
|
||||||
:param provider_model_id: The provider model ID
|
:param provider_model_id: The provider model ID (optional, defaults to primary endpoint)
|
||||||
:return: An OpenAI client
|
:return: An OpenAI client
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -125,7 +125,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
|
|
||||||
base_url = f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
|
base_url = f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
|
||||||
|
|
||||||
if _is_nvidia_hosted(self._config) and provider_model_id in special_model_urls:
|
if provider_model_id and _is_nvidia_hosted(self._config) and provider_model_id in special_model_urls:
|
||||||
base_url = special_model_urls[provider_model_id]
|
base_url = special_model_urls[provider_model_id]
|
||||||
return _get_client_for_base_url(base_url)
|
return _get_client_for_base_url(base_url)
|
||||||
|
|
||||||
|
@ -401,44 +401,3 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
return await self._get_client(provider_model_id).chat.completions.create(**params)
|
return await self._get_client(provider_model_id).chat.completions.create(**params)
|
||||||
except APIConnectionError as e:
|
except APIConnectionError as e:
|
||||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> Model:
|
|
||||||
"""
|
|
||||||
Allow non-llama model registration.
|
|
||||||
|
|
||||||
Non-llama model registration: API Catalogue models, post-training models, etc.
|
|
||||||
client = LlamaStackAsLibraryClient("nvidia")
|
|
||||||
client.models.register(
|
|
||||||
model_id="mistralai/mixtral-8x7b-instruct-v0.1",
|
|
||||||
model_type=ModelType.llm,
|
|
||||||
provider_id="nvidia",
|
|
||||||
provider_model_id="mistralai/mixtral-8x7b-instruct-v0.1"
|
|
||||||
)
|
|
||||||
|
|
||||||
NOTE: Only supports models endpoints compatible with AsyncOpenAI base_url format.
|
|
||||||
"""
|
|
||||||
if model.model_type == ModelType.embedding:
|
|
||||||
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
|
|
||||||
provider_resource_id = model.provider_resource_id
|
|
||||||
else:
|
|
||||||
provider_resource_id = self.get_provider_model_id(model.provider_resource_id)
|
|
||||||
|
|
||||||
if provider_resource_id:
|
|
||||||
model.provider_resource_id = provider_resource_id
|
|
||||||
else:
|
|
||||||
llama_model = model.metadata.get("llama_model")
|
|
||||||
existing_llama_model = self.get_llama_model(model.provider_resource_id)
|
|
||||||
if existing_llama_model:
|
|
||||||
if existing_llama_model != llama_model:
|
|
||||||
raise ValueError(
|
|
||||||
f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# not llama model
|
|
||||||
if llama_model in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
|
|
||||||
self.provider_id_to_llama_model_map[model.provider_resource_id] = (
|
|
||||||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.alias_to_provider_id_map[model.provider_model_id] = model.provider_model_id
|
|
||||||
return model
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue