diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index 7a2697327..e5ccec861 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -5,15 +5,15 @@ # the root directory of this source tree. -from openai import NOT_GIVEN - from llama_stack.apis.inference import ( OpenAIEmbeddingData, OpenAIEmbeddingsResponse, OpenAIEmbeddingUsage, ) from llama_stack.log import get_logger +from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin +from openai import NOT_GIVEN from . import NVIDIAConfig from .utils import _is_nvidia_hosted @@ -21,9 +21,7 @@ from .utils import _is_nvidia_hosted logger = get_logger(name=__name__, category="inference::nvidia") -class NVIDIAInferenceAdapter(OpenAIMixin): - config: NVIDIAConfig - +class NVIDIAInferenceAdapter(OpenAIMixin, ModelRegistryHelper): """ NVIDIA Inference Adapter for Llama Stack. @@ -37,12 +35,29 @@ class NVIDIAInferenceAdapter(OpenAIMixin): - ModelRegistryHelper.check_model_availability() just returns False and shows a warning """ + def __init__(self, config: NVIDIAConfig) -> None: + """Initialize the NVIDIA inference adapter with configuration.""" + # Initialize ModelRegistryHelper with empty model entries since NVIDIA uses dynamic model discovery + ModelRegistryHelper.__init__( + self, model_entries=[], allowed_models=config.allowed_models + ) + self.config = config + # source: https://docs.nvidia.com/nim/nemo-retriever/text-embedding/latest/support-matrix.html embedding_model_metadata: dict[str, dict[str, int]] = { - "nvidia/llama-3.2-nv-embedqa-1b-v2": {"embedding_dimension": 2048, "context_length": 8192}, + "nvidia/llama-3.2-nv-embedqa-1b-v2": { + "embedding_dimension": 2048, + "context_length": 8192, + }, "nvidia/nv-embedqa-e5-v5": {"embedding_dimension": 512, "context_length": 1024}, - "nvidia/nv-embedqa-mistral-7b-v2": {"embedding_dimension": 512, "context_length": 4096}, - "snowflake/arctic-embed-l": {"embedding_dimension": 512, "context_length": 1024}, + "nvidia/nv-embedqa-mistral-7b-v2": { + "embedding_dimension": 512, + "context_length": 4096, + }, + "snowflake/arctic-embed-l": { + "embedding_dimension": 512, + "context_length": 1024, + }, } async def initialize(self) -> None: @@ -60,7 +75,9 @@ class NVIDIAInferenceAdapter(OpenAIMixin): :return: The NVIDIA API key """ - return self.config.api_key.get_secret_value() if self.config.api_key else "NO KEY" + return ( + self.config.api_key.get_secret_value() if self.config.api_key else "NO KEY" + ) def get_base_url(self) -> str: """ @@ -68,7 +85,11 @@ class NVIDIAInferenceAdapter(OpenAIMixin): :return: The NVIDIA API base URL """ - return f"{self.config.url}/v1" if self.config.append_api_version else self.config.url + return ( + f"{self.config.url}/v1" + if self.config.append_api_version + else self.config.url + ) async def openai_embeddings( self, @@ -95,7 +116,9 @@ class NVIDIAInferenceAdapter(OpenAIMixin): response = await self.client.embeddings.create( model=await self._get_provider_model_id(model), input=input, - encoding_format=encoding_format if encoding_format is not None else NOT_GIVEN, + encoding_format=( + encoding_format if encoding_format is not None else NOT_GIVEN + ), dimensions=dimensions if dimensions is not None else NOT_GIVEN, user=user if user is not None else NOT_GIVEN, extra_body=extra_body,