mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
cleanup: simplify connection code
Signed-off-by: Calum Murray <cmurray@redhat.com>
This commit is contained in:
parent
873bf8d95a
commit
caecaa8b31
1 changed files with 14 additions and 12 deletions
|
@ -46,19 +46,21 @@ async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerat
|
||||||
for i, strategy in enumerate(connection_strategies):
|
for i, strategy in enumerate(connection_strategies):
|
||||||
try:
|
try:
|
||||||
if strategy == PROTOCOL_STREAMABLEHTTP:
|
if strategy == PROTOCOL_STREAMABLEHTTP:
|
||||||
async with streamablehttp_client(endpoint, headers=headers) as (read_stream, write_stream, _):
|
client = streamablehttp_client
|
||||||
async with ClientSession(read_stream, write_stream) as session:
|
|
||||||
await session.initialize()
|
|
||||||
protocol_cache[endpoint] = PROTOCOL_STREAMABLEHTTP
|
|
||||||
yield session
|
|
||||||
return
|
|
||||||
elif strategy == PROTOCOL_SSE:
|
elif strategy == PROTOCOL_SSE:
|
||||||
async with sse_client(endpoint, headers=headers) as streams:
|
client = sse_client
|
||||||
async with ClientSession(*streams) as session:
|
else:
|
||||||
await session.initialize()
|
# this should not happen
|
||||||
protocol_cache[endpoint] = PROTOCOL_SSE
|
logger.warning(
|
||||||
yield session
|
"tried to establish MCP connection with unknown protocol, defaulting to try with streamable_http"
|
||||||
return
|
)
|
||||||
|
client = streamablehttp_client
|
||||||
|
async with client(endpoint, headers=headers) as client_streams:
|
||||||
|
async with ClientSession(read_stream=client_streams[0], write_stream=client_streams[1]) as session:
|
||||||
|
await session.initialize()
|
||||||
|
protocol_cache[endpoint] = strategy
|
||||||
|
yield session
|
||||||
|
return
|
||||||
except* httpx.HTTPStatusError as eg:
|
except* httpx.HTTPStatusError as eg:
|
||||||
for exc in eg.exceptions:
|
for exc in eg.exceptions:
|
||||||
# mypy does not currently narrow the type of `eg.exceptions` based on the `except*` filter,
|
# mypy does not currently narrow the type of `eg.exceptions` based on the `except*` filter,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue