forked from phoenix-oss/llama-stack-mirror
Add provider data passing for library client (#750)
# What does this PR do? This PR adds the provider data passing for the library client and changes the provider's api keys be unique ## Test Plan LLAMA_STACK_CONFIG="/Users/dineshyv/.llama/distributions/llamastack-fireworks/fireworks-run.yaml" pytest -v tests/client-sdk/agents/test_agents.py run.yaml: https://gist.github.com/dineshyv/0c10b5c7d0a2fb7ba4f0ecc8dcf860d1
This commit is contained in:
parent
6964510dc1
commit
314806cde3
10 changed files with 49 additions and 19 deletions
|
@ -33,6 +33,7 @@ from termcolor import cprint
|
||||||
from llama_stack.distribution.build import print_pip_install_help
|
from llama_stack.distribution.build import print_pip_install_help
|
||||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||||
from llama_stack.distribution.datatypes import Api
|
from llama_stack.distribution.datatypes import Api
|
||||||
|
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||||
from llama_stack.distribution.resolver import ProviderRegistry
|
from llama_stack.distribution.resolver import ProviderRegistry
|
||||||
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
||||||
from llama_stack.distribution.stack import (
|
from llama_stack.distribution.stack import (
|
||||||
|
@ -67,6 +68,7 @@ def stream_across_asyncio_run_boundary(
|
||||||
async_gen_maker,
|
async_gen_maker,
|
||||||
pool_executor: ThreadPoolExecutor,
|
pool_executor: ThreadPoolExecutor,
|
||||||
path: Optional[str] = None,
|
path: Optional[str] = None,
|
||||||
|
provider_data: Optional[dict[str, Any]] = None,
|
||||||
) -> Generator[T, None, None]:
|
) -> Generator[T, None, None]:
|
||||||
result_queue = queue.Queue()
|
result_queue = queue.Queue()
|
||||||
stop_event = threading.Event()
|
stop_event = threading.Event()
|
||||||
|
@ -75,6 +77,10 @@ def stream_across_asyncio_run_boundary(
|
||||||
# make sure we make the generator in the event loop context
|
# make sure we make the generator in the event loop context
|
||||||
gen = await async_gen_maker()
|
gen = await async_gen_maker()
|
||||||
await start_trace(path, {"__location__": "library_client"})
|
await start_trace(path, {"__location__": "library_client"})
|
||||||
|
if provider_data:
|
||||||
|
set_request_provider_data(
|
||||||
|
{"X-LlamaStack-Provider-Data": json.dumps(provider_data)}
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
async for item in await gen:
|
async for item in await gen:
|
||||||
result_queue.put(item)
|
result_queue.put(item)
|
||||||
|
@ -174,6 +180,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||||
config_path_or_template_name: str,
|
config_path_or_template_name: str,
|
||||||
skip_logger_removal: bool = False,
|
skip_logger_removal: bool = False,
|
||||||
custom_provider_registry: Optional[ProviderRegistry] = None,
|
custom_provider_registry: Optional[ProviderRegistry] = None,
|
||||||
|
provider_data: Optional[dict[str, Any]] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.async_client = AsyncLlamaStackAsLibraryClient(
|
self.async_client = AsyncLlamaStackAsLibraryClient(
|
||||||
|
@ -181,6 +188,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||||
)
|
)
|
||||||
self.pool_executor = ThreadPoolExecutor(max_workers=4)
|
self.pool_executor = ThreadPoolExecutor(max_workers=4)
|
||||||
self.skip_logger_removal = skip_logger_removal
|
self.skip_logger_removal = skip_logger_removal
|
||||||
|
self.provider_data = provider_data
|
||||||
|
|
||||||
def initialize(self):
|
def initialize(self):
|
||||||
if in_notebook():
|
if in_notebook():
|
||||||
|
@ -219,10 +227,15 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||||
lambda: self.async_client.request(*args, **kwargs),
|
lambda: self.async_client.request(*args, **kwargs),
|
||||||
self.pool_executor,
|
self.pool_executor,
|
||||||
path=path,
|
path=path,
|
||||||
|
provider_data=self.provider_data,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
||||||
async def _traced_request():
|
async def _traced_request():
|
||||||
|
if self.provider_data:
|
||||||
|
set_request_provider_data(
|
||||||
|
{"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)}
|
||||||
|
)
|
||||||
await start_trace(path, {"__location__": "library_client"})
|
await start_trace(path, {"__location__": "library_client"})
|
||||||
try:
|
try:
|
||||||
return await self.async_client.request(*args, **kwargs)
|
return await self.async_client.request(*args, **kwargs)
|
||||||
|
|
|
@ -12,7 +12,7 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class BingSearchToolProviderDataValidator(BaseModel):
|
class BingSearchToolProviderDataValidator(BaseModel):
|
||||||
api_key: str
|
bing_search_api_key: str
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: BingSearchToolConfig, _deps):
|
async def get_adapter_impl(config: BingSearchToolConfig, _deps):
|
||||||
|
|
|
@ -44,11 +44,11 @@ class BingSearchToolRuntimeImpl(
|
||||||
return self.config.api_key
|
return self.config.api_key
|
||||||
|
|
||||||
provider_data = self.get_request_provider_data()
|
provider_data = self.get_request_provider_data()
|
||||||
if provider_data is None or not provider_data.api_key:
|
if provider_data is None or not provider_data.bing_search_api_key:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Pass Bing Search API Key in the header X-LlamaStack-Provider-Data as { "api_key": <your api key>}'
|
'Pass Bing Search API Key in the header X-LlamaStack-Provider-Data as { "bing_search_api_key": <your api key>}'
|
||||||
)
|
)
|
||||||
return provider_data.api_key
|
return provider_data.bing_search_api_key
|
||||||
|
|
||||||
async def list_runtime_tools(
|
async def list_runtime_tools(
|
||||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||||
|
|
|
@ -11,7 +11,7 @@ from .config import BraveSearchToolConfig
|
||||||
|
|
||||||
|
|
||||||
class BraveSearchToolProviderDataValidator(BaseModel):
|
class BraveSearchToolProviderDataValidator(BaseModel):
|
||||||
api_key: str
|
brave_search_api_key: str
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: BraveSearchToolConfig, _deps):
|
async def get_adapter_impl(config: BraveSearchToolConfig, _deps):
|
||||||
|
|
|
@ -43,11 +43,11 @@ class BraveSearchToolRuntimeImpl(
|
||||||
return self.config.api_key
|
return self.config.api_key
|
||||||
|
|
||||||
provider_data = self.get_request_provider_data()
|
provider_data = self.get_request_provider_data()
|
||||||
if provider_data is None or not provider_data.api_key:
|
if provider_data is None or not provider_data.brave_search_api_key:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Pass Search provider\'s API Key in the header X-LlamaStack-Provider-Data as { "api_key": <your api key>}'
|
'Pass Search provider\'s API Key in the header X-LlamaStack-Provider-Data as { "brave_search_api_key": <your api key>}'
|
||||||
)
|
)
|
||||||
return provider_data.api_key
|
return provider_data.brave_search_api_key
|
||||||
|
|
||||||
async def list_runtime_tools(
|
async def list_runtime_tools(
|
||||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||||
|
|
|
@ -11,7 +11,7 @@ from .tavily_search import TavilySearchToolRuntimeImpl
|
||||||
|
|
||||||
|
|
||||||
class TavilySearchToolProviderDataValidator(BaseModel):
|
class TavilySearchToolProviderDataValidator(BaseModel):
|
||||||
api_key: str
|
tavily_search_api_key: str
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: TavilySearchToolConfig, _deps):
|
async def get_adapter_impl(config: TavilySearchToolConfig, _deps):
|
||||||
|
|
|
@ -43,11 +43,11 @@ class TavilySearchToolRuntimeImpl(
|
||||||
return self.config.api_key
|
return self.config.api_key
|
||||||
|
|
||||||
provider_data = self.get_request_provider_data()
|
provider_data = self.get_request_provider_data()
|
||||||
if provider_data is None or not provider_data.api_key:
|
if provider_data is None or not provider_data.tavily_search_api_key:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Pass Search provider\'s API Key in the header X-LlamaStack-Provider-Data as { "api_key": <your api key>}'
|
'Pass Search provider\'s API Key in the header X-LlamaStack-Provider-Data as { "tavily_search_api_key": <your api key>}'
|
||||||
)
|
)
|
||||||
return provider_data.api_key
|
return provider_data.tavily_search_api_key
|
||||||
|
|
||||||
async def list_runtime_tools(
|
async def list_runtime_tools(
|
||||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||||
|
|
|
@ -13,7 +13,7 @@ __all__ = ["WolframAlphaToolConfig", "WolframAlphaToolRuntimeImpl"]
|
||||||
|
|
||||||
|
|
||||||
class WolframAlphaToolProviderDataValidator(BaseModel):
|
class WolframAlphaToolProviderDataValidator(BaseModel):
|
||||||
api_key: str
|
wolfram_alpha_api_key: str
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: WolframAlphaToolConfig, _deps):
|
async def get_adapter_impl(config: WolframAlphaToolConfig, _deps):
|
||||||
|
|
|
@ -44,11 +44,11 @@ class WolframAlphaToolRuntimeImpl(
|
||||||
return self.config.api_key
|
return self.config.api_key
|
||||||
|
|
||||||
provider_data = self.get_request_provider_data()
|
provider_data = self.get_request_provider_data()
|
||||||
if provider_data is None or not provider_data.api_key:
|
if provider_data is None or not provider_data.wolfram_alpha_api_key:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Pass WolframAlpha API Key in the header X-LlamaStack-Provider-Data as { "api_key": <your api key>}'
|
'Pass WolframAlpha API Key in the header X-LlamaStack-Provider-Data as { "wolfram_alpha_api_key": <your api key>}'
|
||||||
)
|
)
|
||||||
return provider_data.api_key
|
return provider_data.wolfram_alpha_api_key
|
||||||
|
|
||||||
async def list_runtime_tools(
|
async def list_runtime_tools(
|
||||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||||
|
|
|
@ -13,12 +13,29 @@ from llama_stack_client import LlamaStackClient
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def llama_stack_client():
|
def provider_data():
|
||||||
|
# check env for tavily secret, brave secret and inject all into provider data
|
||||||
|
provider_data = {}
|
||||||
|
if os.environ.get("TAVILY_SEARCH_API_KEY"):
|
||||||
|
provider_data["tavily_search_api_key"] = os.environ["TAVILY_SEARCH_API_KEY"]
|
||||||
|
if os.environ.get("BRAVE_SEARCH_API_KEY"):
|
||||||
|
provider_data["brave_search_api_key"] = os.environ["BRAVE_SEARCH_API_KEY"]
|
||||||
|
return provider_data if len(provider_data) > 0 else None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def llama_stack_client(provider_data):
|
||||||
if os.environ.get("LLAMA_STACK_CONFIG"):
|
if os.environ.get("LLAMA_STACK_CONFIG"):
|
||||||
client = LlamaStackAsLibraryClient(get_env_or_fail("LLAMA_STACK_CONFIG"))
|
client = LlamaStackAsLibraryClient(
|
||||||
|
get_env_or_fail("LLAMA_STACK_CONFIG"),
|
||||||
|
provider_data=provider_data,
|
||||||
|
)
|
||||||
client.initialize()
|
client.initialize()
|
||||||
elif os.environ.get("LLAMA_STACK_BASE_URL"):
|
elif os.environ.get("LLAMA_STACK_BASE_URL"):
|
||||||
client = LlamaStackClient(base_url=get_env_or_fail("LLAMA_STACK_BASE_URL"))
|
client = LlamaStackClient(
|
||||||
|
base_url=get_env_or_fail("LLAMA_STACK_BASE_URL"),
|
||||||
|
provider_data=provider_data,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("LLAMA_STACK_CONFIG or LLAMA_STACK_BASE_URL must be set")
|
raise ValueError("LLAMA_STACK_CONFIG or LLAMA_STACK_BASE_URL must be set")
|
||||||
return client
|
return client
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue