forked from phoenix-oss/llama-stack-mirror
Add provider data passing for library client (#750)
# What does this PR do? This PR adds the provider data passing for the library client and changes the provider's api keys be unique ## Test Plan LLAMA_STACK_CONFIG="/Users/dineshyv/.llama/distributions/llamastack-fireworks/fireworks-run.yaml" pytest -v tests/client-sdk/agents/test_agents.py run.yaml: https://gist.github.com/dineshyv/0c10b5c7d0a2fb7ba4f0ecc8dcf860d1
This commit is contained in:
parent
6964510dc1
commit
314806cde3
10 changed files with 49 additions and 19 deletions
|
@ -33,6 +33,7 @@ from termcolor import cprint
|
|||
from llama_stack.distribution.build import print_pip_install_help
|
||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||
from llama_stack.distribution.resolver import ProviderRegistry
|
||||
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
||||
from llama_stack.distribution.stack import (
|
||||
|
@ -67,6 +68,7 @@ def stream_across_asyncio_run_boundary(
|
|||
async_gen_maker,
|
||||
pool_executor: ThreadPoolExecutor,
|
||||
path: Optional[str] = None,
|
||||
provider_data: Optional[dict[str, Any]] = None,
|
||||
) -> Generator[T, None, None]:
|
||||
result_queue = queue.Queue()
|
||||
stop_event = threading.Event()
|
||||
|
@ -75,6 +77,10 @@ def stream_across_asyncio_run_boundary(
|
|||
# make sure we make the generator in the event loop context
|
||||
gen = await async_gen_maker()
|
||||
await start_trace(path, {"__location__": "library_client"})
|
||||
if provider_data:
|
||||
set_request_provider_data(
|
||||
{"X-LlamaStack-Provider-Data": json.dumps(provider_data)}
|
||||
)
|
||||
try:
|
||||
async for item in await gen:
|
||||
result_queue.put(item)
|
||||
|
@ -174,6 +180,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
|||
config_path_or_template_name: str,
|
||||
skip_logger_removal: bool = False,
|
||||
custom_provider_registry: Optional[ProviderRegistry] = None,
|
||||
provider_data: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.async_client = AsyncLlamaStackAsLibraryClient(
|
||||
|
@ -181,6 +188,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
|||
)
|
||||
self.pool_executor = ThreadPoolExecutor(max_workers=4)
|
||||
self.skip_logger_removal = skip_logger_removal
|
||||
self.provider_data = provider_data
|
||||
|
||||
def initialize(self):
|
||||
if in_notebook():
|
||||
|
@ -219,10 +227,15 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
|||
lambda: self.async_client.request(*args, **kwargs),
|
||||
self.pool_executor,
|
||||
path=path,
|
||||
provider_data=self.provider_data,
|
||||
)
|
||||
else:
|
||||
|
||||
async def _traced_request():
|
||||
if self.provider_data:
|
||||
set_request_provider_data(
|
||||
{"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)}
|
||||
)
|
||||
await start_trace(path, {"__location__": "library_client"})
|
||||
try:
|
||||
return await self.async_client.request(*args, **kwargs)
|
||||
|
|
|
@ -12,7 +12,7 @@ from pydantic import BaseModel
|
|||
|
||||
|
||||
class BingSearchToolProviderDataValidator(BaseModel):
|
||||
api_key: str
|
||||
bing_search_api_key: str
|
||||
|
||||
|
||||
async def get_adapter_impl(config: BingSearchToolConfig, _deps):
|
||||
|
|
|
@ -44,11 +44,11 @@ class BingSearchToolRuntimeImpl(
|
|||
return self.config.api_key
|
||||
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.api_key:
|
||||
if provider_data is None or not provider_data.bing_search_api_key:
|
||||
raise ValueError(
|
||||
'Pass Bing Search API Key in the header X-LlamaStack-Provider-Data as { "api_key": <your api key>}'
|
||||
'Pass Bing Search API Key in the header X-LlamaStack-Provider-Data as { "bing_search_api_key": <your api key>}'
|
||||
)
|
||||
return provider_data.api_key
|
||||
return provider_data.bing_search_api_key
|
||||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
|
|
|
@ -11,7 +11,7 @@ from .config import BraveSearchToolConfig
|
|||
|
||||
|
||||
class BraveSearchToolProviderDataValidator(BaseModel):
|
||||
api_key: str
|
||||
brave_search_api_key: str
|
||||
|
||||
|
||||
async def get_adapter_impl(config: BraveSearchToolConfig, _deps):
|
||||
|
|
|
@ -43,11 +43,11 @@ class BraveSearchToolRuntimeImpl(
|
|||
return self.config.api_key
|
||||
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.api_key:
|
||||
if provider_data is None or not provider_data.brave_search_api_key:
|
||||
raise ValueError(
|
||||
'Pass Search provider\'s API Key in the header X-LlamaStack-Provider-Data as { "api_key": <your api key>}'
|
||||
'Pass Search provider\'s API Key in the header X-LlamaStack-Provider-Data as { "brave_search_api_key": <your api key>}'
|
||||
)
|
||||
return provider_data.api_key
|
||||
return provider_data.brave_search_api_key
|
||||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
|
|
|
@ -11,7 +11,7 @@ from .tavily_search import TavilySearchToolRuntimeImpl
|
|||
|
||||
|
||||
class TavilySearchToolProviderDataValidator(BaseModel):
|
||||
api_key: str
|
||||
tavily_search_api_key: str
|
||||
|
||||
|
||||
async def get_adapter_impl(config: TavilySearchToolConfig, _deps):
|
||||
|
|
|
@ -43,11 +43,11 @@ class TavilySearchToolRuntimeImpl(
|
|||
return self.config.api_key
|
||||
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.api_key:
|
||||
if provider_data is None or not provider_data.tavily_search_api_key:
|
||||
raise ValueError(
|
||||
'Pass Search provider\'s API Key in the header X-LlamaStack-Provider-Data as { "api_key": <your api key>}'
|
||||
'Pass Search provider\'s API Key in the header X-LlamaStack-Provider-Data as { "tavily_search_api_key": <your api key>}'
|
||||
)
|
||||
return provider_data.api_key
|
||||
return provider_data.tavily_search_api_key
|
||||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
|
|
|
@ -13,7 +13,7 @@ __all__ = ["WolframAlphaToolConfig", "WolframAlphaToolRuntimeImpl"]
|
|||
|
||||
|
||||
class WolframAlphaToolProviderDataValidator(BaseModel):
|
||||
api_key: str
|
||||
wolfram_alpha_api_key: str
|
||||
|
||||
|
||||
async def get_adapter_impl(config: WolframAlphaToolConfig, _deps):
|
||||
|
|
|
@ -44,11 +44,11 @@ class WolframAlphaToolRuntimeImpl(
|
|||
return self.config.api_key
|
||||
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.api_key:
|
||||
if provider_data is None or not provider_data.wolfram_alpha_api_key:
|
||||
raise ValueError(
|
||||
'Pass WolframAlpha API Key in the header X-LlamaStack-Provider-Data as { "api_key": <your api key>}'
|
||||
'Pass WolframAlpha API Key in the header X-LlamaStack-Provider-Data as { "wolfram_alpha_api_key": <your api key>}'
|
||||
)
|
||||
return provider_data.api_key
|
||||
return provider_data.wolfram_alpha_api_key
|
||||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
|
|
|
@ -13,12 +13,29 @@ from llama_stack_client import LlamaStackClient
|
|||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def llama_stack_client():
|
||||
def provider_data():
|
||||
# check env for tavily secret, brave secret and inject all into provider data
|
||||
provider_data = {}
|
||||
if os.environ.get("TAVILY_SEARCH_API_KEY"):
|
||||
provider_data["tavily_search_api_key"] = os.environ["TAVILY_SEARCH_API_KEY"]
|
||||
if os.environ.get("BRAVE_SEARCH_API_KEY"):
|
||||
provider_data["brave_search_api_key"] = os.environ["BRAVE_SEARCH_API_KEY"]
|
||||
return provider_data if len(provider_data) > 0 else None
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def llama_stack_client(provider_data):
|
||||
if os.environ.get("LLAMA_STACK_CONFIG"):
|
||||
client = LlamaStackAsLibraryClient(get_env_or_fail("LLAMA_STACK_CONFIG"))
|
||||
client = LlamaStackAsLibraryClient(
|
||||
get_env_or_fail("LLAMA_STACK_CONFIG"),
|
||||
provider_data=provider_data,
|
||||
)
|
||||
client.initialize()
|
||||
elif os.environ.get("LLAMA_STACK_BASE_URL"):
|
||||
client = LlamaStackClient(base_url=get_env_or_fail("LLAMA_STACK_BASE_URL"))
|
||||
client = LlamaStackClient(
|
||||
base_url=get_env_or_fail("LLAMA_STACK_BASE_URL"),
|
||||
provider_data=provider_data,
|
||||
)
|
||||
else:
|
||||
raise ValueError("LLAMA_STACK_CONFIG or LLAMA_STACK_BASE_URL must be set")
|
||||
return client
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue