NIM not working yet
Some checks failed
Installer CI / smoke-test-on-dev (push) Failing after 5s
Installer CI / lint (push) Failing after 9s

This commit is contained in:
Kai Wu 2025-07-29 14:26:58 -07:00
parent 7065b0fb4d
commit 31a15332c4
3 changed files with 218 additions and 9 deletions

View file

@ -5,21 +5,40 @@
# the root directory of this source tree.
import os
import streamlit as st
from llama_stack_client import LlamaStackClient
class LlamaStackApi:
def __init__(self):
# Initialize provider data from environment variables
self.provider_data = {
"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY", ""),
"together_api_key": os.environ.get("TOGETHER_API_KEY", ""),
"sambanova_api_key": os.environ.get("SAMBANOVA_API_KEY", ""),
"openai_api_key": os.environ.get("OPENAI_API_KEY", ""),
"tavily_search_api_key": os.environ.get("TAVILY_SEARCH_API_KEY", ""),
}
# Check if we have any API keys stored in session state
if st.session_state.get("tavily_search_api_key"):
self.provider_data["tavily_search_api_key"] = st.session_state.get("tavily_search_api_key")
# Initialize the client
self.client = LlamaStackClient(
base_url=os.environ.get("LLAMA_STACK_ENDPOINT", "http://localhost:8321"),
provider_data={
"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY", ""),
"together_api_key": os.environ.get("TOGETHER_API_KEY", ""),
"sambanova_api_key": os.environ.get("SAMBANOVA_API_KEY", ""),
"openai_api_key": os.environ.get("OPENAI_API_KEY", ""),
"tavily_search_api_key": os.environ.get("TAVILY_SEARCH_API_KEY", ""),
},
provider_data=self.provider_data,
)
def update_provider_data(self, key, value):
"""Update a specific provider data key and reinitialize the client"""
self.provider_data[key] = value
# Reinitialize the client with updated provider data
self.client = LlamaStackClient(
base_url=os.environ.get("LLAMA_STACK_ENDPOINT", "http://localhost:8321"),
provider_data=self.provider_data,
)
def run_scoring(self, row, scoring_function_ids: list[str], scoring_params: dict | None):

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import streamlit as st
from llama_stack.distribution.ui.modules.api import llama_stack_api
@ -11,6 +12,37 @@ from llama_stack.distribution.ui.modules.api import llama_stack_api
def providers():
st.header("🔍 API Providers")
# API Key Management Section
st.subheader("API Key Management")
# Create a form for API key input
with st.form("api_keys_form"):
# Get the current value from session state or environment variable
tavily_key = st.session_state.get("tavily_search_api_key", os.environ.get("TAVILY_SEARCH_API_KEY", ""))
# Input field for Tavily Search API key
tavily_search_api_key = st.text_input(
"Tavily Search API Key",
value=tavily_key,
type="password",
help="Enter your Tavily Search API key. This will be used for search operations."
)
# Submit button
submit_button = st.form_submit_button("Save API Keys")
if submit_button:
# Store the API key in session state
st.session_state["tavily_search_api_key"] = tavily_search_api_key
# Update the client with the new API key
llama_stack_api.update_provider_data("tavily_search_api_key", tavily_search_api_key)
st.success("API keys saved successfully!")
# Display API Providers
st.subheader("Available API Providers")
apis_providers_lst = llama_stack_api.client.providers.list()
api_to_providers = {}
for api_provider in apis_providers_lst:

View file

