implement full-passthrough in the server

This commit is contained in:
Ashwin Bharambe 2024-08-03 14:15:20 -07:00
parent 38fd76f85c
commit 9dafa6ad94
8 changed files with 69 additions and 71 deletions

View file

@ -25,8 +25,6 @@ COMMON_DEPENDENCIES = [
"flake8",
"httpx",
"huggingface-hub",
"hydra-core",
"hydra-zen",
"json-strong-typing",
"git+ssh://git@github.com/meta-llama/llama-models.git",
"omegaconf",
@ -67,9 +65,12 @@ def available_distributions() -> List[Distribution]:
"fairscale",
"fastapi",
"fire",
"flake8",
"httpx",
"huggingface-hub",
"json-strong-typing",
"pydantic==1.10.13",
"pydantic_core==2.18.2",
"uvicorn",
],
adapters={
ApiSurface.inference: PassthroughApiAdapter(

View file

@ -80,29 +80,59 @@ async def passthrough(
downstream_url: str,
downstream_headers: Optional[Dict[str, str]] = None,
):
client = httpx.AsyncClient()
headers = dict(request.headers)
headers.pop("host", None)
headers.update(downstream_headers or {})
body = await request.body()
content = await request.body()
try:
response = await client.request(
method=request.method,
url=downstream_url,
headers=headers,
content=body,
params=request.query_params,
)
return StreamingResponse(
response.iter_bytes(),
status_code=response.status_code,
headers=dict(response.headers),
)
finally:
await client.aclose()
async def iterating_response():
def enc(x):
return x.encode("latin-1")
async with httpx.AsyncClient() as client:
response_started = False
try:
async with client.stream(
method=request.method,
url=downstream_url,
headers=headers,
content=content,
params=request.query_params,
) as response:
yield enc(
f"HTTP/1.1 {response.status_code} {response.reason_phrase}\r\n"
)
for k, v in response.headers.items():
yield enc(f"{k}: {v}\r\n")
yield b"\r\n"
response_started = True
# using a small chunk size to allow for streaming SSE, this is not ideal
# for large responses but we are not in that regime for the most part
async for chunk in response.aiter_raw(chunk_size=64):
yield chunk
await response.aclose()
except ReadTimeout:
if not response_started:
yield enc(
"HTTP/1.1 504 Gateway Timeout\r\nContent-Type: text/plain\r\n\r\nDownstream server timed out"
)
else:
yield enc("\r\n\r\nError: Downstream server timed out")
except asyncio.CancelledError:
print("Request cancelled")
return
except Exception as e:
if not response_started:
yield enc(
f"HTTP/1.1 500 Internal Server Error\r\nContent-Type: text/plain\r\n\r\nError: {str(e)}"
)
else:
yield enc(f"\r\n\r\nError: {e}")
return StreamingResponse(iterating_response())
def handle_sigint(*args, **kwargs):
@ -134,6 +164,10 @@ def create_dynamic_typed_route(func: Any):
request_model = next(iter(hints.values()))
response_model = hints["return"]
# NOTE: I think it is better to just add a method within each ApiSurface
# "Protocol" / adapter-impl to tell what sort of a response this request
# is going to produce. /chat_completion can produce a streaming or
# non-streaming response depending on if request.stream is True / False.
is_streaming = is_async_iterator_type(response_model)
if is_streaming: