mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 02:32:40 +00:00
fix provider data
This commit is contained in:
parent
9e430c253d
commit
f4ba0e997a
1 changed files with 15 additions and 26 deletions
|
@ -117,7 +117,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.async_client = AsyncLlamaStackAsLibraryClient(
|
self.async_client = AsyncLlamaStackAsLibraryClient(
|
||||||
config_path_or_template_name, custom_provider_registry
|
config_path_or_template_name, custom_provider_registry, provider_data
|
||||||
)
|
)
|
||||||
self.pool_executor = ThreadPoolExecutor(max_workers=4)
|
self.pool_executor = ThreadPoolExecutor(max_workers=4)
|
||||||
self.skip_logger_removal = skip_logger_removal
|
self.skip_logger_removal = skip_logger_removal
|
||||||
|
@ -143,18 +143,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||||
root_logger.removeHandler(handler)
|
root_logger.removeHandler(handler)
|
||||||
print(f"Removed handler {handler.__class__.__name__} from root logger")
|
print(f"Removed handler {handler.__class__.__name__} from root logger")
|
||||||
|
|
||||||
def _get_path(
|
|
||||||
self,
|
|
||||||
cast_to: Any,
|
|
||||||
options: Any,
|
|
||||||
*,
|
|
||||||
stream=False,
|
|
||||||
stream_cls=None,
|
|
||||||
):
|
|
||||||
return options.url
|
|
||||||
|
|
||||||
def request(self, *args, **kwargs):
|
def request(self, *args, **kwargs):
|
||||||
path = self._get_path(*args, **kwargs)
|
|
||||||
if kwargs.get("stream"):
|
if kwargs.get("stream"):
|
||||||
# NOTE: We are using AsyncLlamaStackClient under the hood
|
# NOTE: We are using AsyncLlamaStackClient under the hood
|
||||||
# A new event loop is needed to convert the AsyncStream
|
# A new event loop is needed to convert the AsyncStream
|
||||||
|
@ -180,18 +169,10 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||||
return sync_generator()
|
return sync_generator()
|
||||||
else:
|
else:
|
||||||
|
|
||||||
async def _traced_request():
|
async def _request():
|
||||||
if self.provider_data:
|
return await self.async_client.request(*args, **kwargs)
|
||||||
set_request_provider_data(
|
|
||||||
{"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)}
|
|
||||||
)
|
|
||||||
await start_trace(path, {"__location__": "library_client"})
|
|
||||||
try:
|
|
||||||
return await self.async_client.request(*args, **kwargs)
|
|
||||||
finally:
|
|
||||||
await end_trace()
|
|
||||||
|
|
||||||
return asyncio.run(_traced_request())
|
return asyncio.run(_request())
|
||||||
|
|
||||||
|
|
||||||
class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
|
@ -199,9 +180,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
self,
|
self,
|
||||||
config_path_or_template_name: str,
|
config_path_or_template_name: str,
|
||||||
custom_provider_registry: Optional[ProviderRegistry] = None,
|
custom_provider_registry: Optional[ProviderRegistry] = None,
|
||||||
|
provider_data: Optional[dict[str, Any]] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# when using the library client, we should not log to console since many
|
# when using the library client, we should not log to console since many
|
||||||
# of our logs are intended for server-side usage
|
# of our logs are intended for server-side usage
|
||||||
current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",")
|
current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",")
|
||||||
|
@ -222,6 +203,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
self.config_path_or_template_name = config_path_or_template_name
|
self.config_path_or_template_name = config_path_or_template_name
|
||||||
self.config = config
|
self.config = config
|
||||||
self.custom_provider_registry = custom_provider_registry
|
self.custom_provider_registry = custom_provider_registry
|
||||||
|
self.provider_data = provider_data
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
try:
|
try:
|
||||||
|
@ -278,17 +260,24 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
if not self.endpoint_impls:
|
if not self.endpoint_impls:
|
||||||
raise ValueError("Client not initialized")
|
raise ValueError("Client not initialized")
|
||||||
|
|
||||||
|
if self.provider_data:
|
||||||
|
set_request_provider_data(
|
||||||
|
{"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)}
|
||||||
|
)
|
||||||
|
await start_trace(options.url, {"__location__": "library_client"})
|
||||||
if stream:
|
if stream:
|
||||||
return await self._call_streaming(
|
response = await self._call_streaming(
|
||||||
cast_to=cast_to,
|
cast_to=cast_to,
|
||||||
options=options,
|
options=options,
|
||||||
stream_cls=stream_cls,
|
stream_cls=stream_cls,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return await self._call_non_streaming(
|
response = await self._call_non_streaming(
|
||||||
cast_to=cast_to,
|
cast_to=cast_to,
|
||||||
options=options,
|
options=options,
|
||||||
)
|
)
|
||||||
|
await end_trace()
|
||||||
|
return response
|
||||||
|
|
||||||
async def _call_non_streaming(
|
async def _call_non_streaming(
|
||||||
self,
|
self,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue