From 3d2b374ee7823cf0222a403c10f5a89d80d3e686 Mon Sep 17 00:00:00 2001 From: raspawar Date: Wed, 2 Apr 2025 09:51:55 +0000 Subject: [PATCH] add register_model method --- .../remote/inference/nvidia/nvidia.py | 56 +++++++++++++++++-- .../nvidia/test_supervised_fine_tuning.py | 53 ++++++++++++++++-- 2 files changed, 99 insertions(+), 10 deletions(-) diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index 5caf19fda..91028203f 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -33,11 +33,15 @@ from llama_stack.apis.inference import ( ToolChoice, ToolConfig, ) +from llama_stack.apis.models import Model, ModelType from llama_stack.models.llama.datatypes import ( SamplingParams, ToolDefinition, ToolPromptFormat, ) +from llama_stack.providers.utils.inference import ( + ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR, +) from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, ) @@ -114,10 +118,13 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): "meta/llama-3.2-90b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct", } - base_url = f"{self._config.url}/v1" - if _is_nvidia_hosted(self._config) and provider_model_id in special_model_urls: - base_url = special_model_urls[provider_model_id] - + # add /v1 in case of hosted models + base_url = self._config.url + if _is_nvidia_hosted(self._config): + if provider_model_id in special_model_urls: + base_url = special_model_urls[provider_model_id] + else: + base_url = f"{self._config.url}/v1" return _get_client_for_base_url(base_url) async def completion( @@ -265,3 +272,44 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): else: # we pass n=1 to get only one completion return convert_openai_chat_completion_choice(response.choices[0]) + + async def register_model(self, model: Model) -> Model: + """ + Allow non-llama model registration. + + Non-llama model registration: API Catalogue models, post-training models, etc. + client = LlamaStackAsLibraryClient("nvidia") + client.models.register( + model_id="mistralai/mixtral-8x7b-instruct-v0.1", + model_type=ModelType.llm, + provider_id="nvidia", + provider_model_id="mistralai/mixtral-8x7b-instruct-v0.1" + ) + + NOTE: Only supports models endpoints compatible with AsyncOpenAI base_url format. + """ + if model.model_type == ModelType.embedding: + # embedding models are always registered by their provider model id and does not need to be mapped to a llama model + provider_resource_id = model.provider_resource_id + else: + provider_resource_id = self.get_provider_model_id(model.provider_resource_id) + + if provider_resource_id: + model.provider_resource_id = provider_resource_id + else: + llama_model = model.metadata.get("llama_model") + existing_llama_model = self.get_llama_model(model.provider_resource_id) + if existing_llama_model: + if existing_llama_model != llama_model: + raise ValueError( + f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'" + ) + else: + # not llama model + if llama_model in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR: + self.provider_id_to_llama_model_map[model.provider_resource_id] = ( + ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model] + ) + else: + self.alias_to_provider_id_map[model.provider_model_id] = model.provider_model_id + return model diff --git a/tests/unit/providers/nvidia/test_supervised_fine_tuning.py b/tests/unit/providers/nvidia/test_supervised_fine_tuning.py index 7ce89144b..db83ab2de 100644 --- a/tests/unit/providers/nvidia/test_supervised_fine_tuning.py +++ b/tests/unit/providers/nvidia/test_supervised_fine_tuning.py @@ -10,13 +10,9 @@ import warnings from unittest.mock import patch import pytest -from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig, QatFinetuningConfig -from llama_stack_client.types.post_training_supervised_fine_tune_params import ( - TrainingConfig, - TrainingConfigDataConfig, - TrainingConfigOptimizerConfig, -) +from llama_stack.apis.models import Model, ModelType +from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAConfig, NVIDIAInferenceAdapter from llama_stack.providers.remote.post_training.nvidia.post_training import ( ListNvidiaPostTrainingJobs, NvidiaPostTrainingAdapter, @@ -24,6 +20,12 @@ from llama_stack.providers.remote.post_training.nvidia.post_training import ( NvidiaPostTrainingJob, NvidiaPostTrainingJobStatusResponse, ) +from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig, QatFinetuningConfig +from llama_stack_client.types.post_training_supervised_fine_tune_params import ( + TrainingConfig, + TrainingConfigDataConfig, + TrainingConfigOptimizerConfig, +) class TestNvidiaPostTraining(unittest.TestCase): @@ -40,8 +42,22 @@ class TestNvidiaPostTraining(unittest.TestCase): ) self.mock_make_request = self.make_request_patcher.start() + # Mock the inference client + inference_config = NVIDIAConfig(base_url=os.environ["NVIDIA_BASE_URL"], api_key=None) + self.inference_adapter = NVIDIAInferenceAdapter(inference_config) + + self.mock_client = unittest.mock.MagicMock() + self.mock_client.chat.completions.create = unittest.mock.AsyncMock() + self.inference_mock_make_request = self.mock_client.chat.completions.create + self.inference_make_request_patcher = patch( + "llama_stack.providers.remote.inference.nvidia.nvidia.NVIDIAInferenceAdapter._get_client", + return_value=self.mock_client, + ) + self.inference_make_request_patcher.start() + def tearDown(self): self.make_request_patcher.stop() + self.inference_make_request_patcher.stop() @pytest.fixture(autouse=True) def inject_fixtures(self, run_async): @@ -290,6 +306,31 @@ class TestNvidiaPostTraining(unittest.TestCase): expected_params={"job_id": job_id}, ) + def test_inference_register_model(self): + model_id = "default/job-1234" + model_type = ModelType.llm + model = Model( + identifier=model_id, + provider_id="nvidia", + provider_model_id=model_id, + provider_resource_id=model_id, + model_type=model_type, + ) + result = self.run_async(self.inference_adapter.register_model(model)) + assert result == model + assert len(self.inference_adapter.alias_to_provider_id_map) > 1 + assert self.inference_adapter.get_provider_model_id(model.provider_model_id) == model_id + + with patch.object(self.inference_adapter, "chat_completion") as mock_chat_completion: + self.run_async( + self.inference_adapter.chat_completion( + model_id=model_id, + messages=[{"role": "user", "content": "Hello, model"}], + ) + ) + + mock_chat_completion.assert_called() + if __name__ == "__main__": unittest.main()