NIM not working yet
Some checks failed
Installer CI / smoke-test-on-dev (push) Failing after 5s
Installer CI / lint (push) Failing after 9s

This commit is contained in:
Kai Wu 2025-07-29 14:26:58 -07:00
parent 7065b0fb4d
commit 31a15332c4
3 changed files with 218 additions and 9 deletions

View file

@ -8,7 +8,7 @@ import logging
import warnings
from collections.abc import AsyncIterator
from openai import APIConnectionError, BadRequestError
from openai import APIConnectionError, BadRequestError, AsyncOpenAI
from llama_stack.apis.common.content_types import (
InterleavedContent,
@ -27,13 +27,20 @@ from llama_stack.apis.inference import (
Inference,
LogProbConfig,
Message,
ModelStore,
ResponseFormat,
SamplingParams,
TextTruncation,
ToolChoice,
ToolConfig,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat
from llama_stack.providers.datatypes import (
HealthResponse,
HealthStatus,
ModelsProtocolPrivate,
)
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
@ -57,7 +64,7 @@ from .utils import _is_nvidia_hosted
logger = logging.getLogger(__name__)
class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper, ModelsProtocolPrivate):
"""
NVIDIA Inference Adapter for Llama Stack.
@ -71,6 +78,10 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
- ModelRegistryHelper.check_model_availability() just returns False and shows a warning
"""
# automatically set by the resolver when instantiating the provider
__provider_id__: str
model_store: ModelStore | None = None
def __init__(self, config: NVIDIAConfig) -> None:
# TODO(mf): filter by available models
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
@ -93,6 +104,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
# )
self._config = config
self._client = None
def get_api_key(self) -> str:
"""
@ -110,6 +122,149 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
"""
return f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
@property
def client(self):
"""
Get the OpenAI client.
:return: The OpenAI client
"""
self._lazy_initialize_client()
return self._client
def _lazy_initialize_client(self):
"""
Initialize the OpenAI client if it hasn't been initialized yet.
"""
if self._client is not None:
return
logger.info(f"Initializing NVIDIA client with base_url={self.get_base_url()}")
self._client = AsyncOpenAI(
base_url=self.get_base_url(),
api_key=self.get_api_key(),
)
async def initialize(self) -> None:
"""
Initialize the NVIDIA adapter.
"""
if not self._config.url:
raise ValueError(
"You must provide a URL in run.yaml (or via the NVIDIA_BASE_URL environment variable) to use NVIDIA NIM."
)
async def should_refresh_models(self) -> bool:
"""
Determine if models should be refreshed.
:return: True if models should be refreshed, False otherwise
"""
# Always refresh models to ensure we have the latest available models
return True
async def list_models(self) -> list[Model] | None:
"""
List all models available from the NVIDIA API.
:return: A list of available models
"""
self._lazy_initialize_client()
models = []
try:
async for m in self.client.models.list():
# Determine model type based on model ID or capabilities
# This is a simple heuristic and might need refinement
model_type = ModelType.llm
if "embed" in m.id.lower():
model_type = ModelType.embedding
models.append(
Model(
identifier=m.id,
provider_resource_id=m.id,
provider_id=self.__provider_id__,
metadata={},
model_type=model_type,
)
)
return models
except Exception as e:
logger.warning(f"Failed to list models from NVIDIA API: {e}")
return None
async def register_model(self, model: Model) -> Model:
"""
Register a model with the NVIDIA adapter.
:param model: The model to register
:return: The registered model
"""
self._lazy_initialize_client()
try:
# First try to register using the static model entries
model = await ModelRegistryHelper.register_model(self, model)
except ValueError:
pass # Ignore statically unknown model, will check live listing
try:
# Check if the model is available on the NVIDIA server
available_models = [m.id async for m in self.client.models.list()]
if model.provider_resource_id not in available_models:
raise ValueError(
f"Model {model.provider_resource_id} is not being served by NVIDIA NIM. "
f"Available models: {', '.join(available_models)}"
)
except APIConnectionError as e:
raise ValueError(
f"Failed to connect to NVIDIA NIM at {self._config.url}. Please check if NVIDIA NIM is running and accessible at that URL."
) from e
return model
async def unregister_model(self, model_id: str) -> None:
"""
Unregister a model from the NVIDIA adapter.
:param model_id: The ID of the model to unregister
"""
pass
async def health(self) -> HealthResponse:
"""
Performs a health check by verifying connectivity to the remote NVIDIA NIM server.
This method is used by the Provider API to verify
that the service is running correctly.
:return: A HealthResponse object indicating the health status
"""
try:
client = AsyncOpenAI(
base_url=self.get_base_url(),
api_key=self.get_api_key(),
) if self._client is None else self._client
_ = [m async for m in client.models.list()] # Ensure the client is initialized
return HealthResponse(status=HealthStatus.OK)
except Exception as e:
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
async def _get_model(self, model_id: str) -> Model:
"""
Get a model by ID.
:param model_id: The ID of the model to get
:return: The model
"""
if not self.model_store:
raise ValueError("Model store not set")
return await self.model_store.get_model(model_id)
async def shutdown(self) -> None:
"""
Shutdown the NVIDIA adapter.
"""
pass
async def completion(
self,
model_id: str,
@ -128,6 +283,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
# removing this health check as NeMo customizer endpoint health check is returning 404
# await check_health(self._config) # this raises errors
self._lazy_initialize_client()
provider_model_id = await self._get_provider_model_id(model_id)
request = convert_completion_request(
request=CompletionRequest(
@ -170,6 +326,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
#
# we can ignore str and always pass list[str] to OpenAI
#
self._lazy_initialize_client()
flat_contents = [content.text if isinstance(content, TextContentItem) else content for content in contents]
input = [content.text if isinstance(content, TextContentItem) else content for content in flat_contents]
provider_model_id = await self._get_provider_model_id(model_id)
@ -230,6 +387,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
# await check_health(self._config) # this raises errors
self._lazy_initialize_client()
provider_model_id = await self._get_provider_model_id(model_id)
request = await convert_chat_completion_request(
request=ChatCompletionRequest(