mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-02 05:39:59 +00:00
Merge-related changes.
This commit is contained in:
commit
60e9f46856
456 changed files with 38636 additions and 10892 deletions
|
|
@ -41,10 +41,10 @@ async def execute_preprocessor_chain(
|
|||
preprocessor_inputs: List[PreprocessingDataElement],
|
||||
) -> PreprocessorResponse:
|
||||
if not validate_chain(preprocessor_chain_impls):
|
||||
return PreprocessorResponse(success=False, results=[])
|
||||
return PreprocessorResponse(success=False, output_data_type=None, results=[])
|
||||
|
||||
current_inputs = preprocessor_inputs
|
||||
current_outputs = []
|
||||
current_outputs: List[PreprocessingDataElement] | None = []
|
||||
current_result_type = None
|
||||
|
||||
# TODO: replace with a parallel implementation
|
||||
|
|
@ -59,6 +59,9 @@ async def execute_preprocessor_chain(
|
|||
log.error(f"Preprocessor {current_params.preprocessor_id} returned an error")
|
||||
return PreprocessorResponse(success=False, output_data_type=response.output_data_type, results=[])
|
||||
current_outputs = response.results
|
||||
if current_outputs is None:
|
||||
log.error(f"Preprocessor {current_params.preprocessor_id} returned invalid results")
|
||||
return PreprocessorResponse(success=False, output_data_type=response.output_data_type, results=[])
|
||||
current_inputs = current_outputs
|
||||
current_result_type = response.output_data_type
|
||||
|
||||
|
|
|
|||
37
llama_stack/distribution/utils/context.py
Normal file
37
llama_stack/distribution/utils/context.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from contextvars import ContextVar
|
||||
from typing import AsyncGenerator, List, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def preserve_contexts_async_generator(
|
||||
gen: AsyncGenerator[T, None], context_vars: List[ContextVar]
|
||||
) -> AsyncGenerator[T, None]:
|
||||
"""
|
||||
Wraps an async generator to preserve context variables across iterations.
|
||||
This is needed because we start a new asyncio event loop for each streaming request,
|
||||
and we need to preserve the context across the event loop boundary.
|
||||
"""
|
||||
# Capture initial context values
|
||||
initial_context_values = {context_var.name: context_var.get() for context_var in context_vars}
|
||||
|
||||
async def wrapper() -> AsyncGenerator[T, None]:
|
||||
while True:
|
||||
try:
|
||||
# Restore context values before any await
|
||||
for context_var in context_vars:
|
||||
context_var.set(initial_context_values[context_var.name])
|
||||
|
||||
item = await gen.__anext__()
|
||||
yield item
|
||||
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
return wrapper()
|
||||
|
|
@ -4,13 +4,10 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import errno
|
||||
import logging
|
||||
import os
|
||||
import select
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
|
|
@ -88,13 +85,6 @@ def formulate_run_args(image_type, image_name, config, template_name) -> list:
|
|||
return run_args
|
||||
|
||||
|
||||
def run_with_pty(command):
|
||||
if sys.platform.startswith("win"):
|
||||
return _run_with_pty_win(command)
|
||||
else:
|
||||
return _run_with_pty_unix(command)
|
||||
|
||||
|
||||
def in_notebook():
|
||||
try:
|
||||
from IPython import get_ipython
|
||||
|
|
@ -108,19 +98,19 @@ def in_notebook():
|
|||
return True
|
||||
|
||||
|
||||
# run a command in a pseudo-terminal, with interrupt handling,
|
||||
# useful when you want to run interactive things
|
||||
def _run_with_pty_unix(command):
|
||||
import pty
|
||||
import termios
|
||||
def run_command(command: list[str]) -> int:
|
||||
"""
|
||||
Run a command with interrupt handling and output capture.
|
||||
Uses subprocess.run with direct stream piping for better performance.
|
||||
|
||||
master, slave = pty.openpty()
|
||||
Args:
|
||||
command (list): The command to run.
|
||||
|
||||
old_settings = termios.tcgetattr(sys.stdin)
|
||||
Returns:
|
||||
int: The return code of the command.
|
||||
"""
|
||||
original_sigint = signal.getsignal(signal.SIGINT)
|
||||
|
||||
ctrl_c_pressed = False
|
||||
process = None
|
||||
|
||||
def sigint_handler(signum, frame):
|
||||
nonlocal ctrl_c_pressed
|
||||
|
|
@ -131,106 +121,19 @@ def _run_with_pty_unix(command):
|
|||
# Set up the signal handler
|
||||
signal.signal(signal.SIGINT, sigint_handler)
|
||||
|
||||
new_settings = termios.tcgetattr(sys.stdin)
|
||||
new_settings[3] = new_settings[3] & ~termios.ECHO # Disable echo
|
||||
new_settings[3] = new_settings[3] & ~termios.ICANON # Disable canonical mode
|
||||
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, new_settings)
|
||||
|
||||
process = subprocess.Popen(
|
||||
# Run the command with stdout/stderr piped directly to system streams
|
||||
result = subprocess.run(
|
||||
command,
|
||||
stdin=slave,
|
||||
stdout=slave,
|
||||
stderr=slave,
|
||||
universal_newlines=True,
|
||||
preexec_fn=os.setsid,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
|
||||
# Close the slave file descriptor as it's now owned by the subprocess
|
||||
os.close(slave)
|
||||
|
||||
def handle_io():
|
||||
while not ctrl_c_pressed:
|
||||
try:
|
||||
rlist, _, _ = select.select([sys.stdin, master], [], [], 0.1)
|
||||
|
||||
if sys.stdin in rlist:
|
||||
data = os.read(sys.stdin.fileno(), 1024)
|
||||
if not data:
|
||||
break
|
||||
os.write(master, data)
|
||||
|
||||
if master in rlist:
|
||||
data = os.read(master, 1024)
|
||||
if not data:
|
||||
break
|
||||
sys.stdout.buffer.write(data)
|
||||
sys.stdout.flush()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
# This will be raised when Ctrl+C is pressed
|
||||
break
|
||||
|
||||
if process.poll() is not None:
|
||||
break
|
||||
|
||||
handle_io()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
pass
|
||||
except OSError as e:
|
||||
if e.errno != errno.EIO:
|
||||
raise
|
||||
finally:
|
||||
# Clean up
|
||||
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_settings)
|
||||
signal.signal(signal.SIGINT, original_sigint)
|
||||
|
||||
os.close(master)
|
||||
if process and process.poll() is None:
|
||||
process.terminate()
|
||||
process.wait()
|
||||
|
||||
return process.returncode
|
||||
|
||||
|
||||
# run a command in a pseudo-terminal in windows, with interrupt handling,
|
||||
def _run_with_pty_win(command):
|
||||
"""
|
||||
Runs a command with interactive support using subprocess directly.
|
||||
"""
|
||||
try:
|
||||
# For shell scripts on Windows, use appropriate shell
|
||||
if isinstance(command, (list, tuple)):
|
||||
if command[0].endswith(".sh"):
|
||||
if os.path.exists("/usr/bin/bash"): # WSL
|
||||
command = ["bash"] + command
|
||||
else:
|
||||
# Use cmd.exe with bash while preserving all arguments
|
||||
command = ["cmd.exe", "/c", "bash"] + command
|
||||
|
||||
process = subprocess.Popen(
|
||||
command,
|
||||
shell=True,
|
||||
universal_newlines=True,
|
||||
)
|
||||
|
||||
process.wait()
|
||||
|
||||
return result.returncode
|
||||
except subprocess.SubprocessError as e:
|
||||
log.error(f"Subprocess error: {e}")
|
||||
return 1
|
||||
except Exception as e:
|
||||
print(f"Error: {str(e)}")
|
||||
log.exception(f"Unexpected error: {e}")
|
||||
return 1
|
||||
finally:
|
||||
if process and process.poll() is None:
|
||||
process.terminate()
|
||||
process.wait()
|
||||
return process.returncode
|
||||
|
||||
|
||||
def run_command(command):
|
||||
try:
|
||||
result = subprocess.run(command, capture_output=True, text=True, check=True)
|
||||
print("Script Output\n", result.stdout)
|
||||
return result.returncode
|
||||
except subprocess.CalledProcessError as e:
|
||||
print("Error running script:", e)
|
||||
print("Error output:", e.stderr)
|
||||
return e.returncode
|
||||
# Restore the original signal handler
|
||||
signal.signal(signal.SIGINT, original_sigint)
|
||||
|
|
|
|||
155
llama_stack/distribution/utils/tests/test_context.py
Normal file
155
llama_stack/distribution/utils/tests/test_context.py
Normal file
|
|
@ -0,0 +1,155 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextvars import ContextVar
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preserve_contexts_with_exception():
|
||||
# Create context variable
|
||||
context_var = ContextVar("exception_var", default="initial")
|
||||
token = context_var.set("start_value")
|
||||
|
||||
# Create an async generator that raises an exception
|
||||
async def exception_generator():
|
||||
yield context_var.get()
|
||||
context_var.set("modified")
|
||||
raise ValueError("Test exception")
|
||||
yield None # This will never be reached
|
||||
|
||||
# Wrap the generator
|
||||
wrapped_gen = preserve_contexts_async_generator(exception_generator(), [context_var])
|
||||
|
||||
# First iteration should work
|
||||
value = await wrapped_gen.__anext__()
|
||||
assert value == "start_value"
|
||||
|
||||
# Second iteration should raise the exception
|
||||
with pytest.raises(ValueError, match="Test exception"):
|
||||
await wrapped_gen.__anext__()
|
||||
|
||||
# Clean up
|
||||
context_var.reset(token)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preserve_contexts_empty_generator():
|
||||
# Create context variable
|
||||
context_var = ContextVar("empty_var", default="initial")
|
||||
token = context_var.set("value")
|
||||
|
||||
# Create an empty async generator
|
||||
async def empty_generator():
|
||||
if False: # This condition ensures the generator yields nothing
|
||||
yield None
|
||||
|
||||
# Wrap the generator
|
||||
wrapped_gen = preserve_contexts_async_generator(empty_generator(), [context_var])
|
||||
|
||||
# The generator should raise StopAsyncIteration immediately
|
||||
with pytest.raises(StopAsyncIteration):
|
||||
await wrapped_gen.__anext__()
|
||||
|
||||
# Context variable should remain unchanged
|
||||
assert context_var.get() == "value"
|
||||
|
||||
# Clean up
|
||||
context_var.reset(token)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preserve_contexts_across_event_loops():
|
||||
"""
|
||||
Test that context variables are preserved across event loop boundaries with nested generators.
|
||||
This simulates the real-world scenario where:
|
||||
1. A new event loop is created for each streaming request
|
||||
2. The async generator runs inside that loop
|
||||
3. There are multiple levels of nested generators
|
||||
4. Context needs to be preserved across these boundaries
|
||||
"""
|
||||
# Create context variables
|
||||
request_id = ContextVar("request_id", default=None)
|
||||
user_id = ContextVar("user_id", default=None)
|
||||
|
||||
# Set initial values
|
||||
|
||||
# Results container to verify values across thread boundaries
|
||||
results = []
|
||||
|
||||
# Inner-most generator (level 2)
|
||||
async def inner_generator():
|
||||
# Should have the context from the outer scope
|
||||
yield (1, request_id.get(), user_id.get())
|
||||
|
||||
# Modify one context variable
|
||||
user_id.set("user-modified")
|
||||
|
||||
# Should reflect the modification
|
||||
yield (2, request_id.get(), user_id.get())
|
||||
|
||||
# Middle generator (level 1)
|
||||
async def middle_generator():
|
||||
inner_gen = inner_generator()
|
||||
|
||||
# Forward the first yield from inner
|
||||
item = await inner_gen.__anext__()
|
||||
yield item
|
||||
|
||||
# Forward the second yield from inner
|
||||
item = await inner_gen.__anext__()
|
||||
yield item
|
||||
|
||||
request_id.set("req-modified")
|
||||
|
||||
# Add our own yield with both modified variables
|
||||
yield (3, request_id.get(), user_id.get())
|
||||
|
||||
# Function to run in a separate thread with a new event loop
|
||||
def run_in_new_loop():
|
||||
# Create a new event loop for this thread
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
# Outer generator (runs in the new loop)
|
||||
async def outer_generator():
|
||||
request_id.set("req-12345")
|
||||
user_id.set("user-6789")
|
||||
# Wrap the middle generator
|
||||
wrapped_gen = preserve_contexts_async_generator(middle_generator(), [request_id, user_id])
|
||||
|
||||
# Process all items from the middle generator
|
||||
async for item in wrapped_gen:
|
||||
# Store results for verification
|
||||
results.append(item)
|
||||
|
||||
# Run the outer generator in the new loop
|
||||
loop.run_until_complete(outer_generator())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
# Run the generator chain in a separate thread with a new event loop
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(run_in_new_loop)
|
||||
future.result() # Wait for completion
|
||||
|
||||
# Verify the results
|
||||
assert len(results) == 3
|
||||
|
||||
# First yield should have original values
|
||||
assert results[0] == (1, "req-12345", "user-6789")
|
||||
|
||||
# Second yield should have modified user_id
|
||||
assert results[1] == (2, "req-12345", "user-modified")
|
||||
|
||||
# Third yield should have both modified values
|
||||
assert results[2] == (3, "req-modified", "user-modified")
|
||||
Loading…
Add table
Add a link
Reference in a new issue