Merge-related changes.

This commit is contained in:
ilya-kolchinsky 2025-04-02 19:56:44 +02:00
commit 60e9f46856
456 changed files with 38636 additions and 10892 deletions

View file

@ -41,10 +41,10 @@ async def execute_preprocessor_chain(
preprocessor_inputs: List[PreprocessingDataElement],
) -> PreprocessorResponse:
if not validate_chain(preprocessor_chain_impls):
return PreprocessorResponse(success=False, results=[])
return PreprocessorResponse(success=False, output_data_type=None, results=[])
current_inputs = preprocessor_inputs
current_outputs = []
current_outputs: List[PreprocessingDataElement] | None = []
current_result_type = None
# TODO: replace with a parallel implementation
@ -59,6 +59,9 @@ async def execute_preprocessor_chain(
log.error(f"Preprocessor {current_params.preprocessor_id} returned an error")
return PreprocessorResponse(success=False, output_data_type=response.output_data_type, results=[])
current_outputs = response.results
if current_outputs is None:
log.error(f"Preprocessor {current_params.preprocessor_id} returned invalid results")
return PreprocessorResponse(success=False, output_data_type=response.output_data_type, results=[])
current_inputs = current_outputs
current_result_type = response.output_data_type

View file

@ -0,0 +1,37 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from contextvars import ContextVar
from typing import AsyncGenerator, List, TypeVar
T = TypeVar("T")
def preserve_contexts_async_generator(
gen: AsyncGenerator[T, None], context_vars: List[ContextVar]
) -> AsyncGenerator[T, None]:
"""
Wraps an async generator to preserve context variables across iterations.
This is needed because we start a new asyncio event loop for each streaming request,
and we need to preserve the context across the event loop boundary.
"""
# Capture initial context values
initial_context_values = {context_var.name: context_var.get() for context_var in context_vars}
async def wrapper() -> AsyncGenerator[T, None]:
while True:
try:
# Restore context values before any await
for context_var in context_vars:
context_var.set(initial_context_values[context_var.name])
item = await gen.__anext__()
yield item
except StopAsyncIteration:
break
return wrapper()

View file

@ -4,13 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import errno
import logging
import os
import select
import signal
import subprocess
import sys
from termcolor import cprint
@ -88,13 +85,6 @@ def formulate_run_args(image_type, image_name, config, template_name) -> list:
return run_args
def run_with_pty(command):
if sys.platform.startswith("win"):
return _run_with_pty_win(command)
else:
return _run_with_pty_unix(command)
def in_notebook():
try:
from IPython import get_ipython
@ -108,19 +98,19 @@ def in_notebook():
return True
# run a command in a pseudo-terminal, with interrupt handling,
# useful when you want to run interactive things
def _run_with_pty_unix(command):
import pty
import termios
def run_command(command: list[str]) -> int:
"""
Run a command with interrupt handling and output capture.
Uses subprocess.run with direct stream piping for better performance.
master, slave = pty.openpty()
Args:
command (list): The command to run.
old_settings = termios.tcgetattr(sys.stdin)
Returns:
int: The return code of the command.
"""
original_sigint = signal.getsignal(signal.SIGINT)
ctrl_c_pressed = False
process = None
def sigint_handler(signum, frame):
nonlocal ctrl_c_pressed
@ -131,106 +121,19 @@ def _run_with_pty_unix(command):
# Set up the signal handler
signal.signal(signal.SIGINT, sigint_handler)
new_settings = termios.tcgetattr(sys.stdin)
new_settings[3] = new_settings[3] & ~termios.ECHO # Disable echo
new_settings[3] = new_settings[3] & ~termios.ICANON # Disable canonical mode
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, new_settings)
process = subprocess.Popen(
# Run the command with stdout/stderr piped directly to system streams
result = subprocess.run(
command,
stdin=slave,
stdout=slave,
stderr=slave,
universal_newlines=True,
preexec_fn=os.setsid,
text=True,
check=False,
)
# Close the slave file descriptor as it's now owned by the subprocess
os.close(slave)
def handle_io():
while not ctrl_c_pressed:
try:
rlist, _, _ = select.select([sys.stdin, master], [], [], 0.1)
if sys.stdin in rlist:
data = os.read(sys.stdin.fileno(), 1024)
if not data:
break
os.write(master, data)
if master in rlist:
data = os.read(master, 1024)
if not data:
break
sys.stdout.buffer.write(data)
sys.stdout.flush()
except KeyboardInterrupt:
# This will be raised when Ctrl+C is pressed
break
if process.poll() is not None:
break
handle_io()
except (EOFError, KeyboardInterrupt):
pass
except OSError as e:
if e.errno != errno.EIO:
raise
finally:
# Clean up
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_settings)
signal.signal(signal.SIGINT, original_sigint)
os.close(master)
if process and process.poll() is None:
process.terminate()
process.wait()
return process.returncode
# run a command in a pseudo-terminal in windows, with interrupt handling,
def _run_with_pty_win(command):
"""
Runs a command with interactive support using subprocess directly.
"""
try:
# For shell scripts on Windows, use appropriate shell
if isinstance(command, (list, tuple)):
if command[0].endswith(".sh"):
if os.path.exists("/usr/bin/bash"): # WSL
command = ["bash"] + command
else:
# Use cmd.exe with bash while preserving all arguments
command = ["cmd.exe", "/c", "bash"] + command
process = subprocess.Popen(
command,
shell=True,
universal_newlines=True,
)
process.wait()
return result.returncode
except subprocess.SubprocessError as e:
log.error(f"Subprocess error: {e}")
return 1
except Exception as e:
print(f"Error: {str(e)}")
log.exception(f"Unexpected error: {e}")
return 1
finally:
if process and process.poll() is None:
process.terminate()
process.wait()
return process.returncode
def run_command(command):
try:
result = subprocess.run(command, capture_output=True, text=True, check=True)
print("Script Output\n", result.stdout)
return result.returncode
except subprocess.CalledProcessError as e:
print("Error running script:", e)
print("Error output:", e.stderr)
return e.returncode
# Restore the original signal handler
signal.signal(signal.SIGINT, original_sigint)

