implement full-passthrough in the server

This commit is contained in:
Ashwin Bharambe 2024-08-03 14:15:20 -07:00
parent 38fd76f85c
commit 9dafa6ad94
8 changed files with 69 additions and 71 deletions

View file

@ -8,8 +8,6 @@ fire
flake8
huggingface-hub
httpx
hydra-core
hydra-zen
json-strong-typing
matplotlib
omegaconf

View file

@ -11,11 +11,12 @@ import json
import shlex
from pathlib import Path
from typing import Annotated, get_args, get_origin, Literal, Optional, Union
from typing import get_args, get_origin, Literal, Optional, Union
import yaml
from pydantic import BaseModel
from termcolor import cprint
from typing_extensions import Annotated
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.utils import DISTRIBS_BASE_DIR, EnumEncoder
@ -37,6 +38,7 @@ class DistributionConfigure(Subcommand):
def _add_arguments(self):
from llama_toolchain.distribution.registry import available_distributions
self.parser.add_argument(
"--name",
type=str,
@ -64,6 +66,7 @@ class DistributionConfigure(Subcommand):
def configure_llama_distribution(dist: "Distribution", conda_env: str):
from llama_toolchain.distribution.datatypes import PassthroughApiAdapter
from .utils import run_command
python_exe = run_command(shlex.split("which python"))
@ -101,7 +104,10 @@ def configure_llama_distribution(dist: "Distribution", conda_env: str):
else None
),
)
adapter_configs[api_surface.value] = config.dict()
adapter_configs[api_surface.value] = {
adapter_id: adapter.adapter_id,
**config.dict(),
}
dist_config = {
"adapters": adapter_configs,
@ -138,6 +144,8 @@ def get_non_none_type(field_type):
return next(arg for arg in get_args(field_type) if arg is not type(None))
# TODO: maybe support List values (for simple types, it should be comma-separated and for complex ones,
# it should prompt iteratively if the user wants to add more values)
def prompt_for_config(
config_type: type[BaseModel], existing_config: Optional[BaseModel] = None
) -> BaseModel:

View file

@ -30,6 +30,7 @@ class DistributionInstall(Subcommand):
def _add_arguments(self):
from llama_toolchain.distribution.registry import available_distributions
self.parser.add_argument(
"--name",
type=str,
@ -63,7 +64,7 @@ class DistributionInstall(Subcommand):
os.makedirs(DISTRIBS_BASE_DIR / dist.name, exist_ok=True)
deps = distribution_dependencies(dist)
run_command([script, args.conda_env, " ".join(deps)])
run_with_pty([script, args.conda_env, " ".join(deps)])
with open(DISTRIBS_BASE_DIR / dist.name / "conda.env", "w") as f:
f.write(f"{args.conda_env}\n")

View file

@ -25,8 +25,6 @@ COMMON_DEPENDENCIES = [
"flake8",
"httpx",
"huggingface-hub",
"hydra-core",
"hydra-zen",
"json-strong-typing",
"git+ssh://git@github.com/meta-llama/llama-models.git",
"omegaconf",
@ -67,9 +65,12 @@ def available_distributions() -> List[Distribution]:
"fairscale",
"fastapi",
"fire",
"flake8",
"httpx",
"huggingface-hub",
"json-strong-typing",
"pydantic==1.10.13",
"pydantic_core==2.18.2",
"uvicorn",
],
adapters={
ApiSurface.inference: PassthroughApiAdapter(

View file

@ -80,29 +80,59 @@ async def passthrough(
downstream_url: str,
downstream_headers: Optional[Dict[str, str]] = None,
):
client = httpx.AsyncClient()
headers = dict(request.headers)
headers.pop("host", None)
headers.update(downstream_headers or {})
body = await request.body()
content = await request.body()
async def iterating_response():
def enc(x):
return x.encode("latin-1")
async with httpx.AsyncClient() as client:
response_started = False
try:
response = await client.request(
async with client.stream(
method=request.method,
url=downstream_url,
headers=headers,
content=body,
content=content,
params=request.query_params,
) as response:
yield enc(
f"HTTP/1.1 {response.status_code} {response.reason_phrase}\r\n"
)
return StreamingResponse(
response.iter_bytes(),
status_code=response.status_code,
headers=dict(response.headers),
for k, v in response.headers.items():
yield enc(f"{k}: {v}\r\n")
yield b"\r\n"
response_started = True
# using a small chunk size to allow for streaming SSE, this is not ideal
# for large responses but we are not in that regime for the most part
async for chunk in response.aiter_raw(chunk_size=64):
yield chunk
await response.aclose()
except ReadTimeout:
if not response_started:
yield enc(
"HTTP/1.1 504 Gateway Timeout\r\nContent-Type: text/plain\r\n\r\nDownstream server timed out"
)
finally:
await client.aclose()
else:
yield enc("\r\n\r\nError: Downstream server timed out")
except asyncio.CancelledError:
print("Request cancelled")
return
except Exception as e:
if not response_started:
yield enc(
f"HTTP/1.1 500 Internal Server Error\r\nContent-Type: text/plain\r\n\r\nError: {str(e)}"
)
else:
yield enc(f"\r\n\r\nError: {e}")
return StreamingResponse(iterating_response())
def handle_sigint(*args, **kwargs):
@ -134,6 +164,10 @@ def create_dynamic_typed_route(func: Any):
request_model = next(iter(hints.values()))
response_model = hints["return"]
# NOTE: I think it is better to just add a method within each ApiSurface
# "Protocol" / adapter-impl to tell what sort of a response this request
# is going to produce. /chat_completion can produce a streaming or
# non-streaming response depending on if request.stream is True / False.
is_streaming = is_async_iterator_type(response_model)
if is_streaming:

View file

@ -75,7 +75,7 @@ class InferenceClient(Inference):
async def run_main(host: str, port: int, stream: bool):
client = InferenceClient(f"http://{host}:{port}")
message = UserMessage(content="hello world, help me out here")
message = UserMessage(content="hello world, troll me in two-paragraphs about 42")
cprint(f"User>{message.content}", "green")
iterator = client.chat_completion(
ChatCompletionRequest(

View file

@ -4,17 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import getpass
import json
import os
from enum import Enum
from pathlib import Path
from typing import Optional
from hydra import compose, initialize, MissingConfigException
from hydra.core.global_hydra import GlobalHydra
from omegaconf import OmegaConf
LLAMA_STACK_CONFIG_DIR = os.path.expanduser("~/.llama/")
@ -34,41 +27,6 @@ def get_default_config_dir():
return os.path.join(LLAMA_STACK_CONFIG_DIR, "configs")
def parse_config(config_dir: str, config_path: Optional[str] = None) -> str:
# Configs can be
# 1. relative paths in {config_dir}/
# 2. or default to file {config_dir}/{user}.yaml
# 3. or ultimate default to {config_dir}/default.yaml
# Get the relative path from the current file to the config directory
current_file_directory = os.path.dirname(os.path.abspath(__file__))
relative_path = os.path.relpath(config_dir, current_file_directory)
GlobalHydra.instance().clear()
initialize(config_path=relative_path)
if config_path is None:
try:
user = getpass.getuser()
config_name = user
except MissingConfigException:
print(f"No user-specific {user}.yaml, using default")
config_name = "default"
else:
config_name = config_path
config_abs_path = os.path.abspath(os.path.join(config_dir, f"{config_name}.yaml"))
print(f"Loading config from : {config_abs_path}")
config = compose(config_name=config_name)
print("Yaml config:")
print("------------------------")
print(OmegaConf.to_yaml(config, resolve=True))
print("------------------------")
return config
class EnumEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, Enum):

View file

@ -8,8 +8,6 @@ fire
flake8
httpx
huggingface-hub
hydra-core
hydra-zen
json-strong-typing
llama-models
matplotlib