mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Fix passthrough streaming, send headers properly not part of body :facepalm
This commit is contained in:
parent
039861f1c7
commit
a0e61a3c7a
2 changed files with 36 additions and 45 deletions
|
@ -12,6 +12,7 @@ from collections.abc import (
|
||||||
AsyncIterator as AsyncIteratorABC,
|
AsyncIterator as AsyncIteratorABC,
|
||||||
)
|
)
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
from ssl import SSLError
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
AsyncGenerator,
|
AsyncGenerator,
|
||||||
|
@ -28,11 +29,10 @@ import httpx
|
||||||
import yaml
|
import yaml
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
from fastapi import FastAPI, HTTPException, Request
|
from fastapi import FastAPI, HTTPException, Request, Response
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
from fastapi.routing import APIRoute
|
from fastapi.routing import APIRoute
|
||||||
|
|
||||||
from pydantic import BaseModel, ValidationError
|
from pydantic import BaseModel, ValidationError
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
|
@ -95,53 +95,43 @@ async def passthrough(
|
||||||
|
|
||||||
content = await request.body()
|
content = await request.body()
|
||||||
|
|
||||||
async def iterating_response():
|
client = httpx.AsyncClient()
|
||||||
def enc(x):
|
try:
|
||||||
return x.encode("latin-1")
|
req = client.build_request(
|
||||||
|
method=request.method,
|
||||||
|
url=downstream_url,
|
||||||
|
headers=headers,
|
||||||
|
content=content,
|
||||||
|
params=request.query_params,
|
||||||
|
)
|
||||||
|
response = await client.send(req, stream=True)
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
async def stream_response():
|
||||||
response_started = False
|
async for chunk in response.aiter_raw(chunk_size=64):
|
||||||
try:
|
yield chunk
|
||||||
async with client.stream(
|
|
||||||
method=request.method,
|
|
||||||
url=downstream_url,
|
|
||||||
headers=headers,
|
|
||||||
content=content,
|
|
||||||
params=request.query_params,
|
|
||||||
) as response:
|
|
||||||
yield enc(
|
|
||||||
f"HTTP/1.1 {response.status_code} {response.reason_phrase}\r\n"
|
|
||||||
)
|
|
||||||
for k, v in response.headers.items():
|
|
||||||
yield enc(f"{k}: {v}\r\n")
|
|
||||||
yield b"\r\n"
|
|
||||||
|
|
||||||
response_started = True
|
await response.aclose()
|
||||||
|
await client.aclose()
|
||||||
|
|
||||||
# using a small chunk size to allow for streaming SSE, this is not ideal
|
return StreamingResponse(
|
||||||
# for large responses but we are not in that regime for the most part
|
stream_response(),
|
||||||
async for chunk in response.aiter_raw(chunk_size=64):
|
status_code=response.status_code,
|
||||||
yield chunk
|
headers=dict(response.headers),
|
||||||
await response.aclose()
|
media_type=response.headers.get("content-type"),
|
||||||
except ReadTimeout:
|
)
|
||||||
if not response_started:
|
|
||||||
yield enc(
|
|
||||||
"HTTP/1.1 504 Gateway Timeout\r\nContent-Type: text/plain\r\n\r\nDownstream server timed out"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
yield enc("\r\n\r\nError: Downstream server timed out")
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
print("Request cancelled")
|
|
||||||
return
|
|
||||||
except Exception as e:
|
|
||||||
if not response_started:
|
|
||||||
yield enc(
|
|
||||||
f"HTTP/1.1 500 Internal Server Error\r\nContent-Type: text/plain\r\n\r\nError: {str(e)}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
yield enc(f"\r\n\r\nError: {e}")
|
|
||||||
|
|
||||||
return StreamingResponse(iterating_response())
|
except httpx.ReadTimeout:
|
||||||
|
return Response(content="Downstream server timed out", status_code=504)
|
||||||
|
except httpx.NetworkError as e:
|
||||||
|
return Response(content=f"Network error: {str(e)}", status_code=502)
|
||||||
|
except httpx.TooManyRedirects:
|
||||||
|
return Response(content="Too many redirects", status_code=502)
|
||||||
|
except SSLError as e:
|
||||||
|
return Response(content=f"SSL error: {str(e)}", status_code=502)
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
return Response(content=str(e), status_code=e.response.status_code)
|
||||||
|
except Exception as e:
|
||||||
|
return Response(content=f"Unexpected error: {str(e)}", status_code=500)
|
||||||
|
|
||||||
|
|
||||||
def handle_sigint(*args, **kwargs):
|
def handle_sigint(*args, **kwargs):
|
||||||
|
|
|
@ -50,6 +50,7 @@ class InferenceClient(Inference):
|
||||||
headers={"Content-Type": "application/json"},
|
headers={"Content-Type": "application/json"},
|
||||||
timeout=20,
|
timeout=20,
|
||||||
) as response:
|
) as response:
|
||||||
|
print("Headers", response.headers)
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
content = await response.aread()
|
content = await response.aread()
|
||||||
cprint(
|
cprint(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue