Add provider data passing for library client

This commit is contained in:
Dinesh Yeduguru 2025-01-13 13:25:36 -08:00
parent 78727aad26
commit 33b4e8df49
10 changed files with 49 additions and 19 deletions

View file

@ -33,6 +33,7 @@ from termcolor import cprint
from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.datatypes import Api
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import ProviderRegistry
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
from llama_stack.distribution.stack import (
@ -67,6 +68,7 @@ def stream_across_asyncio_run_boundary(
async_gen_maker,
pool_executor: ThreadPoolExecutor,
path: Optional[str] = None,
provider_data: Optional[dict[str, Any]] = None,
) -> Generator[T, None, None]:
result_queue = queue.Queue()
stop_event = threading.Event()
@ -75,6 +77,10 @@ def stream_across_asyncio_run_boundary(
# make sure we make the generator in the event loop context
gen = await async_gen_maker()
await start_trace(path, {"__location__": "library_client"})
if provider_data:
set_request_provider_data(
{"X-LlamaStack-Provider-Data": json.dumps(provider_data)}
)
try:
async for item in await gen:
result_queue.put(item)
@ -174,6 +180,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
config_path_or_template_name: str,
skip_logger_removal: bool = False,
custom_provider_registry: Optional[ProviderRegistry] = None,
provider_data: Optional[dict[str, Any]] = None,
):
super().__init__()
self.async_client = AsyncLlamaStackAsLibraryClient(
@ -181,6 +188,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
)
self.pool_executor = ThreadPoolExecutor(max_workers=4)
self.skip_logger_removal = skip_logger_removal
self.provider_data = provider_data
def initialize(self):
if in_notebook():
@ -219,10 +227,15 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
lambda: self.async_client.request(*args, **kwargs),
self.pool_executor,
path=path,
provider_data=self.provider_data,
)
else:
async def _traced_request():
if self.provider_data:
set_request_provider_data(
{"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)}
)
await start_trace(path, {"__location__": "library_client"})
try:
return await self.async_client.request(*args, **kwargs)