From ace82836c14b4bd5380a14149047013332672bc3 Mon Sep 17 00:00:00 2001 From: Rashmi Pawar <168514198+raspawar@users.noreply.github.com> Date: Fri, 25 Apr 2025 05:43:33 +0530 Subject: [PATCH] feat: NVIDIA allow non-llama model registration (#1859) # What does this PR do? Adds custom model registration functionality to NVIDIAInferenceAdapter which let's the inference happen on: - post-training model - non-llama models in API Catalogue(behind https://integrate.api.nvidia.com and endpoints compatible with AyncOpenAI) ## Example Usage: ```python from llama_stack.apis.models import Model, ModelType from llama_stack.distribution.library_client import LlamaStackAsLibraryClient client = LlamaStackAsLibraryClient("nvidia") _ = client.initialize() client.models.register( model_id=model_name, model_type=ModelType.llm, provider_id="nvidia" ) response = client.inference.chat_completion( model_id=model_name, messages=[{"role":"system","content":"You are a helpful assistant."},{"role":"user","content":"Write a limerick about the wonders of GPU computing."}], ) ``` ## Test Plan ```bash pytest tests/unit/providers/nvidia/test_supervised_fine_tuning.py ========================================================== test session starts =========================================================== platform linux -- Python 3.10.0, pytest-8.3.5, pluggy-1.5.0 rootdir: /home/ubuntu/llama-stack configfile: pyproject.toml plugins: anyio-4.9.0 collected 6 items tests/unit/providers/nvidia/test_supervised_fine_tuning.py ...... [100%] ============================================================ warnings summary ============================================================ ../miniconda/envs/nvidia-1/lib/python3.10/site-packages/pydantic/fields.py:1076 /home/ubuntu/miniconda/envs/nvidia-1/lib/python3.10/site-packages/pydantic/fields.py:1076: PydanticDeprecatedSince20: Using extra keyword arguments on `Field` is deprecated and will be removed. Use `json_schema_extra` instead. (Extra keys: 'contentEncoding'). Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.11/migration/ warn( -- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html ====================================================== 6 passed, 1 warning in 1.51s ====================================================== ``` [//]: # (## Documentation) Updated Readme.md cc: @dglogo, @sumitb, @mattf --- .../self_hosted_distro/nvidia.md | 3 +- .../remote/inference/nvidia/config.py | 5 ++ .../remote/inference/nvidia/nvidia.py | 52 +++++++++++++++++-- .../remote/post_training/nvidia/README.md | 16 +++++- llama_stack/templates/nvidia/nvidia.py | 12 ++--- .../templates/nvidia/run-with-safety.yaml | 1 + llama_stack/templates/nvidia/run.yaml | 1 + .../nvidia/test_supervised_fine_tuning.py | 41 +++++++++++++++ 8 files changed, 116 insertions(+), 15 deletions(-) diff --git a/docs/source/distributions/self_hosted_distro/nvidia.md b/docs/source/distributions/self_hosted_distro/nvidia.md index 147c5b2ae..4407de779 100644 --- a/docs/source/distributions/self_hosted_distro/nvidia.md +++ b/docs/source/distributions/self_hosted_distro/nvidia.md @@ -22,9 +22,8 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov The following environment variables can be configured: - `NVIDIA_API_KEY`: NVIDIA API Key (default: ``) -- `NVIDIA_USER_ID`: NVIDIA User ID (default: `llama-stack-user`) +- `NVIDIA_APPEND_API_VERSION`: Whether to append the API version to the base_url (default: `True`) - `NVIDIA_DATASET_NAMESPACE`: NVIDIA Dataset Namespace (default: `default`) -- `NVIDIA_ACCESS_POLICIES`: NVIDIA Access Policies (default: `{}`) - `NVIDIA_PROJECT_ID`: NVIDIA Project ID (default: `test-project`) - `NVIDIA_CUSTOMIZER_URL`: NVIDIA Customizer URL (default: `https://customizer.api.nvidia.com`) - `NVIDIA_OUTPUT_MODEL_DIR`: NVIDIA Output Model Directory (default: `test-example-model@v1`) diff --git a/llama_stack/providers/remote/inference/nvidia/config.py b/llama_stack/providers/remote/inference/nvidia/config.py index abd34b498..8f80408d4 100644 --- a/llama_stack/providers/remote/inference/nvidia/config.py +++ b/llama_stack/providers/remote/inference/nvidia/config.py @@ -47,10 +47,15 @@ class NVIDIAConfig(BaseModel): default=60, description="Timeout for the HTTP requests", ) + append_api_version: bool = Field( + default_factory=lambda: os.getenv("NVIDIA_APPEND_API_VERSION", "True").lower() != "false", + description="When set to false, the API version will not be appended to the base_url. By default, it is true.", + ) @classmethod def sample_run_config(cls, **kwargs) -> Dict[str, Any]: return { "url": "${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com}", "api_key": "${env.NVIDIA_API_KEY:}", + "append_api_version": "${env.NVIDIA_APPEND_API_VERSION:True}", } diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index c91b4d768..4a62ad6cb 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -33,7 +33,6 @@ from llama_stack.apis.inference import ( TextTruncation, ToolChoice, ToolConfig, - ToolDefinition, ) from llama_stack.apis.inference.inference import ( OpenAIChatCompletion, @@ -42,7 +41,11 @@ from llama_stack.apis.inference.inference import ( OpenAIMessageParam, OpenAIResponseFormatParam, ) -from llama_stack.models.llama.datatypes import ToolPromptFormat +from llama_stack.apis.models import Model, ModelType +from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat +from llama_stack.providers.utils.inference import ( + ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR, +) from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, ) @@ -120,10 +123,10 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): "meta/llama-3.2-90b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct", } - base_url = f"{self._config.url}/v1" + base_url = f"{self._config.url}/v1" if self._config.append_api_version else self._config.url + if _is_nvidia_hosted(self._config) and provider_model_id in special_model_urls: base_url = special_model_urls[provider_model_id] - return _get_client_for_base_url(base_url) async def _get_provider_model_id(self, model_id: str) -> str: @@ -387,3 +390,44 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): return await self._get_client(provider_model_id).chat.completions.create(**params) except APIConnectionError as e: raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e + + async def register_model(self, model: Model) -> Model: + """ + Allow non-llama model registration. + + Non-llama model registration: API Catalogue models, post-training models, etc. + client = LlamaStackAsLibraryClient("nvidia") + client.models.register( + model_id="mistralai/mixtral-8x7b-instruct-v0.1", + model_type=ModelType.llm, + provider_id="nvidia", + provider_model_id="mistralai/mixtral-8x7b-instruct-v0.1" + ) + + NOTE: Only supports models endpoints compatible with AsyncOpenAI base_url format. + """ + if model.model_type == ModelType.embedding: + # embedding models are always registered by their provider model id and does not need to be mapped to a llama model + provider_resource_id = model.provider_resource_id + else: + provider_resource_id = self.get_provider_model_id(model.provider_resource_id) + + if provider_resource_id: + model.provider_resource_id = provider_resource_id + else: + llama_model = model.metadata.get("llama_model") + existing_llama_model = self.get_llama_model(model.provider_resource_id) + if existing_llama_model: + if existing_llama_model != llama_model: + raise ValueError( + f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'" + ) + else: + # not llama model + if llama_model in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR: + self.provider_id_to_llama_model_map[model.provider_resource_id] = ( + ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model] + ) + else: + self.alias_to_provider_id_map[model.provider_model_id] = model.provider_model_id + return model diff --git a/llama_stack/providers/remote/post_training/nvidia/README.md b/llama_stack/providers/remote/post_training/nvidia/README.md index 230587d66..3ef538d29 100644 --- a/llama_stack/providers/remote/post_training/nvidia/README.md +++ b/llama_stack/providers/remote/post_training/nvidia/README.md @@ -36,7 +36,6 @@ import os os.environ["NVIDIA_API_KEY"] = "your-api-key" os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test" -os.environ["NVIDIA_USER_ID"] = "llama-stack-user" os.environ["NVIDIA_DATASET_NAMESPACE"] = "default" os.environ["NVIDIA_PROJECT_ID"] = "test-project" os.environ["NVIDIA_OUTPUT_MODEL_DIR"] = "test-example-model@v1" @@ -125,6 +124,21 @@ client.post_training.job.cancel(job_uuid="your-job-id") ### Inference with the fine-tuned model +#### 1. Register the model + +```python +from llama_stack.apis.models import Model, ModelType + +client.models.register( + model_id="test-example-model@v1", + provider_id="nvidia", + provider_model_id="test-example-model@v1", + model_type=ModelType.llm, +) +``` + +#### 2. Inference with the fine-tuned model + ```python response = client.inference.completion( content="Complete the sentence using one word: Roses are red, violets are ", diff --git a/llama_stack/templates/nvidia/nvidia.py b/llama_stack/templates/nvidia/nvidia.py index 32ddf78e3..463c13879 100644 --- a/llama_stack/templates/nvidia/nvidia.py +++ b/llama_stack/templates/nvidia/nvidia.py @@ -98,19 +98,15 @@ def get_distribution_template() -> DistributionTemplate: "", "NVIDIA API Key", ), - ## Nemo Customizer related variables - "NVIDIA_USER_ID": ( - "llama-stack-user", - "NVIDIA User ID", + "NVIDIA_APPEND_API_VERSION": ( + "True", + "Whether to append the API version to the base_url", ), + ## Nemo Customizer related variables "NVIDIA_DATASET_NAMESPACE": ( "default", "NVIDIA Dataset Namespace", ), - "NVIDIA_ACCESS_POLICIES": ( - "{}", - "NVIDIA Access Policies", - ), "NVIDIA_PROJECT_ID": ( "test-project", "NVIDIA Project ID", diff --git a/llama_stack/templates/nvidia/run-with-safety.yaml b/llama_stack/templates/nvidia/run-with-safety.yaml index 8483fb9bf..a3e5fefa4 100644 --- a/llama_stack/templates/nvidia/run-with-safety.yaml +++ b/llama_stack/templates/nvidia/run-with-safety.yaml @@ -18,6 +18,7 @@ providers: config: url: ${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com} api_key: ${env.NVIDIA_API_KEY:} + append_api_version: ${env.NVIDIA_APPEND_API_VERSION:True} - provider_id: nvidia provider_type: remote::nvidia config: diff --git a/llama_stack/templates/nvidia/run.yaml b/llama_stack/templates/nvidia/run.yaml index d7e2753ba..271ce1a16 100644 --- a/llama_stack/templates/nvidia/run.yaml +++ b/llama_stack/templates/nvidia/run.yaml @@ -18,6 +18,7 @@ providers: config: url: ${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com} api_key: ${env.NVIDIA_API_KEY:} + append_api_version: ${env.NVIDIA_APPEND_API_VERSION:True} vector_io: - provider_id: faiss provider_type: inline::faiss diff --git a/tests/unit/providers/nvidia/test_supervised_fine_tuning.py b/tests/unit/providers/nvidia/test_supervised_fine_tuning.py index 43e0ac11c..09f67e4e6 100644 --- a/tests/unit/providers/nvidia/test_supervised_fine_tuning.py +++ b/tests/unit/providers/nvidia/test_supervised_fine_tuning.py @@ -17,6 +17,8 @@ from llama_stack_client.types.post_training_supervised_fine_tune_params import ( TrainingConfigOptimizerConfig, ) +from llama_stack.apis.models import Model, ModelType +from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAConfig, NVIDIAInferenceAdapter from llama_stack.providers.remote.post_training.nvidia.post_training import ( ListNvidiaPostTrainingJobs, NvidiaPostTrainingAdapter, @@ -40,8 +42,22 @@ class TestNvidiaPostTraining(unittest.TestCase): ) self.mock_make_request = self.make_request_patcher.start() + # Mock the inference client + inference_config = NVIDIAConfig(base_url=os.environ["NVIDIA_BASE_URL"], api_key=None) + self.inference_adapter = NVIDIAInferenceAdapter(inference_config) + + self.mock_client = unittest.mock.MagicMock() + self.mock_client.chat.completions.create = unittest.mock.AsyncMock() + self.inference_mock_make_request = self.mock_client.chat.completions.create + self.inference_make_request_patcher = patch( + "llama_stack.providers.remote.inference.nvidia.nvidia.NVIDIAInferenceAdapter._get_client", + return_value=self.mock_client, + ) + self.inference_make_request_patcher.start() + def tearDown(self): self.make_request_patcher.stop() + self.inference_make_request_patcher.stop() @pytest.fixture(autouse=True) def inject_fixtures(self, run_async): @@ -303,6 +319,31 @@ class TestNvidiaPostTraining(unittest.TestCase): expected_params={"job_id": job_id}, ) + def test_inference_register_model(self): + model_id = "default/job-1234" + model_type = ModelType.llm + model = Model( + identifier=model_id, + provider_id="nvidia", + provider_model_id=model_id, + provider_resource_id=model_id, + model_type=model_type, + ) + result = self.run_async(self.inference_adapter.register_model(model)) + assert result == model + assert len(self.inference_adapter.alias_to_provider_id_map) > 1 + assert self.inference_adapter.get_provider_model_id(model.provider_model_id) == model_id + + with patch.object(self.inference_adapter, "chat_completion") as mock_chat_completion: + self.run_async( + self.inference_adapter.chat_completion( + model_id=model_id, + messages=[{"role": "user", "content": "Hello, model"}], + ) + ) + + mock_chat_completion.assert_called() + if __name__ == "__main__": unittest.main()