mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
implement full-passthrough in the server
This commit is contained in:
parent
38fd76f85c
commit
9dafa6ad94
8 changed files with 69 additions and 71 deletions
|
@ -25,8 +25,6 @@ COMMON_DEPENDENCIES = [
|
|||
"flake8",
|
||||
"httpx",
|
||||
"huggingface-hub",
|
||||
"hydra-core",
|
||||
"hydra-zen",
|
||||
"json-strong-typing",
|
||||
"git+ssh://git@github.com/meta-llama/llama-models.git",
|
||||
"omegaconf",
|
||||
|
@ -67,9 +65,12 @@ def available_distributions() -> List[Distribution]:
|
|||
"fairscale",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"flake8",
|
||||
"httpx",
|
||||
"huggingface-hub",
|
||||
"json-strong-typing",
|
||||
"pydantic==1.10.13",
|
||||
"pydantic_core==2.18.2",
|
||||
"uvicorn",
|
||||
],
|
||||
adapters={
|
||||
ApiSurface.inference: PassthroughApiAdapter(
|
||||
|
|
|
@ -80,29 +80,59 @@ async def passthrough(
|
|||
downstream_url: str,
|
||||
downstream_headers: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
client = httpx.AsyncClient()
|
||||
|
||||
headers = dict(request.headers)
|
||||
headers.pop("host", None)
|
||||
headers.update(downstream_headers or {})
|
||||
|
||||
body = await request.body()
|
||||
content = await request.body()
|
||||
|
||||
try:
|
||||
response = await client.request(
|
||||
method=request.method,
|
||||
url=downstream_url,
|
||||
headers=headers,
|
||||
content=body,
|
||||
params=request.query_params,
|
||||
)
|
||||
return StreamingResponse(
|
||||
response.iter_bytes(),
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers),
|
||||
)
|
||||
finally:
|
||||
await client.aclose()
|
||||
async def iterating_response():
|
||||
def enc(x):
|
||||
return x.encode("latin-1")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response_started = False
|
||||
try:
|
||||
async with client.stream(
|
||||
method=request.method,
|
||||
url=downstream_url,
|
||||
headers=headers,
|
||||
content=content,
|
||||
params=request.query_params,
|
||||
) as response:
|
||||
yield enc(
|
||||
f"HTTP/1.1 {response.status_code} {response.reason_phrase}\r\n"
|
||||
)
|
||||
for k, v in response.headers.items():
|
||||
yield enc(f"{k}: {v}\r\n")
|
||||
yield b"\r\n"
|
||||
|
||||
response_started = True
|
||||
|
||||
# using a small chunk size to allow for streaming SSE, this is not ideal
|
||||
# for large responses but we are not in that regime for the most part
|
||||
async for chunk in response.aiter_raw(chunk_size=64):
|
||||
yield chunk
|
||||
await response.aclose()
|
||||
except ReadTimeout:
|
||||
if not response_started:
|
||||
yield enc(
|
||||
"HTTP/1.1 504 Gateway Timeout\r\nContent-Type: text/plain\r\n\r\nDownstream server timed out"
|
||||
)
|
||||
else:
|
||||
yield enc("\r\n\r\nError: Downstream server timed out")
|
||||
except asyncio.CancelledError:
|
||||
print("Request cancelled")
|
||||
return
|
||||
except Exception as e:
|
||||
if not response_started:
|
||||
yield enc(
|
||||
f"HTTP/1.1 500 Internal Server Error\r\nContent-Type: text/plain\r\n\r\nError: {str(e)}"
|
||||
)
|
||||
else:
|
||||
yield enc(f"\r\n\r\nError: {e}")
|
||||
|
||||
return StreamingResponse(iterating_response())
|
||||
|
||||
|
||||
def handle_sigint(*args, **kwargs):
|
||||
|
@ -134,6 +164,10 @@ def create_dynamic_typed_route(func: Any):
|
|||
request_model = next(iter(hints.values()))
|
||||
response_model = hints["return"]
|
||||
|
||||
# NOTE: I think it is better to just add a method within each ApiSurface
|
||||
# "Protocol" / adapter-impl to tell what sort of a response this request
|
||||
# is going to produce. /chat_completion can produce a streaming or
|
||||
# non-streaming response depending on if request.stream is True / False.
|
||||
is_streaming = is_async_iterator_type(response_model)
|
||||
|
||||
if is_streaming:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue