mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
cleanup: simplify connection code
Signed-off-by: Calum Murray <cmurray@redhat.com>
This commit is contained in:
parent
873bf8d95a
commit
caecaa8b31
1 changed files with 14 additions and 12 deletions
|
@ -46,19 +46,21 @@ async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerat
|
|||
for i, strategy in enumerate(connection_strategies):
|
||||
try:
|
||||
if strategy == PROTOCOL_STREAMABLEHTTP:
|
||||
async with streamablehttp_client(endpoint, headers=headers) as (read_stream, write_stream, _):
|
||||
async with ClientSession(read_stream, write_stream) as session:
|
||||
await session.initialize()
|
||||
protocol_cache[endpoint] = PROTOCOL_STREAMABLEHTTP
|
||||
yield session
|
||||
return
|
||||
client = streamablehttp_client
|
||||
elif strategy == PROTOCOL_SSE:
|
||||
async with sse_client(endpoint, headers=headers) as streams:
|
||||
async with ClientSession(*streams) as session:
|
||||
await session.initialize()
|
||||
protocol_cache[endpoint] = PROTOCOL_SSE
|
||||
yield session
|
||||
return
|
||||
client = sse_client
|
||||
else:
|
||||
# this should not happen
|
||||
logger.warning(
|
||||
"tried to establish MCP connection with unknown protocol, defaulting to try with streamable_http"
|
||||
)
|
||||
client = streamablehttp_client
|
||||
async with client(endpoint, headers=headers) as client_streams:
|
||||
async with ClientSession(read_stream=client_streams[0], write_stream=client_streams[1]) as session:
|
||||
await session.initialize()
|
||||
protocol_cache[endpoint] = strategy
|
||||
yield session
|
||||
return
|
||||
except* httpx.HTTPStatusError as eg:
|
||||
for exc in eg.exceptions:
|
||||
# mypy does not currently narrow the type of `eg.exceptions` based on the `except*` filter,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue