diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index 4be76380b..158a2738e 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -89,9 +89,14 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): self._config = config - async def query_available_models(self) -> list[str]: - """Query available models from the NVIDIA API.""" - return [model.id async for model in self._get_client().models.list()] + async def check_model_availability(self, model: str) -> bool: + """Check if a specific model is available from the NVIDIA API.""" + try: + await self._get_client().models.retrieve(model) + return True + except Exception: + # If we can't retrieve the model, it's not available + return False @lru_cache # noqa: B019 def _get_client(self, provider_model_id: str | None = None) -> AsyncOpenAI: diff --git a/tests/unit/providers/nvidia/test_supervised_fine_tuning.py b/tests/unit/providers/nvidia/test_supervised_fine_tuning.py index 0b793a165..400705fe1 100644 --- a/tests/unit/providers/nvidia/test_supervised_fine_tuning.py +++ b/tests/unit/providers/nvidia/test_supervised_fine_tuning.py @@ -344,8 +344,8 @@ class TestNvidiaPostTraining(unittest.TestCase): ) # simulate a NIM where default/job-1234 is an available model - with patch.object(self.inference_adapter, "query_available_models", new_callable=AsyncMock) as mock_query: - mock_query.return_value = [model_id] + with patch.object(self.inference_adapter, "check_model_availability", new_callable=AsyncMock) as mock_check: + mock_check.return_value = True result = self.run_async(self.inference_adapter.register_model(model)) assert result == model assert len(self.inference_adapter.alias_to_provider_id_map) > 1