mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
Add provider data passing for library client (#750)
# What does this PR do? This PR adds the provider data passing for the library client and changes the provider's api keys be unique ## Test Plan LLAMA_STACK_CONFIG="/Users/dineshyv/.llama/distributions/llamastack-fireworks/fireworks-run.yaml" pytest -v tests/client-sdk/agents/test_agents.py run.yaml: https://gist.github.com/dineshyv/0c10b5c7d0a2fb7ba4f0ecc8dcf860d1
This commit is contained in:
parent
6964510dc1
commit
314806cde3
10 changed files with 49 additions and 19 deletions
|
@ -33,6 +33,7 @@ from termcolor import cprint
|
|||
from llama_stack.distribution.build import print_pip_install_help
|
||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||
from llama_stack.distribution.resolver import ProviderRegistry
|
||||
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
||||
from llama_stack.distribution.stack import (
|
||||
|
@ -67,6 +68,7 @@ def stream_across_asyncio_run_boundary(
|
|||
async_gen_maker,
|
||||
pool_executor: ThreadPoolExecutor,
|
||||
path: Optional[str] = None,
|
||||
provider_data: Optional[dict[str, Any]] = None,
|
||||
) -> Generator[T, None, None]:
|
||||
result_queue = queue.Queue()
|
||||
stop_event = threading.Event()
|
||||
|
@ -75,6 +77,10 @@ def stream_across_asyncio_run_boundary(
|
|||
# make sure we make the generator in the event loop context
|
||||
gen = await async_gen_maker()
|
||||
await start_trace(path, {"__location__": "library_client"})
|
||||
if provider_data:
|
||||
set_request_provider_data(
|
||||
{"X-LlamaStack-Provider-Data": json.dumps(provider_data)}
|
||||
)
|
||||
try:
|
||||
async for item in await gen:
|
||||
result_queue.put(item)
|
||||
|
@ -174,6 +180,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
|||
config_path_or_template_name: str,
|
||||
skip_logger_removal: bool = False,
|
||||
custom_provider_registry: Optional[ProviderRegistry] = None,
|
||||
provider_data: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.async_client = AsyncLlamaStackAsLibraryClient(
|
||||
|
@ -181,6 +188,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
|||
)
|
||||
self.pool_executor = ThreadPoolExecutor(max_workers=4)
|
||||
self.skip_logger_removal = skip_logger_removal
|
||||
self.provider_data = provider_data
|
||||
|
||||
def initialize(self):
|
||||
if in_notebook():
|
||||
|
@ -219,10 +227,15 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
|||
lambda: self.async_client.request(*args, **kwargs),
|
||||
self.pool_executor,
|
||||
path=path,
|
||||
provider_data=self.provider_data,
|
||||
)
|
||||
else:
|
||||
|
||||
async def _traced_request():
|
||||
if self.provider_data:
|
||||
set_request_provider_data(
|
||||
{"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)}
|
||||
)
|
||||
await start_trace(path, {"__location__": "library_client"})
|
||||
try:
|
||||
return await self.async_client.request(*args, **kwargs)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue