fix provider data

This commit is contained in:
Xi Yan 2025-01-13 20:39:28 -08:00
parent 9e430c253d
commit f4ba0e997a

View file

@ -117,7 +117,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
): ):
super().__init__() super().__init__()
self.async_client = AsyncLlamaStackAsLibraryClient( self.async_client = AsyncLlamaStackAsLibraryClient(
config_path_or_template_name, custom_provider_registry config_path_or_template_name, custom_provider_registry, provider_data
) )
self.pool_executor = ThreadPoolExecutor(max_workers=4) self.pool_executor = ThreadPoolExecutor(max_workers=4)
self.skip_logger_removal = skip_logger_removal self.skip_logger_removal = skip_logger_removal
@ -143,18 +143,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
root_logger.removeHandler(handler) root_logger.removeHandler(handler)
print(f"Removed handler {handler.__class__.__name__} from root logger") print(f"Removed handler {handler.__class__.__name__} from root logger")
def _get_path(
self,
cast_to: Any,
options: Any,
*,
stream=False,
stream_cls=None,
):
return options.url
def request(self, *args, **kwargs): def request(self, *args, **kwargs):
path = self._get_path(*args, **kwargs)
if kwargs.get("stream"): if kwargs.get("stream"):
# NOTE: We are using AsyncLlamaStackClient under the hood # NOTE: We are using AsyncLlamaStackClient under the hood
# A new event loop is needed to convert the AsyncStream # A new event loop is needed to convert the AsyncStream
@ -180,18 +169,10 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
return sync_generator() return sync_generator()
else: else:
async def _traced_request(): async def _request():
if self.provider_data: return await self.async_client.request(*args, **kwargs)
set_request_provider_data(
{"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)}
)
await start_trace(path, {"__location__": "library_client"})
try:
return await self.async_client.request(*args, **kwargs)
finally:
await end_trace()
return asyncio.run(_traced_request()) return asyncio.run(_request())
class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
@ -199,9 +180,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
self, self,
config_path_or_template_name: str, config_path_or_template_name: str,
custom_provider_registry: Optional[ProviderRegistry] = None, custom_provider_registry: Optional[ProviderRegistry] = None,
provider_data: Optional[dict[str, Any]] = None,
): ):
super().__init__() super().__init__()
# when using the library client, we should not log to console since many # when using the library client, we should not log to console since many
# of our logs are intended for server-side usage # of our logs are intended for server-side usage
current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",") current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",")
@ -222,6 +203,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
self.config_path_or_template_name = config_path_or_template_name self.config_path_or_template_name = config_path_or_template_name
self.config = config self.config = config
self.custom_provider_registry = custom_provider_registry self.custom_provider_registry = custom_provider_registry
self.provider_data = provider_data
async def initialize(self): async def initialize(self):
try: try:
@ -278,17 +260,24 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
if not self.endpoint_impls: if not self.endpoint_impls:
raise ValueError("Client not initialized") raise ValueError("Client not initialized")
if self.provider_data:
set_request_provider_data(
{"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)}
)
await start_trace(options.url, {"__location__": "library_client"})
if stream: if stream:
return await self._call_streaming( response = await self._call_streaming(
cast_to=cast_to, cast_to=cast_to,
options=options, options=options,
stream_cls=stream_cls, stream_cls=stream_cls,
) )
else: else:
return await self._call_non_streaming( response = await self._call_non_streaming(
cast_to=cast_to, cast_to=cast_to,
options=options, options=options,
) )
await end_trace()
return response
async def _call_non_streaming( async def _call_non_streaming(
self, self,