@ -8,7 +8,7 @@ import logging
import warnings
from collections.abc import AsyncIterator
from openai import APIConnectionError, BadRequestError
from openai import APIConnectionError, BadRequestError, AsyncOpenAI
from llama_stack.apis.common.content_types import (
InterleavedContent,
@ -27,13 +27,20 @@ from llama_stack.apis.inference import (
Inference,
LogProbConfig,
Message,
ModelStore,
ResponseFormat,
SamplingParams,
TextTruncation,
ToolChoice,
ToolConfig,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat
from llama_stack.providers.datatypes import (
HealthResponse,
HealthStatus,
ModelsProtocolPrivate,
)
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
@ -57,7 +64,7 @@ from .utils import _is_nvidia_hosted
logger = logging.getLogger(__name__)
class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper, ModelsProtocolPrivate):
"""
NVIDIA Inference Adapter for Llama Stack.
@ -71,6 +78,10 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
- ModelRegistryHelper.check_model_availability() just returns False and shows a warning
"""
# automatically set by the resolver when instantiating the provider
__provider_id__: str
model_store: ModelStore | None = None
def __init__(self, config: NVIDIAConfig) -> None:
# TODO(mf): filter by available models
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
@ -93,6 +104,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
# )
self._config = config
self._client = None
def get_api_key(self) -> str:
"""
@ -110,6 +122,149 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
"""
return f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
@property
def client(self):
"""
Get the OpenAI client.
:return: The OpenAI client
"""
self._lazy_initialize_client()
return self._client
def _lazy_initialize_client(self):
"""
Initialize the OpenAI client if it hasn't been initialized yet.
"""
if self._client is not None:
return
logger.info(f"Initializing NVIDIA client with base_url={self.get_base_url()}")
self._client = AsyncOpenAI(
base_url=self.get_base_url(),
api_key=self.get_api_key(),
)
async def initialize(self) -> None:
"""
Initialize the NVIDIA adapter.
"""
if not self._config.url:
raise ValueError(
"You must provide a URL in run.yaml (or via the NVIDIA_BASE_URL environment variable) to use NVIDIA NIM."
)
async def should_refresh_models(self) -> bool:
"""
Determine if models should be refreshed.
:return: True if models should be refreshed, False otherwise
"""
# Always refresh models to ensure we have the latest available models
return True
async def list_models(self) -> list[Model] | None:
"""
List all models available from the NVIDIA API.
:return: A list of available models
"""
self._lazy_initialize_client()
models = []
try:
async for m in self.client.models.list():
# Determine model type based on model ID or capabilities
# This is a simple heuristic and might need refinement
model_type = ModelType.llm
if "embed" in m.id.lower():
model_type = ModelType.embedding
models.append(
Model(
identifier=m.id,
provider_resource_id=m.id,
provider_id=self.__provider_id__,
metadata={},
model_type=model_type,
)
)
return models
except Exception as e:
logger.warning(f"Failed to list models from NVIDIA API: {e}")
return None
async def register_model(self, model: Model) -> Model:
"""
Register a model with the NVIDIA adapter.
:param model: The model to register
:return: The registered model
"""
self._lazy_initialize_client()
try:
# First try to register using the static model entries
model = await ModelRegistryHelper.register_model(self, model)
except ValueError:
pass # Ignore statically unknown model, will check live listing
try:
# Check if the model is available on the NVIDIA server
available_models = [m.id async for m in self.client.models.list()]
if model.provider_resource_id not in available_models:
raise ValueError(
f"Model {model.provider_resource_id} is not being served by NVIDIA NIM. "
f"Available models: {', '.join(available_models)}"
)
except APIConnectionError as e:
raise ValueError(
f"Failed to connect to NVIDIA NIM at {self._config.url}. Please check if NVIDIA NIM is running and accessible at that URL."
) from e
return model
async def unregister_model(self, model_id: str) -> None:
"""
Unregister a model from the NVIDIA adapter.
:param model_id: The ID of the model to unregister
"""
pass
async def health(self) -> HealthResponse:
"""
Performs a health check by verifying connectivity to the remote NVIDIA NIM server.
This method is used by the Provider API to verify
that the service is running correctly.
:return: A HealthResponse object indicating the health status
"""
try:
client = AsyncOpenAI(
base_url=self.get_base_url(),
api_key=self.get_api_key(),
) if self._client is None else self._client
_ = [m async for m in client.models.list()] # Ensure the client is initialized
return HealthResponse(status=HealthStatus.OK)
except Exception as e:
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
async def _get_model(self, model_id: str) -> Model:
"""
Get a model by ID.
:param model_id: The ID of the model to get
:return: The model
"""
if not self.model_store:
raise ValueError("Model store not set")
return await self.model_store.get_model(model_id)
async def shutdown(self) -> None:
"""
Shutdown the NVIDIA adapter.
"""
pass
async def completion(
self,
model_id: str,
@ -128,6 +283,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
# removing this health check as NeMo customizer endpoint health check is returning 404
# await check_health(self._config) # this raises errors
self._lazy_initialize_client()
provider_model_id = await self._get_provider_model_id(model_id)
request = convert_completion_request(
request=CompletionRequest(
@ -170,6 +326,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
#
# we can ignore str and always pass list[str] to OpenAI
#
self._lazy_initialize_client()
flat_contents = [content.text if isinstance(content, TextContentItem) else content for content in contents]
input = [content.text if isinstance(content, TextContentItem) else content for content in flat_contents]
provider_model_id = await self._get_provider_model_id(model_id)
@ -230,6 +387,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
# await check_health(self._config) # this raises errors
self._lazy_initialize_client()
provider_model_id = await self._get_provider_model_id(model_id)
request = await convert_chat_completion_request(
request=ChatCompletionRequest(