fixes and enable tool_runtime tests

This commit is contained in:
Ashwin Bharambe 2025-05-24 07:27:53 -07:00
parent 9f7ed4be43
commit bc7901e3bd
4 changed files with 38 additions and 7 deletions

View file

@ -24,7 +24,7 @@ jobs:
matrix: matrix:
# Listing tests manually since some of them currently fail # Listing tests manually since some of them currently fail
# TODO: generate matrix list from tests/integration when fixed # TODO: generate matrix list from tests/integration when fixed
test-type: [agents, inference, datasets, inspect, scoring, post_training, providers] test-type: [agents, inference, datasets, inspect, scoring, post_training, providers, tool_runtime]
client-type: [library, http] client-type: [library, http]
fail-fast: false # we want to run all tests regardless of failure fail-fast: false # we want to run all tests regardless of failure

1
.gitignore vendored
View file

@ -6,6 +6,7 @@ dev_requirements.txt
build build
.DS_Store .DS_Store
llama_stack/configs/* llama_stack/configs/*
.cursor/
xcuserdata/ xcuserdata/
*.hmap *.hmap
.DS_Store .DS_Store

View file

@ -7,7 +7,14 @@
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Any from typing import Any
import exceptiongroup try:
# for python < 3.11
import exceptiongroup
BaseExceptionGroup = exceptiongroup.BaseExceptionGroup
except ImportError:
pass
import httpx import httpx
from mcp import ClientSession from mcp import ClientSession
from mcp import types as mcp_types from mcp import types as mcp_types
@ -34,7 +41,7 @@ async def sse_client_wrapper(endpoint: str, headers: dict[str, str]):
await session.initialize() await session.initialize()
yield session yield session
except BaseException as e: except BaseException as e:
if isinstance(e, exceptiongroup.BaseExceptionGroup): if isinstance(e, BaseExceptionGroup):
for exc in e.exceptions: for exc in e.exceptions:
if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code == 401: if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code == 401:
raise AuthenticationRequiredError(exc) from exc raise AuthenticationRequiredError(exc) from exc

View file

@ -109,8 +109,31 @@ def make_mcp_server(required_auth_token: str | None = None):
pass pass
time.sleep(0.1) time.sleep(0.1)
yield {"server_url": server_url} try:
yield {"server_url": server_url}
finally:
print("Telling SSE server to exit")
server_instance.should_exit = True
time.sleep(0.5)
# Tell server to exit # Force shutdown if still running
server_instance.should_exit = True if server_thread.is_alive():
server_thread.join(timeout=5) try:
if hasattr(server_instance, "servers") and server_instance.servers:
for srv in server_instance.servers:
srv.close()
# Wait for graceful shutdown
server_thread.join(timeout=3)
if server_thread.is_alive():
print("Warning: Server thread still alive after shutdown attempt")
except Exception as e:
print(f"Error during server shutdown: {e}")
# CRITICAL: Reset SSE global state to prevent event loop contamination
# Reset the SSE AppStatus singleton that stores anyio.Event objects
from sse_starlette.sse import AppStatus
AppStatus.should_exit = False
AppStatus.should_exit_event = None
print("SSE server exited")