diff --git a/fp8_requirements.txt b/fp8_requirements.txt index 17637f27e..8a58cff62 100644 --- a/fp8_requirements.txt +++ b/fp8_requirements.txt @@ -8,8 +8,6 @@ fire flake8 huggingface-hub httpx -hydra-core -hydra-zen json-strong-typing matplotlib omegaconf diff --git a/llama_toolchain/cli/distribution/configure.py b/llama_toolchain/cli/distribution/configure.py index 20dc6955c..a3d449c50 100644 --- a/llama_toolchain/cli/distribution/configure.py +++ b/llama_toolchain/cli/distribution/configure.py @@ -11,11 +11,12 @@ import json import shlex from pathlib import Path -from typing import Annotated, get_args, get_origin, Literal, Optional, Union +from typing import get_args, get_origin, Literal, Optional, Union import yaml from pydantic import BaseModel from termcolor import cprint +from typing_extensions import Annotated from llama_toolchain.cli.subcommand import Subcommand from llama_toolchain.utils import DISTRIBS_BASE_DIR, EnumEncoder @@ -37,6 +38,7 @@ class DistributionConfigure(Subcommand): def _add_arguments(self): from llama_toolchain.distribution.registry import available_distributions + self.parser.add_argument( "--name", type=str, @@ -64,6 +66,7 @@ class DistributionConfigure(Subcommand): def configure_llama_distribution(dist: "Distribution", conda_env: str): from llama_toolchain.distribution.datatypes import PassthroughApiAdapter + from .utils import run_command python_exe = run_command(shlex.split("which python")) @@ -101,7 +104,10 @@ def configure_llama_distribution(dist: "Distribution", conda_env: str): else None ), ) - adapter_configs[api_surface.value] = config.dict() + adapter_configs[api_surface.value] = { + adapter_id: adapter.adapter_id, + **config.dict(), + } dist_config = { "adapters": adapter_configs, @@ -138,6 +144,8 @@ def get_non_none_type(field_type): return next(arg for arg in get_args(field_type) if arg is not type(None)) +# TODO: maybe support List values (for simple types, it should be comma-separated and for complex ones, +# it should prompt iteratively if the user wants to add more values) def prompt_for_config( config_type: type[BaseModel], existing_config: Optional[BaseModel] = None ) -> BaseModel: diff --git a/llama_toolchain/cli/distribution/install.py b/llama_toolchain/cli/distribution/install.py index 367906e32..8584e7517 100644 --- a/llama_toolchain/cli/distribution/install.py +++ b/llama_toolchain/cli/distribution/install.py @@ -30,6 +30,7 @@ class DistributionInstall(Subcommand): def _add_arguments(self): from llama_toolchain.distribution.registry import available_distributions + self.parser.add_argument( "--name", type=str, @@ -63,7 +64,7 @@ class DistributionInstall(Subcommand): os.makedirs(DISTRIBS_BASE_DIR / dist.name, exist_ok=True) deps = distribution_dependencies(dist) - run_command([script, args.conda_env, " ".join(deps)]) + run_with_pty([script, args.conda_env, " ".join(deps)]) with open(DISTRIBS_BASE_DIR / dist.name / "conda.env", "w") as f: f.write(f"{args.conda_env}\n") diff --git a/llama_toolchain/distribution/registry.py b/llama_toolchain/distribution/registry.py index bb25bb3c0..bcfa6eaa4 100644 --- a/llama_toolchain/distribution/registry.py +++ b/llama_toolchain/distribution/registry.py @@ -25,8 +25,6 @@ COMMON_DEPENDENCIES = [ "flake8", "httpx", "huggingface-hub", - "hydra-core", - "hydra-zen", "json-strong-typing", "git+ssh://git@github.com/meta-llama/llama-models.git", "omegaconf", @@ -67,9 +65,12 @@ def available_distributions() -> List[Distribution]: "fairscale", "fastapi", "fire", - "flake8", "httpx", "huggingface-hub", + "json-strong-typing", + "pydantic==1.10.13", + "pydantic_core==2.18.2", + "uvicorn", ], adapters={ ApiSurface.inference: PassthroughApiAdapter( diff --git a/llama_toolchain/distribution/server.py b/llama_toolchain/distribution/server.py index 128b78112..c3e898692 100644 --- a/llama_toolchain/distribution/server.py +++ b/llama_toolchain/distribution/server.py @@ -80,29 +80,59 @@ async def passthrough( downstream_url: str, downstream_headers: Optional[Dict[str, str]] = None, ): - client = httpx.AsyncClient() - headers = dict(request.headers) headers.pop("host", None) headers.update(downstream_headers or {}) - body = await request.body() + content = await request.body() - try: - response = await client.request( - method=request.method, - url=downstream_url, - headers=headers, - content=body, - params=request.query_params, - ) - return StreamingResponse( - response.iter_bytes(), - status_code=response.status_code, - headers=dict(response.headers), - ) - finally: - await client.aclose() + async def iterating_response(): + def enc(x): + return x.encode("latin-1") + + async with httpx.AsyncClient() as client: + response_started = False + try: + async with client.stream( + method=request.method, + url=downstream_url, + headers=headers, + content=content, + params=request.query_params, + ) as response: + yield enc( + f"HTTP/1.1 {response.status_code} {response.reason_phrase}\r\n" + ) + for k, v in response.headers.items(): + yield enc(f"{k}: {v}\r\n") + yield b"\r\n" + + response_started = True + + # using a small chunk size to allow for streaming SSE, this is not ideal + # for large responses but we are not in that regime for the most part + async for chunk in response.aiter_raw(chunk_size=64): + yield chunk + await response.aclose() + except ReadTimeout: + if not response_started: + yield enc( + "HTTP/1.1 504 Gateway Timeout\r\nContent-Type: text/plain\r\n\r\nDownstream server timed out" + ) + else: + yield enc("\r\n\r\nError: Downstream server timed out") + except asyncio.CancelledError: + print("Request cancelled") + return + except Exception as e: + if not response_started: + yield enc( + f"HTTP/1.1 500 Internal Server Error\r\nContent-Type: text/plain\r\n\r\nError: {str(e)}" + ) + else: + yield enc(f"\r\n\r\nError: {e}") + + return StreamingResponse(iterating_response()) def handle_sigint(*args, **kwargs): @@ -134,6 +164,10 @@ def create_dynamic_typed_route(func: Any): request_model = next(iter(hints.values())) response_model = hints["return"] + # NOTE: I think it is better to just add a method within each ApiSurface + # "Protocol" / adapter-impl to tell what sort of a response this request + # is going to produce. /chat_completion can produce a streaming or + # non-streaming response depending on if request.stream is True / False. is_streaming = is_async_iterator_type(response_model) if is_streaming: diff --git a/llama_toolchain/inference/client.py b/llama_toolchain/inference/client.py index 1dfe47b24..178452fde 100644 --- a/llama_toolchain/inference/client.py +++ b/llama_toolchain/inference/client.py @@ -75,7 +75,7 @@ class InferenceClient(Inference): async def run_main(host: str, port: int, stream: bool): client = InferenceClient(f"http://{host}:{port}") - message = UserMessage(content="hello world, help me out here") + message = UserMessage(content="hello world, troll me in two-paragraphs about 42") cprint(f"User>{message.content}", "green") iterator = client.chat_completion( ChatCompletionRequest( diff --git a/llama_toolchain/utils.py b/llama_toolchain/utils.py index 0b4df3b30..19d6fe976 100644 --- a/llama_toolchain/utils.py +++ b/llama_toolchain/utils.py @@ -4,17 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import getpass import json import os from enum import Enum from pathlib import Path -from typing import Optional - -from hydra import compose, initialize, MissingConfigException -from hydra.core.global_hydra import GlobalHydra - -from omegaconf import OmegaConf LLAMA_STACK_CONFIG_DIR = os.path.expanduser("~/.llama/") @@ -34,41 +27,6 @@ def get_default_config_dir(): return os.path.join(LLAMA_STACK_CONFIG_DIR, "configs") -def parse_config(config_dir: str, config_path: Optional[str] = None) -> str: - # Configs can be - # 1. relative paths in {config_dir}/ - # 2. or default to file {config_dir}/{user}.yaml - # 3. or ultimate default to {config_dir}/default.yaml - - # Get the relative path from the current file to the config directory - current_file_directory = os.path.dirname(os.path.abspath(__file__)) - relative_path = os.path.relpath(config_dir, current_file_directory) - - GlobalHydra.instance().clear() - initialize(config_path=relative_path) - - if config_path is None: - try: - user = getpass.getuser() - config_name = user - except MissingConfigException: - print(f"No user-specific {user}.yaml, using default") - config_name = "default" - else: - config_name = config_path - - config_abs_path = os.path.abspath(os.path.join(config_dir, f"{config_name}.yaml")) - print(f"Loading config from : {config_abs_path}") - config = compose(config_name=config_name) - - print("Yaml config:") - print("------------------------") - print(OmegaConf.to_yaml(config, resolve=True)) - print("------------------------") - - return config - - class EnumEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, Enum): diff --git a/requirements.txt b/requirements.txt index 05d642f81..fa78213bf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,8 +8,6 @@ fire flake8 httpx huggingface-hub -hydra-core -hydra-zen json-strong-typing llama-models matplotlib