diff --git a/llama_stack/distribution/ui/modules/api.py b/llama_stack/distribution/ui/modules/api.py index 9db87b280..aa1584b5e 100644 --- a/llama_stack/distribution/ui/modules/api.py +++ b/llama_stack/distribution/ui/modules/api.py @@ -5,21 +5,40 @@ # the root directory of this source tree. import os +import streamlit as st from llama_stack_client import LlamaStackClient class LlamaStackApi: def __init__(self): + # Initialize provider data from environment variables + self.provider_data = { + "fireworks_api_key": os.environ.get("FIREWORKS_API_KEY", ""), + "together_api_key": os.environ.get("TOGETHER_API_KEY", ""), + "sambanova_api_key": os.environ.get("SAMBANOVA_API_KEY", ""), + "openai_api_key": os.environ.get("OPENAI_API_KEY", ""), + "tavily_search_api_key": os.environ.get("TAVILY_SEARCH_API_KEY", ""), + } + + # Check if we have any API keys stored in session state + if st.session_state.get("tavily_search_api_key"): + self.provider_data["tavily_search_api_key"] = st.session_state.get("tavily_search_api_key") + + # Initialize the client self.client = LlamaStackClient( base_url=os.environ.get("LLAMA_STACK_ENDPOINT", "http://localhost:8321"), - provider_data={ - "fireworks_api_key": os.environ.get("FIREWORKS_API_KEY", ""), - "together_api_key": os.environ.get("TOGETHER_API_KEY", ""), - "sambanova_api_key": os.environ.get("SAMBANOVA_API_KEY", ""), - "openai_api_key": os.environ.get("OPENAI_API_KEY", ""), - "tavily_search_api_key": os.environ.get("TAVILY_SEARCH_API_KEY", ""), - }, + provider_data=self.provider_data, + ) + + def update_provider_data(self, key, value): + """Update a specific provider data key and reinitialize the client""" + self.provider_data[key] = value + + # Reinitialize the client with updated provider data + self.client = LlamaStackClient( + base_url=os.environ.get("LLAMA_STACK_ENDPOINT", "http://localhost:8321"), + provider_data=self.provider_data, ) def run_scoring(self, row, scoring_function_ids: list[str], scoring_params: dict | None): diff --git a/llama_stack/distribution/ui/page/distribution/providers.py b/llama_stack/distribution/ui/page/distribution/providers.py index c660cb986..dae1fc90f 100644 --- a/llama_stack/distribution/ui/page/distribution/providers.py +++ b/llama_stack/distribution/ui/page/distribution/providers.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import os import streamlit as st from llama_stack.distribution.ui.modules.api import llama_stack_api @@ -11,6 +12,37 @@ from llama_stack.distribution.ui.modules.api import llama_stack_api def providers(): st.header("🔍 API Providers") + + # API Key Management Section + st.subheader("API Key Management") + + # Create a form for API key input + with st.form("api_keys_form"): + # Get the current value from session state or environment variable + tavily_key = st.session_state.get("tavily_search_api_key", os.environ.get("TAVILY_SEARCH_API_KEY", "")) + + # Input field for Tavily Search API key + tavily_search_api_key = st.text_input( + "Tavily Search API Key", + value=tavily_key, + type="password", + help="Enter your Tavily Search API key. This will be used for search operations." + ) + + # Submit button + submit_button = st.form_submit_button("Save API Keys") + + if submit_button: + # Store the API key in session state + st.session_state["tavily_search_api_key"] = tavily_search_api_key + + # Update the client with the new API key + llama_stack_api.update_provider_data("tavily_search_api_key", tavily_search_api_key) + + st.success("API keys saved successfully!") + + # Display API Providers + st.subheader("Available API Providers") apis_providers_lst = llama_stack_api.client.providers.list() api_to_providers = {} for api_provider in apis_providers_lst: diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index 7bc3fd0c9..d07df0eef 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -8,7 +8,7 @@ import logging import warnings from collections.abc import AsyncIterator -from openai import APIConnectionError, BadRequestError +from openai import APIConnectionError, BadRequestError, AsyncOpenAI from llama_stack.apis.common.content_types import ( InterleavedContent, @@ -27,13 +27,20 @@ from llama_stack.apis.inference import ( Inference, LogProbConfig, Message, + ModelStore, ResponseFormat, SamplingParams, TextTruncation, ToolChoice, ToolConfig, ) +from llama_stack.apis.models import Model, ModelType from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat +from llama_stack.providers.datatypes import ( + HealthResponse, + HealthStatus, + ModelsProtocolPrivate, +) from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, ) @@ -57,7 +64,7 @@ from .utils import _is_nvidia_hosted logger = logging.getLogger(__name__) -class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper): +class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper, ModelsProtocolPrivate): """ NVIDIA Inference Adapter for Llama Stack. @@ -71,6 +78,10 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper): - ModelRegistryHelper.check_model_availability() just returns False and shows a warning """ + # automatically set by the resolver when instantiating the provider + __provider_id__: str + model_store: ModelStore | None = None + def __init__(self, config: NVIDIAConfig) -> None: # TODO(mf): filter by available models ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES) @@ -93,6 +104,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper): # ) self._config = config + self._client = None def get_api_key(self) -> str: """ @@ -110,6 +122,149 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper): """ return f"{self._config.url}/v1" if self._config.append_api_version else self._config.url + @property + def client(self): + """ + Get the OpenAI client. + + :return: The OpenAI client + """ + self._lazy_initialize_client() + return self._client + + def _lazy_initialize_client(self): + """ + Initialize the OpenAI client if it hasn't been initialized yet. + """ + if self._client is not None: + return + + logger.info(f"Initializing NVIDIA client with base_url={self.get_base_url()}") + self._client = AsyncOpenAI( + base_url=self.get_base_url(), + api_key=self.get_api_key(), + ) + + async def initialize(self) -> None: + """ + Initialize the NVIDIA adapter. + """ + if not self._config.url: + raise ValueError( + "You must provide a URL in run.yaml (or via the NVIDIA_BASE_URL environment variable) to use NVIDIA NIM." + ) + + async def should_refresh_models(self) -> bool: + """ + Determine if models should be refreshed. + + :return: True if models should be refreshed, False otherwise + """ + # Always refresh models to ensure we have the latest available models + return True + + async def list_models(self) -> list[Model] | None: + """ + List all models available from the NVIDIA API. + + :return: A list of available models + """ + self._lazy_initialize_client() + models = [] + try: + async for m in self.client.models.list(): + # Determine model type based on model ID or capabilities + # This is a simple heuristic and might need refinement + model_type = ModelType.llm + if "embed" in m.id.lower(): + model_type = ModelType.embedding + + models.append( + Model( + identifier=m.id, + provider_resource_id=m.id, + provider_id=self.__provider_id__, + metadata={}, + model_type=model_type, + ) + ) + return models + except Exception as e: + logger.warning(f"Failed to list models from NVIDIA API: {e}") + return None + + async def register_model(self, model: Model) -> Model: + """ + Register a model with the NVIDIA adapter. + + :param model: The model to register + :return: The registered model + """ + self._lazy_initialize_client() + try: + # First try to register using the static model entries + model = await ModelRegistryHelper.register_model(self, model) + except ValueError: + pass # Ignore statically unknown model, will check live listing + + try: + # Check if the model is available on the NVIDIA server + available_models = [m.id async for m in self.client.models.list()] + if model.provider_resource_id not in available_models: + raise ValueError( + f"Model {model.provider_resource_id} is not being served by NVIDIA NIM. " + f"Available models: {', '.join(available_models)}" + ) + except APIConnectionError as e: + raise ValueError( + f"Failed to connect to NVIDIA NIM at {self._config.url}. Please check if NVIDIA NIM is running and accessible at that URL." + ) from e + + return model + + async def unregister_model(self, model_id: str) -> None: + """ + Unregister a model from the NVIDIA adapter. + + :param model_id: The ID of the model to unregister + """ + pass + + async def health(self) -> HealthResponse: + """ + Performs a health check by verifying connectivity to the remote NVIDIA NIM server. + This method is used by the Provider API to verify + that the service is running correctly. + + :return: A HealthResponse object indicating the health status + """ + try: + client = AsyncOpenAI( + base_url=self.get_base_url(), + api_key=self.get_api_key(), + ) if self._client is None else self._client + _ = [m async for m in client.models.list()] # Ensure the client is initialized + return HealthResponse(status=HealthStatus.OK) + except Exception as e: + return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}") + + async def _get_model(self, model_id: str) -> Model: + """ + Get a model by ID. + + :param model_id: The ID of the model to get + :return: The model + """ + if not self.model_store: + raise ValueError("Model store not set") + return await self.model_store.get_model(model_id) + + async def shutdown(self) -> None: + """ + Shutdown the NVIDIA adapter. + """ + pass + async def completion( self, model_id: str, @@ -128,6 +283,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper): # removing this health check as NeMo customizer endpoint health check is returning 404 # await check_health(self._config) # this raises errors + self._lazy_initialize_client() provider_model_id = await self._get_provider_model_id(model_id) request = convert_completion_request( request=CompletionRequest( @@ -170,6 +326,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper): # # we can ignore str and always pass list[str] to OpenAI # + self._lazy_initialize_client() flat_contents = [content.text if isinstance(content, TextContentItem) else content for content in contents] input = [content.text if isinstance(content, TextContentItem) else content for content in flat_contents] provider_model_id = await self._get_provider_model_id(model_id) @@ -230,6 +387,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper): # await check_health(self._config) # this raises errors + self._lazy_initialize_client() provider_model_id = await self._get_provider_model_id(model_id) request = await convert_chat_completion_request( request=ChatCompletionRequest(