diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index f790c2312..cb7554523 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -9,7 +9,7 @@ import warnings from collections.abc import AsyncIterator from typing import Any -from openai import APIConnectionError, AsyncOpenAI, BadRequestError +from openai import APIConnectionError, AsyncOpenAI, BadRequestError, NotFoundError from llama_stack.apis.common.content_types import ( InterleavedContent, @@ -40,11 +40,7 @@ from llama_stack.apis.inference import ( ToolChoice, ToolConfig, ) -from llama_stack.apis.models import Model, ModelType from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat -from llama_stack.providers.utils.inference import ( - ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR, -) from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, ) @@ -92,6 +88,22 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): self._config = config + async def check_model_availability(self, model: str) -> bool: + """ + Check if a specific model is available. + + :param model: The model identifier to check. + :return: True if the model is available dynamically, False otherwise. + """ + try: + await self._client.models.retrieve(model) + return True + except NotFoundError: + logger.error(f"Model {model} is not available") + except Exception as e: + logger.error(f"Failed to check model availability: {e}") + return False + @property def _client(self) -> AsyncOpenAI: """ @@ -380,44 +392,3 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): return await self._client.chat.completions.create(**params) except APIConnectionError as e: raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e - - async def register_model(self, model: Model) -> Model: - """ - Allow non-llama model registration. - - Non-llama model registration: API Catalogue models, post-training models, etc. - client = LlamaStackAsLibraryClient("nvidia") - client.models.register( - model_id="mistralai/mixtral-8x7b-instruct-v0.1", - model_type=ModelType.llm, - provider_id="nvidia", - provider_model_id="mistralai/mixtral-8x7b-instruct-v0.1" - ) - - NOTE: Only supports models endpoints compatible with AsyncOpenAI base_url format. - """ - if model.model_type == ModelType.embedding: - # embedding models are always registered by their provider model id and does not need to be mapped to a llama model - provider_resource_id = model.provider_resource_id - else: - provider_resource_id = self.get_provider_model_id(model.provider_resource_id) - - if provider_resource_id: - model.provider_resource_id = provider_resource_id - else: - llama_model = model.metadata.get("llama_model") - existing_llama_model = self.get_llama_model(model.provider_resource_id) - if existing_llama_model: - if existing_llama_model != llama_model: - raise ValueError( - f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'" - ) - else: - # not llama model - if llama_model in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR: - self.provider_id_to_llama_model_map[model.provider_resource_id] = ( - ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model] - ) - else: - self.alias_to_provider_id_map[model.provider_model_id] = model.provider_model_id - return model diff --git a/tests/unit/providers/nvidia/test_supervised_fine_tuning.py b/tests/unit/providers/nvidia/test_supervised_fine_tuning.py index f75b0add9..bbbb60a30 100644 --- a/tests/unit/providers/nvidia/test_supervised_fine_tuning.py +++ b/tests/unit/providers/nvidia/test_supervised_fine_tuning.py @@ -7,7 +7,7 @@ import os import unittest import warnings -from unittest.mock import patch +from unittest.mock import AsyncMock, patch import pytest @@ -343,7 +343,11 @@ class TestNvidiaPostTraining(unittest.TestCase): provider_resource_id=model_id, model_type=model_type, ) - result = self.run_async(self.inference_adapter.register_model(model)) + + # simulate a NIM where default/job-1234 is an available model + with patch.object(self.inference_adapter, "check_model_availability", new_callable=AsyncMock) as mock_check: + mock_check.return_value = True + result = self.run_async(self.inference_adapter.register_model(model)) assert result == model assert len(self.inference_adapter.alias_to_provider_id_map) > 1 assert self.inference_adapter.get_provider_model_id(model.provider_model_id) == model_id