mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
implement full-passthrough in the server
This commit is contained in:
parent
38fd76f85c
commit
9dafa6ad94
8 changed files with 69 additions and 71 deletions
|
@ -8,8 +8,6 @@ fire
|
|||
flake8
|
||||
huggingface-hub
|
||||
httpx
|
||||
hydra-core
|
||||
hydra-zen
|
||||
json-strong-typing
|
||||
matplotlib
|
||||
omegaconf
|
||||
|
|
|
@ -11,11 +11,12 @@ import json
|
|||
import shlex
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Annotated, get_args, get_origin, Literal, Optional, Union
|
||||
from typing import get_args, get_origin, Literal, Optional, Union
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel
|
||||
from termcolor import cprint
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_toolchain.cli.subcommand import Subcommand
|
||||
from llama_toolchain.utils import DISTRIBS_BASE_DIR, EnumEncoder
|
||||
|
@ -37,6 +38,7 @@ class DistributionConfigure(Subcommand):
|
|||
|
||||
def _add_arguments(self):
|
||||
from llama_toolchain.distribution.registry import available_distributions
|
||||
|
||||
self.parser.add_argument(
|
||||
"--name",
|
||||
type=str,
|
||||
|
@ -64,6 +66,7 @@ class DistributionConfigure(Subcommand):
|
|||
|
||||
def configure_llama_distribution(dist: "Distribution", conda_env: str):
|
||||
from llama_toolchain.distribution.datatypes import PassthroughApiAdapter
|
||||
|
||||
from .utils import run_command
|
||||
|
||||
python_exe = run_command(shlex.split("which python"))
|
||||
|
@ -101,7 +104,10 @@ def configure_llama_distribution(dist: "Distribution", conda_env: str):
|
|||
else None
|
||||
),
|
||||
)
|
||||
adapter_configs[api_surface.value] = config.dict()
|
||||
adapter_configs[api_surface.value] = {
|
||||
adapter_id: adapter.adapter_id,
|
||||
**config.dict(),
|
||||
}
|
||||
|
||||
dist_config = {
|
||||
"adapters": adapter_configs,
|
||||
|
@ -138,6 +144,8 @@ def get_non_none_type(field_type):
|
|||
return next(arg for arg in get_args(field_type) if arg is not type(None))
|
||||
|
||||
|
||||
# TODO: maybe support List values (for simple types, it should be comma-separated and for complex ones,
|
||||
# it should prompt iteratively if the user wants to add more values)
|
||||
def prompt_for_config(
|
||||
config_type: type[BaseModel], existing_config: Optional[BaseModel] = None
|
||||
) -> BaseModel:
|
||||
|
|
|
@ -30,6 +30,7 @@ class DistributionInstall(Subcommand):
|
|||
|
||||
def _add_arguments(self):
|
||||
from llama_toolchain.distribution.registry import available_distributions
|
||||
|
||||
self.parser.add_argument(
|
||||
"--name",
|
||||
type=str,
|
||||
|
@ -63,7 +64,7 @@ class DistributionInstall(Subcommand):
|
|||
os.makedirs(DISTRIBS_BASE_DIR / dist.name, exist_ok=True)
|
||||
|
||||
deps = distribution_dependencies(dist)
|
||||
run_command([script, args.conda_env, " ".join(deps)])
|
||||
run_with_pty([script, args.conda_env, " ".join(deps)])
|
||||
with open(DISTRIBS_BASE_DIR / dist.name / "conda.env", "w") as f:
|
||||
f.write(f"{args.conda_env}\n")
|
||||
|
||||
|
|
|
@ -25,8 +25,6 @@ COMMON_DEPENDENCIES = [
|
|||
"flake8",
|
||||
"httpx",
|
||||
"huggingface-hub",
|
||||
"hydra-core",
|
||||
"hydra-zen",
|
||||
"json-strong-typing",
|
||||
"git+ssh://git@github.com/meta-llama/llama-models.git",
|
||||
"omegaconf",
|
||||
|
@ -67,9 +65,12 @@ def available_distributions() -> List[Distribution]:
|
|||
"fairscale",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"flake8",
|
||||
"httpx",
|
||||
"huggingface-hub",
|
||||
"json-strong-typing",
|
||||
"pydantic==1.10.13",
|
||||
"pydantic_core==2.18.2",
|
||||
"uvicorn",
|
||||
],
|
||||
adapters={
|
||||
ApiSurface.inference: PassthroughApiAdapter(
|
||||
|
|
|
@ -80,29 +80,59 @@ async def passthrough(
|
|||
downstream_url: str,
|
||||
downstream_headers: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
client = httpx.AsyncClient()
|
||||
|
||||
headers = dict(request.headers)
|
||||
headers.pop("host", None)
|
||||
headers.update(downstream_headers or {})
|
||||
|
||||
body = await request.body()
|
||||
content = await request.body()
|
||||
|
||||
try:
|
||||
response = await client.request(
|
||||
method=request.method,
|
||||
url=downstream_url,
|
||||
headers=headers,
|
||||
content=body,
|
||||
params=request.query_params,
|
||||
)
|
||||
return StreamingResponse(
|
||||
response.iter_bytes(),
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers),
|
||||
)
|
||||
finally:
|
||||
await client.aclose()
|
||||
async def iterating_response():
|
||||
def enc(x):
|
||||
return x.encode("latin-1")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response_started = False
|
||||
try:
|
||||
async with client.stream(
|
||||
method=request.method,
|
||||
url=downstream_url,
|
||||
headers=headers,
|
||||
content=content,
|
||||
params=request.query_params,
|
||||
) as response:
|
||||
yield enc(
|
||||
f"HTTP/1.1 {response.status_code} {response.reason_phrase}\r\n"
|
||||
)
|
||||
for k, v in response.headers.items():
|
||||
yield enc(f"{k}: {v}\r\n")
|
||||
yield b"\r\n"
|
||||
|
||||
response_started = True
|
||||
|
||||
# using a small chunk size to allow for streaming SSE, this is not ideal
|
||||
# for large responses but we are not in that regime for the most part
|
||||
async for chunk in response.aiter_raw(chunk_size=64):
|
||||
yield chunk
|
||||
await response.aclose()
|
||||
except ReadTimeout:
|
||||
if not response_started:
|
||||
yield enc(
|
||||
"HTTP/1.1 504 Gateway Timeout\r\nContent-Type: text/plain\r\n\r\nDownstream server timed out"
|
||||
)
|
||||
else:
|
||||
yield enc("\r\n\r\nError: Downstream server timed out")
|
||||
except asyncio.CancelledError:
|
||||
print("Request cancelled")
|
||||
return
|
||||
except Exception as e:
|
||||
if not response_started:
|
||||
yield enc(
|
||||
f"HTTP/1.1 500 Internal Server Error\r\nContent-Type: text/plain\r\n\r\nError: {str(e)}"
|
||||
)
|
||||
else:
|
||||
yield enc(f"\r\n\r\nError: {e}")
|
||||
|
||||
return StreamingResponse(iterating_response())
|
||||
|
||||
|
||||
def handle_sigint(*args, **kwargs):
|
||||
|
@ -134,6 +164,10 @@ def create_dynamic_typed_route(func: Any):
|
|||
request_model = next(iter(hints.values()))
|
||||
response_model = hints["return"]
|
||||
|
||||
# NOTE: I think it is better to just add a method within each ApiSurface
|
||||
# "Protocol" / adapter-impl to tell what sort of a response this request
|
||||
# is going to produce. /chat_completion can produce a streaming or
|
||||
# non-streaming response depending on if request.stream is True / False.
|
||||
is_streaming = is_async_iterator_type(response_model)
|
||||
|
||||
if is_streaming:
|
||||
|
|
|
@ -75,7 +75,7 @@ class InferenceClient(Inference):
|
|||
async def run_main(host: str, port: int, stream: bool):
|
||||
client = InferenceClient(f"http://{host}:{port}")
|
||||
|
||||
message = UserMessage(content="hello world, help me out here")
|
||||
message = UserMessage(content="hello world, troll me in two-paragraphs about 42")
|
||||
cprint(f"User>{message.content}", "green")
|
||||
iterator = client.chat_completion(
|
||||
ChatCompletionRequest(
|
||||
|
|
|
@ -4,17 +4,10 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import getpass
|
||||
import json
|
||||
import os
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from hydra import compose, initialize, MissingConfigException
|
||||
from hydra.core.global_hydra import GlobalHydra
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
|
||||
LLAMA_STACK_CONFIG_DIR = os.path.expanduser("~/.llama/")
|
||||
|
@ -34,41 +27,6 @@ def get_default_config_dir():
|
|||
return os.path.join(LLAMA_STACK_CONFIG_DIR, "configs")
|
||||
|
||||
|
||||
def parse_config(config_dir: str, config_path: Optional[str] = None) -> str:
|
||||
# Configs can be
|
||||
# 1. relative paths in {config_dir}/
|
||||
# 2. or default to file {config_dir}/{user}.yaml
|
||||
# 3. or ultimate default to {config_dir}/default.yaml
|
||||
|
||||
# Get the relative path from the current file to the config directory
|
||||
current_file_directory = os.path.dirname(os.path.abspath(__file__))
|
||||
relative_path = os.path.relpath(config_dir, current_file_directory)
|
||||
|
||||
GlobalHydra.instance().clear()
|
||||
initialize(config_path=relative_path)
|
||||
|
||||
if config_path is None:
|
||||
try:
|
||||
user = getpass.getuser()
|
||||
config_name = user
|
||||
except MissingConfigException:
|
||||
print(f"No user-specific {user}.yaml, using default")
|
||||
config_name = "default"
|
||||
else:
|
||||
config_name = config_path
|
||||
|
||||
config_abs_path = os.path.abspath(os.path.join(config_dir, f"{config_name}.yaml"))
|
||||
print(f"Loading config from : {config_abs_path}")
|
||||
config = compose(config_name=config_name)
|
||||
|
||||
print("Yaml config:")
|
||||
print("------------------------")
|
||||
print(OmegaConf.to_yaml(config, resolve=True))
|
||||
print("------------------------")
|
||||
|
||||
return config
|
||||
|
||||
|
||||
class EnumEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, Enum):
|
||||
|
|
|
@ -8,8 +8,6 @@ fire
|
|||
flake8
|
||||
httpx
|
||||
huggingface-hub
|
||||
hydra-core
|
||||
hydra-zen
|
||||
json-strong-typing
|
||||
llama-models
|
||||
matplotlib
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue