pre-commit fixes

This commit is contained in:
Chantal D Gama Rose 2025-03-14 13:56:05 -07:00
parent 967dd0aa08
commit 7e211f8553
314 changed files with 5574 additions and 11369 deletions

View file

@ -32,7 +32,10 @@ from termcolor import cprint
from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.datatypes import Api
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.request_headers import (
PROVIDER_DATA_VAR,
request_provider_data_context,
)
from llama_stack.distribution.resolver import ProviderRegistry
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
from llama_stack.distribution.stack import (
@ -41,8 +44,10 @@ from llama_stack.distribution.stack import (
redact_sensitive_fields,
replace_env_vars,
)
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
from llama_stack.distribution.utils.exec import in_notebook
from llama_stack.providers.utils.telemetry.tracing import (
CURRENT_TRACE_CONTEXT,
end_trace,
setup_logger,
start_trace,
@ -160,6 +165,9 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
except StopAsyncIteration:
pass
finally:
pending = asyncio.all_tasks(loop)
if pending:
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
loop.close()
return sync_generator()
@ -262,21 +270,25 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
if not self.endpoint_impls:
raise ValueError("Client not initialized")
# Create headers with provider data if available
headers = {}
if self.provider_data:
set_request_provider_data({"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)})
headers["X-LlamaStack-Provider-Data"] = json.dumps(self.provider_data)
if stream:
response = await self._call_streaming(
cast_to=cast_to,
options=options,
stream_cls=stream_cls,
)
else:
response = await self._call_non_streaming(
cast_to=cast_to,
options=options,
)
return response
# Use context manager for provider data
with request_provider_data_context(headers):
if stream:
response = await self._call_streaming(
cast_to=cast_to,
options=options,
stream_cls=stream_cls,
)
else:
response = await self._call_non_streaming(
cast_to=cast_to,
options=options,
)
return response
def _find_matching_endpoint(self, method: str, path: str) -> tuple[Any, dict]:
"""Find the matching endpoint implementation for a given method and path.
@ -374,9 +386,11 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
finally:
await end_trace()
wrapped_gen = preserve_contexts_async_generator(gen(), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR])
mock_response = httpx.Response(
status_code=httpx.codes.OK,
content=gen(),
content=wrapped_gen,
headers={
"Content-Type": "application/json",
},