diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index a899ae811..50af2cdea 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -33,6 +33,7 @@ from termcolor import cprint from llama_stack.distribution.build import print_pip_install_help from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.distribution.datatypes import Api +from llama_stack.distribution.request_headers import set_request_provider_data from llama_stack.distribution.resolver import ProviderRegistry from llama_stack.distribution.server.endpoints import get_all_api_endpoints from llama_stack.distribution.stack import ( @@ -67,6 +68,7 @@ def stream_across_asyncio_run_boundary( async_gen_maker, pool_executor: ThreadPoolExecutor, path: Optional[str] = None, + provider_data: Optional[dict[str, Any]] = None, ) -> Generator[T, None, None]: result_queue = queue.Queue() stop_event = threading.Event() @@ -75,6 +77,10 @@ def stream_across_asyncio_run_boundary( # make sure we make the generator in the event loop context gen = await async_gen_maker() await start_trace(path, {"__location__": "library_client"}) + if provider_data: + set_request_provider_data( + {"X-LlamaStack-Provider-Data": json.dumps(provider_data)} + ) try: async for item in await gen: result_queue.put(item) @@ -174,6 +180,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient): config_path_or_template_name: str, skip_logger_removal: bool = False, custom_provider_registry: Optional[ProviderRegistry] = None, + provider_data: Optional[dict[str, Any]] = None, ): super().__init__() self.async_client = AsyncLlamaStackAsLibraryClient( @@ -181,6 +188,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient): ) self.pool_executor = ThreadPoolExecutor(max_workers=4) self.skip_logger_removal = skip_logger_removal + self.provider_data = provider_data def initialize(self): if in_notebook(): @@ -219,10 +227,15 @@ class LlamaStackAsLibraryClient(LlamaStackClient): lambda: self.async_client.request(*args, **kwargs), self.pool_executor, path=path, + provider_data=self.provider_data, ) else: async def _traced_request(): + if self.provider_data: + set_request_provider_data( + {"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)} + ) await start_trace(path, {"__location__": "library_client"}) try: return await self.async_client.request(*args, **kwargs) diff --git a/llama_stack/providers/remote/tool_runtime/bing_search/__init__.py b/llama_stack/providers/remote/tool_runtime/bing_search/__init__.py index 8481737b5..30a883675 100644 --- a/llama_stack/providers/remote/tool_runtime/bing_search/__init__.py +++ b/llama_stack/providers/remote/tool_runtime/bing_search/__init__.py @@ -12,7 +12,7 @@ from pydantic import BaseModel class BingSearchToolProviderDataValidator(BaseModel): - api_key: str + bing_search_api_key: str async def get_adapter_impl(config: BingSearchToolConfig, _deps): diff --git a/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py b/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py index b864620d8..5114e06aa 100644 --- a/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py +++ b/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py @@ -44,11 +44,11 @@ class BingSearchToolRuntimeImpl( return self.config.api_key provider_data = self.get_request_provider_data() - if provider_data is None or not provider_data.api_key: + if provider_data is None or not provider_data.bing_search_api_key: raise ValueError( - 'Pass Bing Search API Key in the header X-LlamaStack-Provider-Data as { "api_key": }' + 'Pass Bing Search API Key in the header X-LlamaStack-Provider-Data as { "bing_search_api_key": }' ) - return provider_data.api_key + return provider_data.bing_search_api_key async def list_runtime_tools( self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None diff --git a/llama_stack/providers/remote/tool_runtime/brave_search/__init__.py b/llama_stack/providers/remote/tool_runtime/brave_search/__init__.py index 0827e51d2..2bfa520b4 100644 --- a/llama_stack/providers/remote/tool_runtime/brave_search/__init__.py +++ b/llama_stack/providers/remote/tool_runtime/brave_search/__init__.py @@ -11,7 +11,7 @@ from .config import BraveSearchToolConfig class BraveSearchToolProviderDataValidator(BaseModel): - api_key: str + brave_search_api_key: str async def get_adapter_impl(config: BraveSearchToolConfig, _deps): diff --git a/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py b/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py index 259d02f1b..016f746ea 100644 --- a/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py +++ b/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py @@ -43,11 +43,11 @@ class BraveSearchToolRuntimeImpl( return self.config.api_key provider_data = self.get_request_provider_data() - if provider_data is None or not provider_data.api_key: + if provider_data is None or not provider_data.brave_search_api_key: raise ValueError( - 'Pass Search provider\'s API Key in the header X-LlamaStack-Provider-Data as { "api_key": }' + 'Pass Search provider\'s API Key in the header X-LlamaStack-Provider-Data as { "brave_search_api_key": }' ) - return provider_data.api_key + return provider_data.brave_search_api_key async def list_runtime_tools( self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None diff --git a/llama_stack/providers/remote/tool_runtime/tavily_search/__init__.py b/llama_stack/providers/remote/tool_runtime/tavily_search/__init__.py index 379e99081..e90a142ec 100644 --- a/llama_stack/providers/remote/tool_runtime/tavily_search/__init__.py +++ b/llama_stack/providers/remote/tool_runtime/tavily_search/__init__.py @@ -11,7 +11,7 @@ from .tavily_search import TavilySearchToolRuntimeImpl class TavilySearchToolProviderDataValidator(BaseModel): - api_key: str + tavily_search_api_key: str async def get_adapter_impl(config: TavilySearchToolConfig, _deps): diff --git a/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py b/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py index 1716f96e5..82077193e 100644 --- a/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py +++ b/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py @@ -43,11 +43,11 @@ class TavilySearchToolRuntimeImpl( return self.config.api_key provider_data = self.get_request_provider_data() - if provider_data is None or not provider_data.api_key: + if provider_data is None or not provider_data.tavily_search_api_key: raise ValueError( - 'Pass Search provider\'s API Key in the header X-LlamaStack-Provider-Data as { "api_key": }' + 'Pass Search provider\'s API Key in the header X-LlamaStack-Provider-Data as { "tavily_search_api_key": }' ) - return provider_data.api_key + return provider_data.tavily_search_api_key async def list_runtime_tools( self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None diff --git a/llama_stack/providers/remote/tool_runtime/wolfram_alpha/__init__.py b/llama_stack/providers/remote/tool_runtime/wolfram_alpha/__init__.py index aaa6e4e69..adeb094ab 100644 --- a/llama_stack/providers/remote/tool_runtime/wolfram_alpha/__init__.py +++ b/llama_stack/providers/remote/tool_runtime/wolfram_alpha/__init__.py @@ -13,7 +13,7 @@ __all__ = ["WolframAlphaToolConfig", "WolframAlphaToolRuntimeImpl"] class WolframAlphaToolProviderDataValidator(BaseModel): - api_key: str + wolfram_alpha_api_key: str async def get_adapter_impl(config: WolframAlphaToolConfig, _deps): diff --git a/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py b/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py index 8d0792ca0..04ecfcc15 100644 --- a/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py +++ b/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py @@ -44,11 +44,11 @@ class WolframAlphaToolRuntimeImpl( return self.config.api_key provider_data = self.get_request_provider_data() - if provider_data is None or not provider_data.api_key: + if provider_data is None or not provider_data.wolfram_alpha_api_key: raise ValueError( - 'Pass WolframAlpha API Key in the header X-LlamaStack-Provider-Data as { "api_key": }' + 'Pass WolframAlpha API Key in the header X-LlamaStack-Provider-Data as { "wolfram_alpha_api_key": }' ) - return provider_data.api_key + return provider_data.wolfram_alpha_api_key async def list_runtime_tools( self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None diff --git a/tests/client-sdk/conftest.py b/tests/client-sdk/conftest.py index 28808ae4c..16e6d1bbd 100644 --- a/tests/client-sdk/conftest.py +++ b/tests/client-sdk/conftest.py @@ -13,12 +13,29 @@ from llama_stack_client import LlamaStackClient @pytest.fixture(scope="session") -def llama_stack_client(): +def provider_data(): + # check env for tavily secret, brave secret and inject all into provider data + provider_data = {} + if os.environ.get("TAVILY_SEARCH_API_KEY"): + provider_data["tavily_search_api_key"] = os.environ["TAVILY_SEARCH_API_KEY"] + if os.environ.get("BRAVE_SEARCH_API_KEY"): + provider_data["brave_search_api_key"] = os.environ["BRAVE_SEARCH_API_KEY"] + return provider_data if len(provider_data) > 0 else None + + +@pytest.fixture(scope="session") +def llama_stack_client(provider_data): if os.environ.get("LLAMA_STACK_CONFIG"): - client = LlamaStackAsLibraryClient(get_env_or_fail("LLAMA_STACK_CONFIG")) + client = LlamaStackAsLibraryClient( + get_env_or_fail("LLAMA_STACK_CONFIG"), + provider_data=provider_data, + ) client.initialize() elif os.environ.get("LLAMA_STACK_BASE_URL"): - client = LlamaStackClient(base_url=get_env_or_fail("LLAMA_STACK_BASE_URL")) + client = LlamaStackClient( + base_url=get_env_or_fail("LLAMA_STACK_BASE_URL"), + provider_data=provider_data, + ) else: raise ValueError("LLAMA_STACK_CONFIG or LLAMA_STACK_BASE_URL must be set") return client