fixes and enable tool_runtime tests

This commit is contained in:
Ashwin Bharambe 2025-05-24 07:27:53 -07:00
parent 9f7ed4be43
commit bc7901e3bd
4 changed files with 38 additions and 7 deletions

View file

@ -24,7 +24,7 @@ jobs:
matrix:
# Listing tests manually since some of them currently fail
# TODO: generate matrix list from tests/integration when fixed
test-type: [agents, inference, datasets, inspect, scoring, post_training, providers]
test-type: [agents, inference, datasets, inspect, scoring, post_training, providers, tool_runtime]
client-type: [library, http]
fail-fast: false # we want to run all tests regardless of failure

1
.gitignore vendored
View file

@ -6,6 +6,7 @@ dev_requirements.txt
build
.DS_Store
llama_stack/configs/*
.cursor/
xcuserdata/
*.hmap
.DS_Store

View file

@ -7,7 +7,14 @@
from contextlib import asynccontextmanager
from typing import Any
import exceptiongroup
try:
# for python < 3.11
import exceptiongroup
BaseExceptionGroup = exceptiongroup.BaseExceptionGroup
except ImportError:
pass
import httpx
from mcp import ClientSession
from mcp import types as mcp_types
@ -34,7 +41,7 @@ async def sse_client_wrapper(endpoint: str, headers: dict[str, str]):
await session.initialize()
yield session
except BaseException as e:
if isinstance(e, exceptiongroup.BaseExceptionGroup):
if isinstance(e, BaseExceptionGroup):
for exc in e.exceptions:
if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code == 401:
raise AuthenticationRequiredError(exc) from exc

View file

@ -109,8 +109,31 @@ def make_mcp_server(required_auth_token: str | None = None):
pass
time.sleep(0.1)
yield {"server_url": server_url}
try:
yield {"server_url": server_url}
finally:
print("Telling SSE server to exit")
server_instance.should_exit = True
time.sleep(0.5)
# Tell server to exit
server_instance.should_exit = True
server_thread.join(timeout=5)
# Force shutdown if still running
if server_thread.is_alive():
try:
if hasattr(server_instance, "servers") and server_instance.servers:
for srv in server_instance.servers:
srv.close()
# Wait for graceful shutdown
server_thread.join(timeout=3)
if server_thread.is_alive():
print("Warning: Server thread still alive after shutdown attempt")
except Exception as e:
print(f"Error during server shutdown: {e}")
# CRITICAL: Reset SSE global state to prevent event loop contamination
# Reset the SSE AppStatus singleton that stores anyio.Event objects
from sse_starlette.sse import AppStatus
AppStatus.should_exit = False
AppStatus.should_exit_event = None
print("SSE server exited")