Add provider data passing for library client (#750)

# What does this PR do?

This PR adds the provider data passing for the library client and
changes the provider's api keys be unique


## Test Plan

LLAMA_STACK_CONFIG="/Users/dineshyv/.llama/distributions/llamastack-fireworks/fireworks-run.yaml"
pytest -v tests/client-sdk/agents/test_agents.py

run.yaml:
https://gist.github.com/dineshyv/0c10b5c7d0a2fb7ba4f0ecc8dcf860d1
This commit is contained in:
Dinesh Yeduguru 2025-01-13 15:12:10 -08:00 committed by GitHub
parent 6964510dc1
commit 314806cde3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 49 additions and 19 deletions

View file

@ -33,6 +33,7 @@ from termcolor import cprint
from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.datatypes import Api
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import ProviderRegistry
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
from llama_stack.distribution.stack import (
@ -67,6 +68,7 @@ def stream_across_asyncio_run_boundary(
async_gen_maker,
pool_executor: ThreadPoolExecutor,
path: Optional[str] = None,
provider_data: Optional[dict[str, Any]] = None,
) -> Generator[T, None, None]:
result_queue = queue.Queue()
stop_event = threading.Event()
@ -75,6 +77,10 @@ def stream_across_asyncio_run_boundary(
# make sure we make the generator in the event loop context
gen = await async_gen_maker()
await start_trace(path, {"__location__": "library_client"})
if provider_data:
set_request_provider_data(
{"X-LlamaStack-Provider-Data": json.dumps(provider_data)}
)
try:
async for item in await gen:
result_queue.put(item)
@ -174,6 +180,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
config_path_or_template_name: str,
skip_logger_removal: bool = False,
custom_provider_registry: Optional[ProviderRegistry] = None,
provider_data: Optional[dict[str, Any]] = None,
):
super().__init__()
self.async_client = AsyncLlamaStackAsLibraryClient(
@ -181,6 +188,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
)
self.pool_executor = ThreadPoolExecutor(max_workers=4)
self.skip_logger_removal = skip_logger_removal
self.provider_data = provider_data
def initialize(self):
if in_notebook():
@ -219,10 +227,15 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
lambda: self.async_client.request(*args, **kwargs),
self.pool_executor,
path=path,
provider_data=self.provider_data,
)
else:
async def _traced_request():
if self.provider_data:
set_request_provider_data(
{"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)}
)
await start_trace(path, {"__location__": "library_client"})
try:
return await self.async_client.request(*args, **kwargs)