View file

@ -0,0 +1,155 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from concurrent.futures import ThreadPoolExecutor
from contextvars import ContextVar
import pytest
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
@pytest.mark.asyncio
async def test_preserve_contexts_with_exception():
# Create context variable
context_var = ContextVar("exception_var", default="initial")
token = context_var.set("start_value")
# Create an async generator that raises an exception
async def exception_generator():
yield context_var.get()
context_var.set("modified")
raise ValueError("Test exception")
yield None # This will never be reached
# Wrap the generator
wrapped_gen = preserve_contexts_async_generator(exception_generator(), [context_var])
# First iteration should work
value = await wrapped_gen.__anext__()
assert value == "start_value"
# Second iteration should raise the exception
with pytest.raises(ValueError, match="Test exception"):
await wrapped_gen.__anext__()
# Clean up
context_var.reset(token)
@pytest.mark.asyncio
async def test_preserve_contexts_empty_generator():
# Create context variable
context_var = ContextVar("empty_var", default="initial")
token = context_var.set("value")
# Create an empty async generator
async def empty_generator():
if False: # This condition ensures the generator yields nothing
yield None
# Wrap the generator
wrapped_gen = preserve_contexts_async_generator(empty_generator(), [context_var])
# The generator should raise StopAsyncIteration immediately
with pytest.raises(StopAsyncIteration):
await wrapped_gen.__anext__()
# Context variable should remain unchanged
assert context_var.get() == "value"
# Clean up
context_var.reset(token)
@pytest.mark.asyncio
async def test_preserve_contexts_across_event_loops():
"""
Test that context variables are preserved across event loop boundaries with nested generators.
This simulates the real-world scenario where:
1. A new event loop is created for each streaming request
2. The async generator runs inside that loop
3. There are multiple levels of nested generators
4. Context needs to be preserved across these boundaries
"""
# Create context variables
request_id = ContextVar("request_id", default=None)
user_id = ContextVar("user_id", default=None)
# Set initial values
# Results container to verify values across thread boundaries
results = []
# Inner-most generator (level 2)
async def inner_generator():
# Should have the context from the outer scope
yield (1, request_id.get(), user_id.get())
# Modify one context variable
user_id.set("user-modified")
# Should reflect the modification
yield (2, request_id.get(), user_id.get())
# Middle generator (level 1)
async def middle_generator():
inner_gen = inner_generator()
# Forward the first yield from inner
item = await inner_gen.__anext__()
yield item
# Forward the second yield from inner
item = await inner_gen.__anext__()
yield item
request_id.set("req-modified")
# Add our own yield with both modified variables
yield (3, request_id.get(), user_id.get())
# Function to run in a separate thread with a new event loop
def run_in_new_loop():
# Create a new event loop for this thread
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
# Outer generator (runs in the new loop)
async def outer_generator():
request_id.set("req-12345")
user_id.set("user-6789")
# Wrap the middle generator
wrapped_gen = preserve_contexts_async_generator(middle_generator(), [request_id, user_id])
# Process all items from the middle generator
async for item in wrapped_gen:
# Store results for verification
results.append(item)
# Run the outer generator in the new loop
loop.run_until_complete(outer_generator())
finally:
loop.close()
# Run the generator chain in a separate thread with a new event loop
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(run_in_new_loop)
future.result() # Wait for completion
# Verify the results
assert len(results) == 3
# First yield should have original values
assert results[0] == (1, "req-12345", "user-6789")
# Second yield should have modified user_id
assert results[1] == (2, "req-12345", "user-modified")
# Third yield should have both modified values
assert results[2] == (3, "req-modified", "user-modified")