mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-23 02:42:24 +00:00
query_available_models() -> list[str] -> check_model_availability(model) -> bool
This commit is contained in:
parent
c2ab8988e6
commit
f69ae45127
2 changed files with 10 additions and 5 deletions
|
|
@ -89,9 +89,14 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
|
|
||||||
self._config = config
|
self._config = config
|
||||||
|
|
||||||
async def query_available_models(self) -> list[str]:
|
async def check_model_availability(self, model: str) -> bool:
|
||||||
"""Query available models from the NVIDIA API."""
|
"""Check if a specific model is available from the NVIDIA API."""
|
||||||
return [model.id async for model in self._get_client().models.list()]
|
try:
|
||||||
|
await self._get_client().models.retrieve(model)
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
# If we can't retrieve the model, it's not available
|
||||||
|
return False
|
||||||
|
|
||||||
@lru_cache # noqa: B019
|
@lru_cache # noqa: B019
|
||||||
def _get_client(self, provider_model_id: str | None = None) -> AsyncOpenAI:
|
def _get_client(self, provider_model_id: str | None = None) -> AsyncOpenAI:
|
||||||
|
|
|
||||||
|
|
@ -344,8 +344,8 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# simulate a NIM where default/job-1234 is an available model
|
# simulate a NIM where default/job-1234 is an available model
|
||||||
with patch.object(self.inference_adapter, "query_available_models", new_callable=AsyncMock) as mock_query:
|
with patch.object(self.inference_adapter, "check_model_availability", new_callable=AsyncMock) as mock_check:
|
||||||
mock_query.return_value = [model_id]
|
mock_check.return_value = True
|
||||||
result = self.run_async(self.inference_adapter.register_model(model))
|
result = self.run_async(self.inference_adapter.register_model(model))
|
||||||
assert result == model
|
assert result == model
|
||||||
assert len(self.inference_adapter.alias_to_provider_id_map) > 1
|
assert len(self.inference_adapter.alias_to_provider_id_map) > 1
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue