llama-stack-mirror/llama_stack/providers/remote/inference/nvidia/config.py
Rashmi Pawar ace82836c1
feat: NVIDIA allow non-llama model registration (#1859)
# What does this PR do?
Adds custom model registration functionality to NVIDIAInferenceAdapter
which let's the inference happen on:
- post-training model
- non-llama models in API Catalogue(behind
https://integrate.api.nvidia.com and endpoints compatible with
AyncOpenAI)

## Example Usage:
```python
from llama_stack.apis.models import Model, ModelType
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
client = LlamaStackAsLibraryClient("nvidia")
_ = client.initialize()

client.models.register(
        model_id=model_name,
        model_type=ModelType.llm,
        provider_id="nvidia"
)

response = client.inference.chat_completion(
    model_id=model_name,
    messages=[{"role":"system","content":"You are a helpful assistant."},{"role":"user","content":"Write a limerick about the wonders of GPU computing."}],
)
```

## Test Plan
```bash
pytest tests/unit/providers/nvidia/test_supervised_fine_tuning.py 
========================================================== test session starts ===========================================================
platform linux -- Python 3.10.0, pytest-8.3.5, pluggy-1.5.0
rootdir: /home/ubuntu/llama-stack
configfile: pyproject.toml
plugins: anyio-4.9.0
collected 6 items                                                                                                                        

tests/unit/providers/nvidia/test_supervised_fine_tuning.py ......                                                                  [100%]

============================================================ warnings summary ============================================================
../miniconda/envs/nvidia-1/lib/python3.10/site-packages/pydantic/fields.py:1076
  /home/ubuntu/miniconda/envs/nvidia-1/lib/python3.10/site-packages/pydantic/fields.py:1076: PydanticDeprecatedSince20: Using extra keyword arguments on `Field` is deprecated and will be removed. Use `json_schema_extra` instead. (Extra keys: 'contentEncoding'). Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.11/migration/
    warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
====================================================== 6 passed, 1 warning in 1.51s ======================================================
```

[//]: # (## Documentation)
Updated Readme.md

cc: @dglogo, @sumitb, @mattf
2025-04-24 17:13:33 -07:00

61 lines
2.3 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field, SecretStr
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class NVIDIAConfig(BaseModel):
"""
Configuration for the NVIDIA NIM inference endpoint.
Attributes:
url (str): A base url for accessing the NVIDIA NIM, e.g. http://localhost:8000
api_key (str): The access key for the hosted NIM endpoints
There are two ways to access NVIDIA NIMs -
0. Hosted: Preview APIs hosted at https://integrate.api.nvidia.com
1. Self-hosted: You can run NVIDIA NIMs on your own infrastructure
By default the configuration is set to use the hosted APIs. This requires
an API key which can be obtained from https://ngc.nvidia.com/.
By default the configuration will attempt to read the NVIDIA_API_KEY environment
variable to set the api_key. Please do not put your API key in code.
If you are using a self-hosted NVIDIA NIM, you can set the url to the
URL of your running NVIDIA NIM and do not need to set the api_key.
"""
url: str = Field(
default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "https://integrate.api.nvidia.com"),
description="A base url for accessing the NVIDIA NIM",
)
api_key: Optional[SecretStr] = Field(
default_factory=lambda: os.getenv("NVIDIA_API_KEY"),
description="The NVIDIA API key, only needed of using the hosted service",
)
timeout: int = Field(
default=60,
description="Timeout for the HTTP requests",
)
append_api_version: bool = Field(
default_factory=lambda: os.getenv("NVIDIA_APPEND_API_VERSION", "True").lower() != "false",
description="When set to false, the API version will not be appended to the base_url. By default, it is true.",
)
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
return {
"url": "${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com}",
"api_key": "${env.NVIDIA_API_KEY:}",
"append_api_version": "${env.NVIDIA_APPEND_API_VERSION:True}",
}