forked from phoenix-oss/llama-stack-mirror
# What does this PR do? This PR adds back the changes in #1300 which were reverted in #1476 . It also adds logic to preserve context variables across asyncio boundary. this is needed with the library client since the async generator logic yields control to code outside the event loop, and on resuming, does not have the same context as before and this requires preserving the context vars. address #1477 ## Test Plan ``` curl --request POST \ --url http://localhost:8321/v1/inference/chat-completion \ --header 'content-type: application/json' \ --data '{ "model_id": "meta-llama/Llama-3.1-70B-Instruct", "messages": [ { "role": "user", "content": { "type": "text", "text": "where do humans live" } } ], "stream": false }' | jq . { "metrics": [ { "trace_id": "kCZwO3tyQC-FuAGb", "span_id": "bsP_5a5O", "timestamp": "2025-03-11T16:47:38.549084Z", "attributes": { "model_id": "meta-llama/Llama-3.1-70B-Instruct", "provider_id": "fireworks" }, "type": "metric", "metric": "prompt_tokens", "value": 10, "unit": "tokens" }, { "trace_id": "kCZwO3tyQC-FuAGb", "span_id": "bsP_5a5O", "timestamp": "2025-03-11T16:47:38.549449Z", "attributes": { "model_id": "meta-llama/Llama-3.1-70B-Instruct", "provider_id": "fireworks" }, "type": "metric", "metric": "completion_tokens", "value": 369, "unit": "tokens" }, { "trace_id": "kCZwO3tyQC-FuAGb", "span_id": "bsP_5a5O", "timestamp": "2025-03-11T16:47:38.549457Z", "attributes": { "model_id": "meta-llama/Llama-3.1-70B-Instruct", "provider_id": "fireworks" }, "type": "metric", "metric": "total_tokens", "value": 379, "unit": "tokens" } ], "completion_message": { "role": "assistant", "content": "Humans live on the planet Earth, specifically on its landmasses and in its oceans. Here's a breakdown of where humans live:\n\n1. **Continents:** Humans inhabit all seven continents:\n\t* Africa\n\t* Antarctica ( temporary residents, mostly scientists and researchers)\n\t* Asia\n\t* Australia\n\t* Europe\n\t* North America\n\t* South America\n2. **Countries:** There are 196 countries recognized by the United Nations, and humans live in almost all of them.\n3. **Cities and towns:** Many humans live in urban areas, such as cities and towns, which are often located near coastlines, rivers, or other bodies of water.\n4. **Rural areas:** Some humans live in rural areas, such as villages, farms, and countryside.\n5. **Islands:** Humans inhabit many islands around the world, including those in the Pacific, Indian, and Atlantic Oceans.\n6. **Mountains and highlands:** Humans live in mountainous regions, such as the Himalayas, the Andes, and the Rocky Mountains.\n7. **Deserts:** Some humans live in desert regions, such as the Sahara, the Mojave, and the Atacama.\n8. **Coastal areas:** Many humans live in coastal areas, such as beaches, ports, and coastal cities.\n9. **Underwater habitats:** A few humans live in underwater habitats, such as research stations and submarines.\n10. **Space:** A small number of humans have lived in space, including astronauts on the International Space Station and those who have visited the Moon.\n\nOverall, humans can be found living in almost every environment on Earth, from the frozen tundra to the hottest deserts, and from the highest mountains to the deepest oceans.", "stop_reason": "end_of_turn", "tool_calls": [] }, "logprobs": null } ``` Orignal repro no longer showing any error: ``` LLAMA_STACK_DISABLE_VERSION_CHECK=true llama stack run ~/.llama/distributions/fireworks/fireworks-run.yaml python -m examples.agents.e2e_loop_with_client_tools localhost 8321 ``` client logs: https://gist.github.com/dineshyv/047c7e87b18a5792aa660e311ea53166 server logs: https://gist.github.com/dineshyv/97a2174099619e9916c7c490be26e559
155 lines
5 KiB
Python
155 lines
5 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import asyncio
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from contextvars import ContextVar
|
|
|
|
import pytest
|
|
|
|
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_preserve_contexts_with_exception():
|
|
# Create context variable
|
|
context_var = ContextVar("exception_var", default="initial")
|
|
token = context_var.set("start_value")
|
|
|
|
# Create an async generator that raises an exception
|
|
async def exception_generator():
|
|
yield context_var.get()
|
|
context_var.set("modified")
|
|
raise ValueError("Test exception")
|
|
yield None # This will never be reached
|
|
|
|
# Wrap the generator
|
|
wrapped_gen = preserve_contexts_async_generator(exception_generator(), [context_var])
|
|
|
|
# First iteration should work
|
|
value = await wrapped_gen.__anext__()
|
|
assert value == "start_value"
|
|
|
|
# Second iteration should raise the exception
|
|
with pytest.raises(ValueError, match="Test exception"):
|
|
await wrapped_gen.__anext__()
|
|
|
|
# Clean up
|
|
context_var.reset(token)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_preserve_contexts_empty_generator():
|
|
# Create context variable
|
|
context_var = ContextVar("empty_var", default="initial")
|
|
token = context_var.set("value")
|
|
|
|
# Create an empty async generator
|
|
async def empty_generator():
|
|
if False: # This condition ensures the generator yields nothing
|
|
yield None
|
|
|
|
# Wrap the generator
|
|
wrapped_gen = preserve_contexts_async_generator(empty_generator(), [context_var])
|
|
|
|
# The generator should raise StopAsyncIteration immediately
|
|
with pytest.raises(StopAsyncIteration):
|
|
await wrapped_gen.__anext__()
|
|
|
|
# Context variable should remain unchanged
|
|
assert context_var.get() == "value"
|
|
|
|
# Clean up
|
|
context_var.reset(token)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_preserve_contexts_across_event_loops():
|
|
"""
|
|
Test that context variables are preserved across event loop boundaries with nested generators.
|
|
This simulates the real-world scenario where:
|
|
1. A new event loop is created for each streaming request
|
|
2. The async generator runs inside that loop
|
|
3. There are multiple levels of nested generators
|
|
4. Context needs to be preserved across these boundaries
|
|
"""
|
|
# Create context variables
|
|
request_id = ContextVar("request_id", default=None)
|
|
user_id = ContextVar("user_id", default=None)
|
|
|
|
# Set initial values
|
|
|
|
# Results container to verify values across thread boundaries
|
|
results = []
|
|
|
|
# Inner-most generator (level 2)
|
|
async def inner_generator():
|
|
# Should have the context from the outer scope
|
|
yield (1, request_id.get(), user_id.get())
|
|
|
|
# Modify one context variable
|
|
user_id.set("user-modified")
|
|
|
|
# Should reflect the modification
|
|
yield (2, request_id.get(), user_id.get())
|
|
|
|
# Middle generator (level 1)
|
|
async def middle_generator():
|
|
inner_gen = inner_generator()
|
|
|
|
# Forward the first yield from inner
|
|
item = await inner_gen.__anext__()
|
|
yield item
|
|
|
|
# Forward the second yield from inner
|
|
item = await inner_gen.__anext__()
|
|
yield item
|
|
|
|
request_id.set("req-modified")
|
|
|
|
# Add our own yield with both modified variables
|
|
yield (3, request_id.get(), user_id.get())
|
|
|
|
# Function to run in a separate thread with a new event loop
|
|
def run_in_new_loop():
|
|
# Create a new event loop for this thread
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
|
|
try:
|
|
# Outer generator (runs in the new loop)
|
|
async def outer_generator():
|
|
request_id.set("req-12345")
|
|
user_id.set("user-6789")
|
|
# Wrap the middle generator
|
|
wrapped_gen = preserve_contexts_async_generator(middle_generator(), [request_id, user_id])
|
|
|
|
# Process all items from the middle generator
|
|
async for item in wrapped_gen:
|
|
# Store results for verification
|
|
results.append(item)
|
|
|
|
# Run the outer generator in the new loop
|
|
loop.run_until_complete(outer_generator())
|
|
finally:
|
|
loop.close()
|
|
|
|
# Run the generator chain in a separate thread with a new event loop
|
|
with ThreadPoolExecutor(max_workers=1) as executor:
|
|
future = executor.submit(run_in_new_loop)
|
|
future.result() # Wait for completion
|
|
|
|
# Verify the results
|
|
assert len(results) == 3
|
|
|
|
# First yield should have original values
|
|
assert results[0] == (1, "req-12345", "user-6789")
|
|
|
|
# Second yield should have modified user_id
|
|
assert results[1] == (2, "req-12345", "user-modified")
|
|
|
|
# Third yield should have both modified values
|
|
assert results[2] == (3, "req-modified", "user-modified")
|