From a0e61a3c7a5944642ff9731ad9f4760357cd6111 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Tue, 6 Aug 2024 16:39:38 -0700 Subject: [PATCH] Fix passthrough streaming, send headers properly not part of body :facepalm --- llama_toolchain/distribution/server.py | 80 +++++++++++--------------- llama_toolchain/inference/client.py | 1 + 2 files changed, 36 insertions(+), 45 deletions(-) diff --git a/llama_toolchain/distribution/server.py b/llama_toolchain/distribution/server.py index 63ddcbe01..a087b3a64 100644 --- a/llama_toolchain/distribution/server.py +++ b/llama_toolchain/distribution/server.py @@ -12,6 +12,7 @@ from collections.abc import ( AsyncIterator as AsyncIteratorABC, ) from contextlib import asynccontextmanager +from ssl import SSLError from typing import ( Any, AsyncGenerator, @@ -28,11 +29,10 @@ import httpx import yaml from dotenv import load_dotenv -from fastapi import FastAPI, HTTPException, Request +from fastapi import FastAPI, HTTPException, Request, Response from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse, StreamingResponse from fastapi.routing import APIRoute - from pydantic import BaseModel, ValidationError from termcolor import cprint @@ -95,53 +95,43 @@ async def passthrough( content = await request.body() - async def iterating_response(): - def enc(x): - return x.encode("latin-1") + client = httpx.AsyncClient() + try: + req = client.build_request( + method=request.method, + url=downstream_url, + headers=headers, + content=content, + params=request.query_params, + ) + response = await client.send(req, stream=True) - async with httpx.AsyncClient() as client: - response_started = False - try: - async with client.stream( - method=request.method, - url=downstream_url, - headers=headers, - content=content, - params=request.query_params, - ) as response: - yield enc( - f"HTTP/1.1 {response.status_code} {response.reason_phrase}\r\n" - ) - for k, v in response.headers.items(): - yield enc(f"{k}: {v}\r\n") - yield b"\r\n" + async def stream_response(): + async for chunk in response.aiter_raw(chunk_size=64): + yield chunk - response_started = True + await response.aclose() + await client.aclose() - # using a small chunk size to allow for streaming SSE, this is not ideal - # for large responses but we are not in that regime for the most part - async for chunk in response.aiter_raw(chunk_size=64): - yield chunk - await response.aclose() - except ReadTimeout: - if not response_started: - yield enc( - "HTTP/1.1 504 Gateway Timeout\r\nContent-Type: text/plain\r\n\r\nDownstream server timed out" - ) - else: - yield enc("\r\n\r\nError: Downstream server timed out") - except asyncio.CancelledError: - print("Request cancelled") - return - except Exception as e: - if not response_started: - yield enc( - f"HTTP/1.1 500 Internal Server Error\r\nContent-Type: text/plain\r\n\r\nError: {str(e)}" - ) - else: - yield enc(f"\r\n\r\nError: {e}") + return StreamingResponse( + stream_response(), + status_code=response.status_code, + headers=dict(response.headers), + media_type=response.headers.get("content-type"), + ) - return StreamingResponse(iterating_response()) + except httpx.ReadTimeout: + return Response(content="Downstream server timed out", status_code=504) + except httpx.NetworkError as e: + return Response(content=f"Network error: {str(e)}", status_code=502) + except httpx.TooManyRedirects: + return Response(content="Too many redirects", status_code=502) + except SSLError as e: + return Response(content=f"SSL error: {str(e)}", status_code=502) + except httpx.HTTPStatusError as e: + return Response(content=str(e), status_code=e.response.status_code) + except Exception as e: + return Response(content=f"Unexpected error: {str(e)}", status_code=500) def handle_sigint(*args, **kwargs): diff --git a/llama_toolchain/inference/client.py b/llama_toolchain/inference/client.py index aa84f906d..36ee6225a 100644 --- a/llama_toolchain/inference/client.py +++ b/llama_toolchain/inference/client.py @@ -50,6 +50,7 @@ class InferenceClient(Inference): headers={"Content-Type": "application/json"}, timeout=20, ) as response: + print("Headers", response.headers) if response.status_code != 200: content = await response.aread() cprint(