Fix passthrough streaming, send headers properly not part of body :facepalm

This commit is contained in:
Ashwin Bharambe 2024-08-06 16:39:38 -07:00
parent 039861f1c7
commit a0e61a3c7a
2 changed files with 36 additions and 45 deletions

View file

@ -12,6 +12,7 @@ from collections.abc import (
AsyncIterator as AsyncIteratorABC, AsyncIterator as AsyncIteratorABC,
) )
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from ssl import SSLError
from typing import ( from typing import (
Any, Any,
AsyncGenerator, AsyncGenerator,
@ -28,11 +29,10 @@ import httpx
import yaml import yaml
from dotenv import load_dotenv from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException, Request from fastapi import FastAPI, HTTPException, Request, Response
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.routing import APIRoute from fastapi.routing import APIRoute
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
from termcolor import cprint from termcolor import cprint
@ -95,53 +95,43 @@ async def passthrough(
content = await request.body() content = await request.body()
async def iterating_response(): client = httpx.AsyncClient()
def enc(x): try:
return x.encode("latin-1") req = client.build_request(
method=request.method,
url=downstream_url,
headers=headers,
content=content,
params=request.query_params,
)
response = await client.send(req, stream=True)
async with httpx.AsyncClient() as client: async def stream_response():
response_started = False async for chunk in response.aiter_raw(chunk_size=64):
try: yield chunk
async with client.stream(
method=request.method,
url=downstream_url,
headers=headers,
content=content,
params=request.query_params,
) as response:
yield enc(
f"HTTP/1.1 {response.status_code} {response.reason_phrase}\r\n"
)
for k, v in response.headers.items():
yield enc(f"{k}: {v}\r\n")
yield b"\r\n"
response_started = True await response.aclose()
await client.aclose()
# using a small chunk size to allow for streaming SSE, this is not ideal return StreamingResponse(
# for large responses but we are not in that regime for the most part stream_response(),
async for chunk in response.aiter_raw(chunk_size=64): status_code=response.status_code,
yield chunk headers=dict(response.headers),
await response.aclose() media_type=response.headers.get("content-type"),
except ReadTimeout: )
if not response_started:
yield enc(
"HTTP/1.1 504 Gateway Timeout\r\nContent-Type: text/plain\r\n\r\nDownstream server timed out"
)
else:
yield enc("\r\n\r\nError: Downstream server timed out")
except asyncio.CancelledError:
print("Request cancelled")
return
except Exception as e:
if not response_started:
yield enc(
f"HTTP/1.1 500 Internal Server Error\r\nContent-Type: text/plain\r\n\r\nError: {str(e)}"
)
else:
yield enc(f"\r\n\r\nError: {e}")
return StreamingResponse(iterating_response()) except httpx.ReadTimeout:
return Response(content="Downstream server timed out", status_code=504)
except httpx.NetworkError as e:
return Response(content=f"Network error: {str(e)}", status_code=502)
except httpx.TooManyRedirects:
return Response(content="Too many redirects", status_code=502)
except SSLError as e:
return Response(content=f"SSL error: {str(e)}", status_code=502)
except httpx.HTTPStatusError as e:
return Response(content=str(e), status_code=e.response.status_code)
except Exception as e:
return Response(content=f"Unexpected error: {str(e)}", status_code=500)
def handle_sigint(*args, **kwargs): def handle_sigint(*args, **kwargs):

View file

@ -50,6 +50,7 @@ class InferenceClient(Inference):
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
timeout=20, timeout=20,
) as response: ) as response:
print("Headers", response.headers)
if response.status_code != 200: if response.status_code != 200:
content = await response.aread() content = await response.aread()
cprint( cprint(