mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-10 04:08:31 +00:00
fix tracing and add tests
This commit is contained in:
parent
714c09cd53
commit
518a5f898c
3 changed files with 170 additions and 16 deletions
|
@ -376,19 +376,18 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
|
||||
body = self._convert_body(path, options.method, body)
|
||||
|
||||
await start_trace(options.url, {"__location__": "library_client"})
|
||||
|
||||
async def gen():
|
||||
await start_trace(options.url, {"__location__": "library_client"})
|
||||
try:
|
||||
async for chunk in await func(**body):
|
||||
data = json.dumps(convert_pydantic_to_json_value(chunk))
|
||||
sse_event = f"data: {data}\n\n"
|
||||
yield sse_event.encode("utf-8")
|
||||
|
||||
try:
|
||||
wrapped_gen = preserve_contexts_async_generator(gen(), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR])
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
wrapped_gen = preserve_contexts_async_generator(gen(), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR])
|
||||
|
||||
mock_response = httpx.Response(
|
||||
status_code=httpx.codes.OK,
|
||||
content=wrapped_gen,
|
||||
|
|
|
@ -14,19 +14,19 @@ def preserve_contexts_async_generator(
|
|||
gen: AsyncGenerator[T, None], context_vars: List[ContextVar]
|
||||
) -> AsyncGenerator[T, None]:
|
||||
"""
|
||||
Wraps an async generator to preserve both tracing and headers context variables across iterations.
|
||||
This is needed because we start a new asyncio event loop for each request, and we need to preserve the context
|
||||
across the event loop boundary.
|
||||
Wraps an async generator to preserve context variables across iterations.
|
||||
This is needed because we start a new asyncio event loop for each streaming request,
|
||||
and we need to preserve the context across the event loop boundary.
|
||||
"""
|
||||
context_values = [context_var.get() for context_var in context_vars]
|
||||
|
||||
async def wrapper():
|
||||
while True:
|
||||
for context_var, context_value in zip(context_vars, context_values, strict=False):
|
||||
_ = context_var.set(context_value)
|
||||
try:
|
||||
item = await gen.__anext__()
|
||||
context_values = {context_var.name: context_var.get() for context_var in context_vars}
|
||||
yield item
|
||||
for context_var in context_vars:
|
||||
_ = context_var.set(context_values[context_var.name])
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
|
|
155
llama_stack/distribution/utils/tests/test_context.py
Normal file
155
llama_stack/distribution/utils/tests/test_context.py
Normal file
|
@ -0,0 +1,155 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextvars import ContextVar
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preserve_contexts_with_exception():
|
||||
# Create context variable
|
||||
context_var = ContextVar("exception_var", default="initial")
|
||||
token = context_var.set("start_value")
|
||||
|
||||
# Create an async generator that raises an exception
|
||||
async def exception_generator():
|
||||
yield context_var.get()
|
||||
context_var.set("modified")
|
||||
raise ValueError("Test exception")
|
||||
yield None # This will never be reached
|
||||
|
||||
# Wrap the generator
|
||||
wrapped_gen = preserve_contexts_async_generator(exception_generator(), [context_var])
|
||||
|
||||
# First iteration should work
|
||||
value = await wrapped_gen.__anext__()
|
||||
assert value == "start_value"
|
||||
|
||||
# Second iteration should raise the exception
|
||||
with pytest.raises(ValueError, match="Test exception"):
|
||||
await wrapped_gen.__anext__()
|
||||
|
||||
# Clean up
|
||||
context_var.reset(token)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preserve_contexts_empty_generator():
|
||||
# Create context variable
|
||||
context_var = ContextVar("empty_var", default="initial")
|
||||
token = context_var.set("value")
|
||||
|
||||
# Create an empty async generator
|
||||
async def empty_generator():
|
||||
if False: # This condition ensures the generator yields nothing
|
||||
yield None
|
||||
|
||||
# Wrap the generator
|
||||
wrapped_gen = preserve_contexts_async_generator(empty_generator(), [context_var])
|
||||
|
||||
# The generator should raise StopAsyncIteration immediately
|
||||
with pytest.raises(StopAsyncIteration):
|
||||
await wrapped_gen.__anext__()
|
||||
|
||||
# Context variable should remain unchanged
|
||||
assert context_var.get() == "value"
|
||||
|
||||
# Clean up
|
||||
context_var.reset(token)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preserve_contexts_across_event_loops():
|
||||
"""
|
||||
Test that context variables are preserved across event loop boundaries with nested generators.
|
||||
This simulates the real-world scenario where:
|
||||
1. A new event loop is created for each streaming request
|
||||
2. The async generator runs inside that loop
|
||||
3. There are multiple levels of nested generators
|
||||
4. Context needs to be preserved across these boundaries
|
||||
"""
|
||||
# Create context variables
|
||||
request_id = ContextVar("request_id", default=None)
|
||||
user_id = ContextVar("user_id", default=None)
|
||||
|
||||
# Set initial values
|
||||
|
||||
# Results container to verify values across thread boundaries
|
||||
results = []
|
||||
|
||||
# Inner-most generator (level 2)
|
||||
async def inner_generator():
|
||||
# Should have the context from the outer scope
|
||||
yield (1, request_id.get(), user_id.get())
|
||||
|
||||
# Modify one context variable
|
||||
user_id.set("user-modified")
|
||||
|
||||
# Should reflect the modification
|
||||
yield (2, request_id.get(), user_id.get())
|
||||
|
||||
# Middle generator (level 1)
|
||||
async def middle_generator():
|
||||
inner_gen = inner_generator()
|
||||
|
||||
# Forward the first yield from inner
|
||||
item = await inner_gen.__anext__()
|
||||
yield item
|
||||
|
||||
# Forward the second yield from inner
|
||||
item = await inner_gen.__anext__()
|
||||
yield item
|
||||
|
||||
request_id.set("req-modified")
|
||||
|
||||
# Add our own yield with both modified variables
|
||||
yield (3, request_id.get(), user_id.get())
|
||||
|
||||
# Function to run in a separate thread with a new event loop
|
||||
def run_in_new_loop():
|
||||
# Create a new event loop for this thread
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
# Outer generator (runs in the new loop)
|
||||
async def outer_generator():
|
||||
request_id.set("req-12345")
|
||||
user_id.set("user-6789")
|
||||
# Wrap the middle generator
|
||||
wrapped_gen = preserve_contexts_async_generator(middle_generator(), [request_id, user_id])
|
||||
|
||||
# Process all items from the middle generator
|
||||
async for item in wrapped_gen:
|
||||
# Store results for verification
|
||||
results.append(item)
|
||||
|
||||
# Run the outer generator in the new loop
|
||||
loop.run_until_complete(outer_generator())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
# Run the generator chain in a separate thread with a new event loop
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(run_in_new_loop)
|
||||
future.result() # Wait for completion
|
||||
|
||||
# Verify the results
|
||||
assert len(results) == 3
|
||||
|
||||
# First yield should have original values
|
||||
assert results[0] == (1, "req-12345", "user-6789")
|
||||
|
||||
# Second yield should have modified user_id
|
||||
assert results[1] == (2, "req-12345", "user-modified")
|
||||
|
||||
# Third yield should have both modified values
|
||||
assert results[2] == (3, "req-modified", "user-modified")
|
Loading…
Add table
Add a link
Reference in a new issue