mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-12 13:57:57 +00:00
fix(tests): handle TEST_CONTEXT not being set
This commit is contained in:
parent
dac1d7be1c
commit
f365961731
1 changed files with 10 additions and 3 deletions
|
@ -250,6 +250,8 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
||||||
await log_request_pre_validation(request)
|
await log_request_pre_validation(request)
|
||||||
|
|
||||||
test_context_token = None
|
test_context_token = None
|
||||||
|
test_context_var = None
|
||||||
|
reset_test_context_fn = None
|
||||||
|
|
||||||
# Use context manager with both provider data and auth attributes
|
# Use context manager with both provider data and auth attributes
|
||||||
with request_provider_data_context(request.headers, user):
|
with request_provider_data_context(request.headers, user):
|
||||||
|
@ -261,13 +263,18 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
||||||
)
|
)
|
||||||
|
|
||||||
test_context_token = sync_test_context_from_provider_data()
|
test_context_token = sync_test_context_from_provider_data()
|
||||||
|
test_context_var = TEST_CONTEXT
|
||||||
|
reset_test_context_fn = reset_test_context
|
||||||
|
|
||||||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if is_streaming:
|
if is_streaming:
|
||||||
|
context_vars = [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR]
|
||||||
|
if test_context_var is not None:
|
||||||
|
context_vars.append(test_context_var)
|
||||||
gen = preserve_contexts_async_generator(
|
gen = preserve_contexts_async_generator(
|
||||||
sse_generator(func(**kwargs)), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR, TEST_CONTEXT]
|
sse_generator(func(**kwargs)), context_vars
|
||||||
)
|
)
|
||||||
return StreamingResponse(gen, media_type="text/event-stream")
|
return StreamingResponse(gen, media_type="text/event-stream")
|
||||||
else:
|
else:
|
||||||
|
@ -287,8 +294,8 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
||||||
logger.error(f"Error executing endpoint {route=} {method=}: {str(e)}")
|
logger.error(f"Error executing endpoint {route=} {method=}: {str(e)}")
|
||||||
raise translate_exception(e) from e
|
raise translate_exception(e) from e
|
||||||
finally:
|
finally:
|
||||||
if test_context_token is not None:
|
if test_context_token is not None and reset_test_context_fn is not None:
|
||||||
reset_test_context(test_context_token)
|
reset_test_context_fn(test_context_token)
|
||||||
|
|
||||||
sig = inspect.signature(func)
|
sig = inspect.signature(func)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue