mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-10 04:08:31 +00:00
fix tracing and add tests
This commit is contained in:
parent
714c09cd53
commit
518a5f898c
3 changed files with 170 additions and 16 deletions
|
@ -376,18 +376,17 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
|
|
||||||
body = self._convert_body(path, options.method, body)
|
body = self._convert_body(path, options.method, body)
|
||||||
|
|
||||||
await start_trace(options.url, {"__location__": "library_client"})
|
|
||||||
|
|
||||||
async def gen():
|
async def gen():
|
||||||
async for chunk in await func(**body):
|
await start_trace(options.url, {"__location__": "library_client"})
|
||||||
data = json.dumps(convert_pydantic_to_json_value(chunk))
|
try:
|
||||||
sse_event = f"data: {data}\n\n"
|
async for chunk in await func(**body):
|
||||||
yield sse_event.encode("utf-8")
|
data = json.dumps(convert_pydantic_to_json_value(chunk))
|
||||||
|
sse_event = f"data: {data}\n\n"
|
||||||
|
yield sse_event.encode("utf-8")
|
||||||
|
finally:
|
||||||
|
await end_trace()
|
||||||
|
|
||||||
try:
|
wrapped_gen = preserve_contexts_async_generator(gen(), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR])
|
||||||
wrapped_gen = preserve_contexts_async_generator(gen(), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR])
|
|
||||||
finally:
|
|
||||||
await end_trace()
|
|
||||||
|
|
||||||
mock_response = httpx.Response(
|
mock_response = httpx.Response(
|
||||||
status_code=httpx.codes.OK,
|
status_code=httpx.codes.OK,
|
||||||
|
|
|
@ -14,19 +14,19 @@ def preserve_contexts_async_generator(
|
||||||
gen: AsyncGenerator[T, None], context_vars: List[ContextVar]
|
gen: AsyncGenerator[T, None], context_vars: List[ContextVar]
|
||||||
) -> AsyncGenerator[T, None]:
|
) -> AsyncGenerator[T, None]:
|
||||||
"""
|
"""
|
||||||
Wraps an async generator to preserve both tracing and headers context variables across iterations.
|
Wraps an async generator to preserve context variables across iterations.
|
||||||
This is needed because we start a new asyncio event loop for each request, and we need to preserve the context
|
This is needed because we start a new asyncio event loop for each streaming request,
|
||||||
across the event loop boundary.
|
and we need to preserve the context across the event loop boundary.
|
||||||
"""
|
"""
|
||||||
context_values = [context_var.get() for context_var in context_vars]
|
|
||||||
|
|
||||||
async def wrapper():
|
async def wrapper():
|
||||||
while True:
|
while True:
|
||||||
for context_var, context_value in zip(context_vars, context_values, strict=False):
|
|
||||||
_ = context_var.set(context_value)
|
|
||||||
try:
|
try:
|
||||||
item = await gen.__anext__()
|
item = await gen.__anext__()
|
||||||
|
context_values = {context_var.name: context_var.get() for context_var in context_vars}
|
||||||
yield item
|
yield item
|
||||||
|
for context_var in context_vars:
|
||||||
|
_ = context_var.set(context_values[context_var.name])
|
||||||
except StopAsyncIteration:
|
except StopAsyncIteration:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
155
llama_stack/distribution/utils/tests/test_context.py
Normal file
155
llama_stack/distribution/utils/tests/test_context.py
Normal file
|
@ -0,0 +1,155 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from contextvars import ContextVar
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_preserve_contexts_with_exception():
|
||||||
|
# Create context variable
|
||||||
|
context_var = ContextVar("exception_var", default="initial")
|
||||||
|
token = context_var.set("start_value")
|
||||||
|
|
||||||
|
# Create an async generator that raises an exception
|
||||||
|
async def exception_generator():
|
||||||
|
yield context_var.get()
|
||||||
|
context_var.set("modified")
|
||||||
|
raise ValueError("Test exception")
|
||||||
|
yield None # This will never be reached
|
||||||
|
|
||||||
|
# Wrap the generator
|
||||||
|
wrapped_gen = preserve_contexts_async_generator(exception_generator(), [context_var])
|
||||||
|
|
||||||
|
# First iteration should work
|
||||||
|
value = await wrapped_gen.__anext__()
|
||||||
|
assert value == "start_value"
|
||||||
|
|
||||||
|
# Second iteration should raise the exception
|
||||||
|
with pytest.raises(ValueError, match="Test exception"):
|
||||||
|
await wrapped_gen.__anext__()
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
context_var.reset(token)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_preserve_contexts_empty_generator():
|
||||||
|
# Create context variable
|
||||||
|
context_var = ContextVar("empty_var", default="initial")
|
||||||
|
token = context_var.set("value")
|
||||||
|
|
||||||
|
# Create an empty async generator
|
||||||
|
async def empty_generator():
|
||||||
|
if False: # This condition ensures the generator yields nothing
|
||||||
|
yield None
|
||||||
|
|
||||||
|
# Wrap the generator
|
||||||
|
wrapped_gen = preserve_contexts_async_generator(empty_generator(), [context_var])
|
||||||
|
|
||||||
|
# The generator should raise StopAsyncIteration immediately
|
||||||
|
with pytest.raises(StopAsyncIteration):
|
||||||
|
await wrapped_gen.__anext__()
|
||||||
|
|
||||||
|
# Context variable should remain unchanged
|
||||||
|
assert context_var.get() == "value"
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
context_var.reset(token)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_preserve_contexts_across_event_loops():
|
||||||
|
"""
|
||||||
|
Test that context variables are preserved across event loop boundaries with nested generators.
|
||||||
|
This simulates the real-world scenario where:
|
||||||
|
1. A new event loop is created for each streaming request
|
||||||
|
2. The async generator runs inside that loop
|
||||||
|
3. There are multiple levels of nested generators
|
||||||
|
4. Context needs to be preserved across these boundaries
|
||||||
|
"""
|
||||||
|
# Create context variables
|
||||||
|
request_id = ContextVar("request_id", default=None)
|
||||||
|
user_id = ContextVar("user_id", default=None)
|
||||||
|
|
||||||
|
# Set initial values
|
||||||
|
|
||||||
|
# Results container to verify values across thread boundaries
|
||||||
|
results = []
|
||||||
|
|
||||||
|
# Inner-most generator (level 2)
|
||||||
|
async def inner_generator():
|
||||||
|
# Should have the context from the outer scope
|
||||||
|
yield (1, request_id.get(), user_id.get())
|
||||||
|
|
||||||
|
# Modify one context variable
|
||||||
|
user_id.set("user-modified")
|
||||||
|
|
||||||
|
# Should reflect the modification
|
||||||
|
yield (2, request_id.get(), user_id.get())
|
||||||
|
|
||||||
|
# Middle generator (level 1)
|
||||||
|
async def middle_generator():
|
||||||
|
inner_gen = inner_generator()
|
||||||
|
|
||||||
|
# Forward the first yield from inner
|
||||||
|
item = await inner_gen.__anext__()
|
||||||
|
yield item
|
||||||
|
|
||||||
|
# Forward the second yield from inner
|
||||||
|
item = await inner_gen.__anext__()
|
||||||
|
yield item
|
||||||
|
|
||||||
|
request_id.set("req-modified")
|
||||||
|
|
||||||
|
# Add our own yield with both modified variables
|
||||||
|
yield (3, request_id.get(), user_id.get())
|
||||||
|
|
||||||
|
# Function to run in a separate thread with a new event loop
|
||||||
|
def run_in_new_loop():
|
||||||
|
# Create a new event loop for this thread
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Outer generator (runs in the new loop)
|
||||||
|
async def outer_generator():
|
||||||
|
request_id.set("req-12345")
|
||||||
|
user_id.set("user-6789")
|
||||||
|
# Wrap the middle generator
|
||||||
|
wrapped_gen = preserve_contexts_async_generator(middle_generator(), [request_id, user_id])
|
||||||
|
|
||||||
|
# Process all items from the middle generator
|
||||||
|
async for item in wrapped_gen:
|
||||||
|
# Store results for verification
|
||||||
|
results.append(item)
|
||||||
|
|
||||||
|
# Run the outer generator in the new loop
|
||||||
|
loop.run_until_complete(outer_generator())
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
# Run the generator chain in a separate thread with a new event loop
|
||||||
|
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||||
|
future = executor.submit(run_in_new_loop)
|
||||||
|
future.result() # Wait for completion
|
||||||
|
|
||||||
|
# Verify the results
|
||||||
|
assert len(results) == 3
|
||||||
|
|
||||||
|
# First yield should have original values
|
||||||
|
assert results[0] == (1, "req-12345", "user-6789")
|
||||||
|
|
||||||
|
# Second yield should have modified user_id
|
||||||
|
assert results[1] == (2, "req-12345", "user-modified")
|
||||||
|
|
||||||
|
# Third yield should have both modified values
|
||||||
|
assert results[2] == (3, "req-modified", "user-modified")
|
Loading…
Add table
Add a link
Reference in a new issue