mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
# What does this PR do? Adds custom model registration functionality to NVIDIAInferenceAdapter which let's the inference happen on: - post-training model - non-llama models in API Catalogue(behind https://integrate.api.nvidia.com and endpoints compatible with AyncOpenAI) ## Example Usage: ```python from llama_stack.apis.models import Model, ModelType from llama_stack.distribution.library_client import LlamaStackAsLibraryClient client = LlamaStackAsLibraryClient("nvidia") _ = client.initialize() client.models.register( model_id=model_name, model_type=ModelType.llm, provider_id="nvidia" ) response = client.inference.chat_completion( model_id=model_name, messages=[{"role":"system","content":"You are a helpful assistant."},{"role":"user","content":"Write a limerick about the wonders of GPU computing."}], ) ``` ## Test Plan ```bash pytest tests/unit/providers/nvidia/test_supervised_fine_tuning.py ========================================================== test session starts =========================================================== platform linux -- Python 3.10.0, pytest-8.3.5, pluggy-1.5.0 rootdir: /home/ubuntu/llama-stack configfile: pyproject.toml plugins: anyio-4.9.0 collected 6 items tests/unit/providers/nvidia/test_supervised_fine_tuning.py ...... [100%] ============================================================ warnings summary ============================================================ ../miniconda/envs/nvidia-1/lib/python3.10/site-packages/pydantic/fields.py:1076 /home/ubuntu/miniconda/envs/nvidia-1/lib/python3.10/site-packages/pydantic/fields.py:1076: PydanticDeprecatedSince20: Using extra keyword arguments on `Field` is deprecated and will be removed. Use `json_schema_extra` instead. (Extra keys: 'contentEncoding'). Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.11/migration/ warn( -- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html ====================================================== 6 passed, 1 warning in 1.51s ====================================================== ``` [//]: # (## Documentation) Updated Readme.md cc: @dglogo, @sumitb, @mattf
61 lines
2.3 KiB
Python
61 lines
2.3 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import os
|
|
from typing import Any, Dict, Optional
|
|
|
|
from pydantic import BaseModel, Field, SecretStr
|
|
|
|
from llama_stack.schema_utils import json_schema_type
|
|
|
|
|
|
@json_schema_type
|
|
class NVIDIAConfig(BaseModel):
|
|
"""
|
|
Configuration for the NVIDIA NIM inference endpoint.
|
|
|
|
Attributes:
|
|
url (str): A base url for accessing the NVIDIA NIM, e.g. http://localhost:8000
|
|
api_key (str): The access key for the hosted NIM endpoints
|
|
|
|
There are two ways to access NVIDIA NIMs -
|
|
0. Hosted: Preview APIs hosted at https://integrate.api.nvidia.com
|
|
1. Self-hosted: You can run NVIDIA NIMs on your own infrastructure
|
|
|
|
By default the configuration is set to use the hosted APIs. This requires
|
|
an API key which can be obtained from https://ngc.nvidia.com/.
|
|
|
|
By default the configuration will attempt to read the NVIDIA_API_KEY environment
|
|
variable to set the api_key. Please do not put your API key in code.
|
|
|
|
If you are using a self-hosted NVIDIA NIM, you can set the url to the
|
|
URL of your running NVIDIA NIM and do not need to set the api_key.
|
|
"""
|
|
|
|
url: str = Field(
|
|
default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "https://integrate.api.nvidia.com"),
|
|
description="A base url for accessing the NVIDIA NIM",
|
|
)
|
|
api_key: Optional[SecretStr] = Field(
|
|
default_factory=lambda: os.getenv("NVIDIA_API_KEY"),
|
|
description="The NVIDIA API key, only needed of using the hosted service",
|
|
)
|
|
timeout: int = Field(
|
|
default=60,
|
|
description="Timeout for the HTTP requests",
|
|
)
|
|
append_api_version: bool = Field(
|
|
default_factory=lambda: os.getenv("NVIDIA_APPEND_API_VERSION", "True").lower() != "false",
|
|
description="When set to false, the API version will not be appended to the base_url. By default, it is true.",
|
|
)
|
|
|
|
@classmethod
|
|
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
|
return {
|
|
"url": "${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com}",
|
|
"api_key": "${env.NVIDIA_API_KEY:}",
|
|
"append_api_version": "${env.NVIDIA_APPEND_API_VERSION:True}",
|
|
}
|