mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-31 16:01:46 +00:00
fixes and enable tool_runtime tests
This commit is contained in:
parent
9f7ed4be43
commit
bc7901e3bd
4 changed files with 38 additions and 7 deletions
2
.github/workflows/integration-tests.yml
vendored
2
.github/workflows/integration-tests.yml
vendored
|
@ -24,7 +24,7 @@ jobs:
|
|||
matrix:
|
||||
# Listing tests manually since some of them currently fail
|
||||
# TODO: generate matrix list from tests/integration when fixed
|
||||
test-type: [agents, inference, datasets, inspect, scoring, post_training, providers]
|
||||
test-type: [agents, inference, datasets, inspect, scoring, post_training, providers, tool_runtime]
|
||||
client-type: [library, http]
|
||||
fail-fast: false # we want to run all tests regardless of failure
|
||||
|
||||
|
|
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -6,6 +6,7 @@ dev_requirements.txt
|
|||
build
|
||||
.DS_Store
|
||||
llama_stack/configs/*
|
||||
.cursor/
|
||||
xcuserdata/
|
||||
*.hmap
|
||||
.DS_Store
|
||||
|
|
|
@ -7,7 +7,14 @@
|
|||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
|
||||
import exceptiongroup
|
||||
try:
|
||||
# for python < 3.11
|
||||
import exceptiongroup
|
||||
|
||||
BaseExceptionGroup = exceptiongroup.BaseExceptionGroup
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
import httpx
|
||||
from mcp import ClientSession
|
||||
from mcp import types as mcp_types
|
||||
|
@ -34,7 +41,7 @@ async def sse_client_wrapper(endpoint: str, headers: dict[str, str]):
|
|||
await session.initialize()
|
||||
yield session
|
||||
except BaseException as e:
|
||||
if isinstance(e, exceptiongroup.BaseExceptionGroup):
|
||||
if isinstance(e, BaseExceptionGroup):
|
||||
for exc in e.exceptions:
|
||||
if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code == 401:
|
||||
raise AuthenticationRequiredError(exc) from exc
|
||||
|
|
|
@ -109,8 +109,31 @@ def make_mcp_server(required_auth_token: str | None = None):
|
|||
pass
|
||||
time.sleep(0.1)
|
||||
|
||||
yield {"server_url": server_url}
|
||||
try:
|
||||
yield {"server_url": server_url}
|
||||
finally:
|
||||
print("Telling SSE server to exit")
|
||||
server_instance.should_exit = True
|
||||
time.sleep(0.5)
|
||||
|
||||
# Tell server to exit
|
||||
server_instance.should_exit = True
|
||||
server_thread.join(timeout=5)
|
||||
# Force shutdown if still running
|
||||
if server_thread.is_alive():
|
||||
try:
|
||||
if hasattr(server_instance, "servers") and server_instance.servers:
|
||||
for srv in server_instance.servers:
|
||||
srv.close()
|
||||
|
||||
# Wait for graceful shutdown
|
||||
server_thread.join(timeout=3)
|
||||
if server_thread.is_alive():
|
||||
print("Warning: Server thread still alive after shutdown attempt")
|
||||
except Exception as e:
|
||||
print(f"Error during server shutdown: {e}")
|
||||
|
||||
# CRITICAL: Reset SSE global state to prevent event loop contamination
|
||||
# Reset the SSE AppStatus singleton that stores anyio.Event objects
|
||||
from sse_starlette.sse import AppStatus
|
||||
|
||||
AppStatus.should_exit = False
|
||||
AppStatus.should_exit_event = None
|
||||
print("SSE server exited")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue