Add provider data passing for library client (#750)

# What does this PR do?

This PR adds the provider data passing for the library client and
changes the provider's api keys be unique


## Test Plan

LLAMA_STACK_CONFIG="/Users/dineshyv/.llama/distributions/llamastack-fireworks/fireworks-run.yaml"
pytest -v tests/client-sdk/agents/test_agents.py

run.yaml:
https://gist.github.com/dineshyv/0c10b5c7d0a2fb7ba4f0ecc8dcf860d1
This commit is contained in:
Dinesh Yeduguru 2025-01-13 15:12:10 -08:00 committed by GitHub
parent 6964510dc1
commit 314806cde3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 49 additions and 19 deletions

View file

@ -33,6 +33,7 @@ from termcolor import cprint
from llama_stack.distribution.build import print_pip_install_help from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.datatypes import Api from llama_stack.distribution.datatypes import Api
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import ProviderRegistry from llama_stack.distribution.resolver import ProviderRegistry
from llama_stack.distribution.server.endpoints import get_all_api_endpoints from llama_stack.distribution.server.endpoints import get_all_api_endpoints
from llama_stack.distribution.stack import ( from llama_stack.distribution.stack import (
@ -67,6 +68,7 @@ def stream_across_asyncio_run_boundary(
async_gen_maker, async_gen_maker,
pool_executor: ThreadPoolExecutor, pool_executor: ThreadPoolExecutor,
path: Optional[str] = None, path: Optional[str] = None,
provider_data: Optional[dict[str, Any]] = None,
) -> Generator[T, None, None]: ) -> Generator[T, None, None]:
result_queue = queue.Queue() result_queue = queue.Queue()
stop_event = threading.Event() stop_event = threading.Event()
@ -75,6 +77,10 @@ def stream_across_asyncio_run_boundary(
# make sure we make the generator in the event loop context # make sure we make the generator in the event loop context
gen = await async_gen_maker() gen = await async_gen_maker()
await start_trace(path, {"__location__": "library_client"}) await start_trace(path, {"__location__": "library_client"})
if provider_data:
set_request_provider_data(
{"X-LlamaStack-Provider-Data": json.dumps(provider_data)}
)
try: try:
async for item in await gen: async for item in await gen:
result_queue.put(item) result_queue.put(item)
@ -174,6 +180,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
config_path_or_template_name: str, config_path_or_template_name: str,
skip_logger_removal: bool = False, skip_logger_removal: bool = False,
custom_provider_registry: Optional[ProviderRegistry] = None, custom_provider_registry: Optional[ProviderRegistry] = None,
provider_data: Optional[dict[str, Any]] = None,
): ):
super().__init__() super().__init__()
self.async_client = AsyncLlamaStackAsLibraryClient( self.async_client = AsyncLlamaStackAsLibraryClient(
@ -181,6 +188,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
) )
self.pool_executor = ThreadPoolExecutor(max_workers=4) self.pool_executor = ThreadPoolExecutor(max_workers=4)
self.skip_logger_removal = skip_logger_removal self.skip_logger_removal = skip_logger_removal
self.provider_data = provider_data
def initialize(self): def initialize(self):
if in_notebook(): if in_notebook():
@ -219,10 +227,15 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
lambda: self.async_client.request(*args, **kwargs), lambda: self.async_client.request(*args, **kwargs),
self.pool_executor, self.pool_executor,
path=path, path=path,
provider_data=self.provider_data,
) )
else: else:
async def _traced_request(): async def _traced_request():
if self.provider_data:
set_request_provider_data(
{"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)}
)
await start_trace(path, {"__location__": "library_client"}) await start_trace(path, {"__location__": "library_client"})
try: try:
return await self.async_client.request(*args, **kwargs) return await self.async_client.request(*args, **kwargs)

View file

@ -12,7 +12,7 @@ from pydantic import BaseModel
class BingSearchToolProviderDataValidator(BaseModel): class BingSearchToolProviderDataValidator(BaseModel):
api_key: str bing_search_api_key: str
async def get_adapter_impl(config: BingSearchToolConfig, _deps): async def get_adapter_impl(config: BingSearchToolConfig, _deps):

View file

@ -44,11 +44,11 @@ class BingSearchToolRuntimeImpl(
return self.config.api_key return self.config.api_key
provider_data = self.get_request_provider_data() provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.api_key: if provider_data is None or not provider_data.bing_search_api_key:
raise ValueError( raise ValueError(
'Pass Bing Search API Key in the header X-LlamaStack-Provider-Data as { "api_key": <your api key>}' 'Pass Bing Search API Key in the header X-LlamaStack-Provider-Data as { "bing_search_api_key": <your api key>}'
) )
return provider_data.api_key return provider_data.bing_search_api_key
async def list_runtime_tools( async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None

View file

@ -11,7 +11,7 @@ from .config import BraveSearchToolConfig
class BraveSearchToolProviderDataValidator(BaseModel): class BraveSearchToolProviderDataValidator(BaseModel):
api_key: str brave_search_api_key: str
async def get_adapter_impl(config: BraveSearchToolConfig, _deps): async def get_adapter_impl(config: BraveSearchToolConfig, _deps):

View file

@ -43,11 +43,11 @@ class BraveSearchToolRuntimeImpl(
return self.config.api_key return self.config.api_key
provider_data = self.get_request_provider_data() provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.api_key: if provider_data is None or not provider_data.brave_search_api_key:
raise ValueError( raise ValueError(
'Pass Search provider\'s API Key in the header X-LlamaStack-Provider-Data as { "api_key": <your api key>}' 'Pass Search provider\'s API Key in the header X-LlamaStack-Provider-Data as { "brave_search_api_key": <your api key>}'
) )
return provider_data.api_key return provider_data.brave_search_api_key
async def list_runtime_tools( async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None

View file

@ -11,7 +11,7 @@ from .tavily_search import TavilySearchToolRuntimeImpl
class TavilySearchToolProviderDataValidator(BaseModel): class TavilySearchToolProviderDataValidator(BaseModel):
api_key: str tavily_search_api_key: str
async def get_adapter_impl(config: TavilySearchToolConfig, _deps): async def get_adapter_impl(config: TavilySearchToolConfig, _deps):

View file

@ -43,11 +43,11 @@ class TavilySearchToolRuntimeImpl(
return self.config.api_key return self.config.api_key
provider_data = self.get_request_provider_data() provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.api_key: if provider_data is None or not provider_data.tavily_search_api_key:
raise ValueError( raise ValueError(
'Pass Search provider\'s API Key in the header X-LlamaStack-Provider-Data as { "api_key": <your api key>}' 'Pass Search provider\'s API Key in the header X-LlamaStack-Provider-Data as { "tavily_search_api_key": <your api key>}'
) )
return provider_data.api_key return provider_data.tavily_search_api_key
async def list_runtime_tools( async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None

View file

@ -13,7 +13,7 @@ __all__ = ["WolframAlphaToolConfig", "WolframAlphaToolRuntimeImpl"]
class WolframAlphaToolProviderDataValidator(BaseModel): class WolframAlphaToolProviderDataValidator(BaseModel):
api_key: str wolfram_alpha_api_key: str
async def get_adapter_impl(config: WolframAlphaToolConfig, _deps): async def get_adapter_impl(config: WolframAlphaToolConfig, _deps):

View file

@ -44,11 +44,11 @@ class WolframAlphaToolRuntimeImpl(
return self.config.api_key return self.config.api_key
provider_data = self.get_request_provider_data() provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.api_key: if provider_data is None or not provider_data.wolfram_alpha_api_key:
raise ValueError( raise ValueError(
'Pass WolframAlpha API Key in the header X-LlamaStack-Provider-Data as { "api_key": <your api key>}' 'Pass WolframAlpha API Key in the header X-LlamaStack-Provider-Data as { "wolfram_alpha_api_key": <your api key>}'
) )
return provider_data.api_key return provider_data.wolfram_alpha_api_key
async def list_runtime_tools( async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None

View file

@ -13,12 +13,29 @@ from llama_stack_client import LlamaStackClient
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def llama_stack_client(): def provider_data():
# check env for tavily secret, brave secret and inject all into provider data
provider_data = {}
if os.environ.get("TAVILY_SEARCH_API_KEY"):
provider_data["tavily_search_api_key"] = os.environ["TAVILY_SEARCH_API_KEY"]
if os.environ.get("BRAVE_SEARCH_API_KEY"):
provider_data["brave_search_api_key"] = os.environ["BRAVE_SEARCH_API_KEY"]
return provider_data if len(provider_data) > 0 else None
@pytest.fixture(scope="session")
def llama_stack_client(provider_data):
if os.environ.get("LLAMA_STACK_CONFIG"): if os.environ.get("LLAMA_STACK_CONFIG"):
client = LlamaStackAsLibraryClient(get_env_or_fail("LLAMA_STACK_CONFIG")) client = LlamaStackAsLibraryClient(
get_env_or_fail("LLAMA_STACK_CONFIG"),
provider_data=provider_data,
)
client.initialize() client.initialize()
elif os.environ.get("LLAMA_STACK_BASE_URL"): elif os.environ.get("LLAMA_STACK_BASE_URL"):
client = LlamaStackClient(base_url=get_env_or_fail("LLAMA_STACK_BASE_URL")) client = LlamaStackClient(
base_url=get_env_or_fail("LLAMA_STACK_BASE_URL"),
provider_data=provider_data,
)
else: else:
raise ValueError("LLAMA_STACK_CONFIG or LLAMA_STACK_BASE_URL must be set") raise ValueError("LLAMA_STACK_CONFIG or LLAMA_STACK_BASE_URL must be set")
return client return client