mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Distribution server now functioning
This commit is contained in:
parent
041cafbee3
commit
2cf9915806
21 changed files with 635 additions and 266 deletions
|
@ -18,11 +18,13 @@ from termcolor import cprint
|
|||
|
||||
from llama_toolchain.cli.subcommand import Subcommand
|
||||
from llama_toolchain.distribution.datatypes import Distribution, PassthroughApiAdapter
|
||||
from llama_toolchain.distribution.registry import available_distributions
|
||||
from llama_toolchain.distribution.registry import (
|
||||
available_distributions,
|
||||
resolve_distribution,
|
||||
)
|
||||
from llama_toolchain.utils import DISTRIBS_BASE_DIR
|
||||
from .utils import run_command
|
||||
|
||||
DISTRIBS = available_distributions()
|
||||
from .utils import run_command
|
||||
|
||||
|
||||
class DistributionConfigure(Subcommand):
|
||||
|
@ -43,18 +45,13 @@ class DistributionConfigure(Subcommand):
|
|||
self.parser.add_argument(
|
||||
"--name",
|
||||
type=str,
|
||||
help="Mame of the distribution to configure",
|
||||
help="Name of the distribution to configure",
|
||||
default="local-source",
|
||||
choices=[d.name for d in available_distributions()],
|
||||
)
|
||||
|
||||
def _run_distribution_configure_cmd(self, args: argparse.Namespace) -> None:
|
||||
dist = None
|
||||
for d in DISTRIBS:
|
||||
if d.name == args.name:
|
||||
dist = d
|
||||
break
|
||||
|
||||
dist = resolve_distribution(args.name)
|
||||
if dist is None:
|
||||
self.parser.error(f"Could not find distribution {args.name}")
|
||||
return
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
import argparse
|
||||
|
||||
from llama_toolchain.cli.subcommand import Subcommand
|
||||
from llama_toolchain.distribution.registry import resolve_distribution
|
||||
|
||||
|
||||
class DistributionCreate(Subcommand):
|
||||
|
@ -23,7 +24,20 @@ class DistributionCreate(Subcommand):
|
|||
self.parser.set_defaults(func=self._run_distribution_create_cmd)
|
||||
|
||||
def _add_arguments(self):
|
||||
pass
|
||||
self.parser.add_argument(
|
||||
"--name",
|
||||
type=str,
|
||||
help="Name of the distribution to create",
|
||||
required=True,
|
||||
)
|
||||
# for each ApiSurface the user wants to support, we should
|
||||
# get the list of available adapters, ask which one the user
|
||||
# wants to pick and then ask for their configuration.
|
||||
|
||||
def _run_distribution_create_cmd(self, args: argparse.Namespace) -> None:
|
||||
dist = resolve_distribution(args.name)
|
||||
if dist is not None:
|
||||
self.parser.error(f"Distribution with name {args.name} already exists")
|
||||
return
|
||||
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -12,6 +12,7 @@ from .configure import DistributionConfigure
|
|||
from .create import DistributionCreate
|
||||
from .install import DistributionInstall
|
||||
from .list import DistributionList
|
||||
from .start import DistributionStart
|
||||
|
||||
|
||||
class DistributionParser(Subcommand):
|
||||
|
@ -31,3 +32,4 @@ class DistributionParser(Subcommand):
|
|||
DistributionInstall.create(subparsers)
|
||||
DistributionCreate.create(subparsers)
|
||||
DistributionConfigure.create(subparsers)
|
||||
DistributionStart.create(subparsers)
|
||||
|
|
|
@ -11,9 +11,13 @@ import shlex
|
|||
import pkg_resources
|
||||
|
||||
from llama_toolchain.cli.subcommand import Subcommand
|
||||
from llama_toolchain.distribution.datatypes import distribution_dependencies
|
||||
from llama_toolchain.distribution.registry import available_distributions
|
||||
from llama_toolchain.distribution.distribution import distribution_dependencies
|
||||
from llama_toolchain.distribution.registry import (
|
||||
available_distributions,
|
||||
resolve_distribution,
|
||||
)
|
||||
from llama_toolchain.utils import DISTRIBS_BASE_DIR
|
||||
|
||||
from .utils import run_command, run_with_pty
|
||||
|
||||
DISTRIBS = available_distributions()
|
||||
|
@ -55,12 +59,7 @@ class DistributionInstall(Subcommand):
|
|||
"distribution/install_distribution.sh",
|
||||
)
|
||||
|
||||
dist = None
|
||||
for d in DISTRIBS:
|
||||
if d.name == args.name:
|
||||
dist = d
|
||||
break
|
||||
|
||||
dist = resolve_distribution(args.name)
|
||||
if dist is None:
|
||||
self.parser.error(f"Could not find distribution {args.name}")
|
||||
return
|
||||
|
|
|
@ -9,7 +9,7 @@ import argparse
|
|||
from llama_toolchain.cli.subcommand import Subcommand
|
||||
from llama_toolchain.cli.table import print_table
|
||||
|
||||
from llama_toolchain.distribution.datatypes import distribution_dependencies
|
||||
from llama_toolchain.distribution.distribution import distribution_dependencies
|
||||
from llama_toolchain.distribution.registry import available_distributions
|
||||
|
||||
|
||||
|
|
87
llama_toolchain/cli/distribution/start.py
Normal file
87
llama_toolchain/cli/distribution/start.py
Normal file
|
@ -0,0 +1,87 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import shlex
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
from llama_toolchain.cli.subcommand import Subcommand
|
||||
from llama_toolchain.distribution.registry import resolve_distribution
|
||||
from llama_toolchain.distribution.server import main as distribution_server_init
|
||||
from llama_toolchain.utils import DISTRIBS_BASE_DIR
|
||||
|
||||
from .utils import run_command
|
||||
|
||||
|
||||
class DistributionStart(Subcommand):
|
||||
|
||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||
super().__init__()
|
||||
self.parser = subparsers.add_parser(
|
||||
"start",
|
||||
prog="llama distribution start",
|
||||
description="""start the server for a Llama stack distribution. you should have already installed and configured the distribution""",
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
self._add_arguments()
|
||||
self.parser.set_defaults(func=self._run_distribution_start_cmd)
|
||||
|
||||
def _add_arguments(self):
|
||||
self.parser.add_argument(
|
||||
"--name",
|
||||
type=str,
|
||||
help="Name of the distribution to start",
|
||||
required=True,
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
help="Port to run the server on. Defaults to 5000",
|
||||
default=5000,
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--disable-ipv6",
|
||||
action="store_true",
|
||||
help="Disable IPv6 support",
|
||||
default=False,
|
||||
)
|
||||
|
||||
def _run_distribution_start_cmd(self, args: argparse.Namespace) -> None:
|
||||
dist = resolve_distribution(args.name)
|
||||
if dist is None:
|
||||
self.parser.error(f"Distribution with name {args.name} not found")
|
||||
return
|
||||
|
||||
config_yaml = Path(DISTRIBS_BASE_DIR) / dist.name / "config.yaml"
|
||||
if not config_yaml.exists():
|
||||
raise ValueError(
|
||||
f"Configuration {config_yaml} does not exist. Please run `llama distribution install` or `llama distribution configure` first"
|
||||
)
|
||||
|
||||
with open(config_yaml, "r") as fp:
|
||||
config = yaml.safe_load(fp)
|
||||
|
||||
conda_env = config["conda_env"]
|
||||
python_exe = run_command(shlex.split("which python"))
|
||||
# simple check, unfortunate
|
||||
if conda_env not in python_exe:
|
||||
raise ValueError(
|
||||
f"Please re-run start after activating the `{conda_env}` conda environment first"
|
||||
)
|
||||
|
||||
distribution_server_init(
|
||||
dist.name,
|
||||
config_yaml,
|
||||
args.port,
|
||||
disable_ipv6=args.disable_ipv6,
|
||||
)
|
||||
# run_with_pty(
|
||||
# shlex.split(
|
||||
# f"conda run -n {conda_env} python -m llama_toolchain.distribution.server {dist.name} {config_yaml} --port 5000"
|
||||
# )
|
||||
# )
|
|
@ -8,21 +8,35 @@ import errno
|
|||
import os
|
||||
import pty
|
||||
import select
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import termios
|
||||
import tty
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
|
||||
def run_with_pty(command):
|
||||
old_settings = termios.tcgetattr(sys.stdin)
|
||||
|
||||
# Create a new pseudo-terminal
|
||||
master, slave = pty.openpty()
|
||||
|
||||
old_settings = termios.tcgetattr(sys.stdin)
|
||||
original_sigint = signal.getsignal(signal.SIGINT)
|
||||
|
||||
ctrl_c_pressed = False
|
||||
|
||||
def sigint_handler(signum, frame):
|
||||
nonlocal ctrl_c_pressed
|
||||
ctrl_c_pressed = True
|
||||
cprint("\nCtrl-C detected. Aborting...", "white", attrs=["bold"])
|
||||
|
||||
try:
|
||||
# ensure the terminal does not echo input
|
||||
tty.setraw(sys.stdin.fileno())
|
||||
# Set up the signal handler
|
||||
signal.signal(signal.SIGINT, sigint_handler)
|
||||
|
||||
new_settings = termios.tcgetattr(sys.stdin)
|
||||
new_settings[3] = new_settings[3] & ~termios.ECHO # Disable echo
|
||||
new_settings[3] = new_settings[3] & ~termios.ICANON # Disable canonical mode
|
||||
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, new_settings)
|
||||
|
||||
process = subprocess.Popen(
|
||||
command,
|
||||
|
@ -30,18 +44,20 @@ def run_with_pty(command):
|
|||
stdout=slave,
|
||||
stderr=slave,
|
||||
universal_newlines=True,
|
||||
preexec_fn=os.setsid,
|
||||
)
|
||||
|
||||
# Close the slave file descriptor as it's now owned by the subprocess
|
||||
os.close(slave)
|
||||
|
||||
def handle_io():
|
||||
while True:
|
||||
rlist, _, _ = select.select([sys.stdin, master], [], [])
|
||||
while not ctrl_c_pressed:
|
||||
try:
|
||||
rlist, _, _ = select.select([sys.stdin, master], [], [], 0.1)
|
||||
|
||||
if sys.stdin in rlist:
|
||||
data = os.read(sys.stdin.fileno(), 1024)
|
||||
if not data: # EOF
|
||||
if not data:
|
||||
break
|
||||
os.write(master, data)
|
||||
|
||||
|
@ -49,7 +65,15 @@ def run_with_pty(command):
|
|||
data = os.read(master, 1024)
|
||||
if not data:
|
||||
break
|
||||
os.write(sys.stdout.fileno(), data)
|
||||
sys.stdout.buffer.write(data)
|
||||
sys.stdout.flush()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
# This will be raised when Ctrl+C is pressed
|
||||
break
|
||||
|
||||
if process.poll() is not None:
|
||||
break
|
||||
|
||||
handle_io()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
|
@ -58,11 +82,14 @@ def run_with_pty(command):
|
|||
if e.errno != errno.EIO:
|
||||
raise
|
||||
finally:
|
||||
# Restore original terminal settings
|
||||
# Clean up
|
||||
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_settings)
|
||||
signal.signal(signal.SIGINT, original_sigint)
|
||||
|
||||
process.wait()
|
||||
os.close(master)
|
||||
if process.poll() is None:
|
||||
process.terminate()
|
||||
process.wait()
|
||||
|
||||
return process.returncode
|
||||
|
||||
|
|
|
@ -8,7 +8,8 @@ import argparse
|
|||
|
||||
from .distribution import DistributionParser
|
||||
from .download import Download
|
||||
from .inference import InferenceParser
|
||||
|
||||
# from .inference import InferenceParser
|
||||
from .model import ModelParser
|
||||
|
||||
|
||||
|
@ -29,7 +30,7 @@ class LlamaCLIParser:
|
|||
|
||||
# Add sub-commands
|
||||
Download.create(subparsers)
|
||||
InferenceParser.create(subparsers)
|
||||
# InferenceParser.create(subparsers)
|
||||
ModelParser.create(subparsers)
|
||||
DistributionParser.create(subparsers)
|
||||
|
||||
|
|
|
@ -5,54 +5,10 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Union
|
||||
from typing import Dict, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from strong_typing.schema import json_schema_type
|
||||
from typing_extensions import Annotated
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AdapterType(Enum):
|
||||
passthrough_api = "passthrough_api"
|
||||
python_impl = "python_impl"
|
||||
not_implemented = "not_implemented"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class PassthroughApiAdapterConfig(BaseModel):
|
||||
type: Literal[AdapterType.passthrough_api.value] = AdapterType.passthrough_api.value
|
||||
base_url: str = Field(..., description="The base URL for the llama stack provider")
|
||||
headers: Dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
description="Headers (e.g., authorization) to send with the request",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class PythonImplAdapterConfig(BaseModel):
|
||||
type: Literal[AdapterType.python_impl.value] = AdapterType.python_impl.value
|
||||
adapter_id: str
|
||||
kwargs: Dict[str, Any] = Field(
|
||||
default_factory=dict, description="kwargs to pass to the entrypoint"
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class NotImplementedAdapterConfig(BaseModel):
|
||||
type: Literal[AdapterType.not_implemented.value] = AdapterType.not_implemented.value
|
||||
|
||||
|
||||
# should we define very granular typed classes for each of the PythonImplAdapters we will have?
|
||||
# e.g., OllamaInference / vLLMInference / etc. might need very specific parameters
|
||||
AdapterConfig = Annotated[
|
||||
Union[
|
||||
PassthroughApiAdapterConfig,
|
||||
NotImplementedAdapterConfig,
|
||||
PythonImplAdapterConfig,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -61,6 +17,13 @@ class ApiSurface(Enum):
|
|||
safety = "safety"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ApiSurfaceEndpoint(BaseModel):
|
||||
route: str
|
||||
method: str
|
||||
name: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Adapter(BaseModel):
|
||||
api_surface: ApiSurface
|
||||
|
@ -108,13 +71,3 @@ class Distribution(BaseModel):
|
|||
default_factory=list,
|
||||
description="Additional pip packages beyond those required by the adapters",
|
||||
)
|
||||
|
||||
|
||||
def distribution_dependencies(distribution: Distribution) -> List[str]:
|
||||
# only consider SourceAdapters when calculating dependencies
|
||||
return [
|
||||
dep
|
||||
for adapter in distribution.adapters.values()
|
||||
if isinstance(adapter, SourceAdapter)
|
||||
for dep in adapter.pip_packages
|
||||
] + distribution.additional_pip_packages
|
||||
|
|
51
llama_toolchain/distribution/distribution.py
Normal file
51
llama_toolchain/distribution/distribution.py
Normal file
|
@ -0,0 +1,51 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import inspect
|
||||
from typing import Dict, List
|
||||
|
||||
from llama_toolchain.inference.api.endpoints import Inference
|
||||
from llama_toolchain.safety.api.endpoints import Safety
|
||||
|
||||
from .datatypes import ApiSurface, ApiSurfaceEndpoint, Distribution, SourceAdapter
|
||||
|
||||
|
||||
def distribution_dependencies(distribution: Distribution) -> List[str]:
|
||||
# only consider SourceAdapters when calculating dependencies
|
||||
return [
|
||||
dep
|
||||
for adapter in distribution.adapters.values()
|
||||
if isinstance(adapter, SourceAdapter)
|
||||
for dep in adapter.pip_packages
|
||||
] + distribution.additional_pip_packages
|
||||
|
||||
|
||||
def api_surface_endpoints() -> Dict[ApiSurface, List[ApiSurfaceEndpoint]]:
|
||||
surfaces = {}
|
||||
|
||||
protocols = {
|
||||
ApiSurface.inference: Inference,
|
||||
ApiSurface.safety: Safety,
|
||||
}
|
||||
|
||||
for surface, protocol in protocols.items():
|
||||
endpoints = []
|
||||
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
|
||||
|
||||
for name, method in protocol_methods:
|
||||
if not hasattr(method, "__webmethod__"):
|
||||
continue
|
||||
|
||||
webmethod = method.__webmethod__
|
||||
route = webmethod.route
|
||||
|
||||
# use `post` for all methods right now until we fix up the `webmethod` openapi
|
||||
# annotation and write our own openapi generator
|
||||
endpoints.append(ApiSurfaceEndpoint(route=route, method="post", name=name))
|
||||
|
||||
surfaces[surface] = endpoints
|
||||
|
||||
return surfaces
|
26
llama_toolchain/distribution/dynamic.py
Normal file
26
llama_toolchain/distribution/dynamic.py
Normal file
|
@ -0,0 +1,26 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import importlib
|
||||
from typing import Any, Dict
|
||||
|
||||
from .datatypes import SourceAdapter
|
||||
|
||||
|
||||
def instantiate_class_type(fully_qualified_name):
|
||||
module_name, class_name = fully_qualified_name.rsplit(".", 1)
|
||||
module = importlib.import_module(module_name)
|
||||
return getattr(module, class_name)
|
||||
|
||||
|
||||
# returns a class implementing the protocol corresponding to the ApiSurface
|
||||
def instantiate_adapter(adapter: SourceAdapter, adapter_config: Dict[str, Any]):
|
||||
module = importlib.import_module(adapter.module)
|
||||
|
||||
config_type = instantiate_class_type(adapter.config_class)
|
||||
config = config_type(**adapter_config)
|
||||
return asyncio.run(module.get_adapter_impl(config))
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from llama_toolchain.inference.adapters import available_inference_adapters
|
||||
|
||||
|
@ -63,3 +63,10 @@ def available_distributions() -> List[Distribution]:
|
|||
},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def resolve_distribution(name: str) -> Optional[Distribution]:
|
||||
for dist in available_distributions():
|
||||
if dist.name == name:
|
||||
return dist
|
||||
return None
|
||||
|
|
202
llama_toolchain/distribution/server.py
Normal file
202
llama_toolchain/distribution/server.py
Normal file
|
@ -0,0 +1,202 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import signal
|
||||
from collections.abc import (
|
||||
AsyncGenerator as AsyncGeneratorABC,
|
||||
AsyncIterator as AsyncIteratorABC,
|
||||
)
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, get_type_hints, Optional
|
||||
|
||||
import fire
|
||||
import httpx
|
||||
import yaml
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.routing import APIRoute
|
||||
|
||||
from pydantic import BaseModel
|
||||
from termcolor import cprint
|
||||
|
||||
from .datatypes import PassthroughApiAdapter
|
||||
from .distribution import api_surface_endpoints
|
||||
from .dynamic import instantiate_adapter
|
||||
|
||||
from .registry import resolve_distribution
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def is_async_iterator_type(typ):
|
||||
if hasattr(typ, "__origin__"):
|
||||
origin = typ.__origin__
|
||||
if isinstance(origin, type):
|
||||
return issubclass(
|
||||
origin,
|
||||
(AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC),
|
||||
)
|
||||
return False
|
||||
return isinstance(
|
||||
typ, (AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC)
|
||||
)
|
||||
|
||||
|
||||
def create_sse_event(data: Any) -> str:
|
||||
if isinstance(data, BaseModel):
|
||||
data = data.json()
|
||||
else:
|
||||
data = json.dumps(data)
|
||||
|
||||
return f"data: {data}\n\n"
|
||||
|
||||
|
||||
async def passthrough(
|
||||
request: Request,
|
||||
downstream_url: str,
|
||||
downstream_headers: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
client = httpx.AsyncClient()
|
||||
|
||||
headers = dict(request.headers)
|
||||
headers.pop("host", None)
|
||||
headers.update(downstream_headers or {})
|
||||
|
||||
body = await request.body()
|
||||
|
||||
try:
|
||||
response = await client.request(
|
||||
method=request.method,
|
||||
url=downstream_url,
|
||||
headers=headers,
|
||||
content=body,
|
||||
params=request.query_params,
|
||||
)
|
||||
return StreamingResponse(
|
||||
response.iter_bytes(),
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers),
|
||||
)
|
||||
finally:
|
||||
await client.aclose()
|
||||
|
||||
|
||||
def handle_sigint(*args, **kwargs):
|
||||
print("SIGINT or CTRL-C detected. Exiting gracefully", args)
|
||||
loop = asyncio.get_event_loop()
|
||||
for task in asyncio.all_tasks(loop):
|
||||
task.cancel()
|
||||
loop.stop()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
print("Starting up")
|
||||
yield
|
||||
print("Shutting down")
|
||||
|
||||
|
||||
def create_dynamic_passthrough(
|
||||
downstream_url: str, downstream_headers: Optional[Dict[str, str]] = None
|
||||
):
|
||||
async def endpoint(request: Request):
|
||||
return await passthrough(request, downstream_url, downstream_headers)
|
||||
|
||||
return endpoint
|
||||
|
||||
|
||||
def create_dynamic_typed_route(func: Any):
|
||||
hints = get_type_hints(func)
|
||||
request_model = next(iter(hints.values()))
|
||||
response_model = hints["return"]
|
||||
|
||||
is_streaming = is_async_iterator_type(response_model)
|
||||
|
||||
if is_streaming:
|
||||
|
||||
async def endpoint(request: request_model):
|
||||
async def event_generator():
|
||||
async for item in func(request):
|
||||
yield create_sse_event(item)
|
||||
await asyncio.sleep(0.001)
|
||||
|
||||
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
||||
|
||||
else:
|
||||
|
||||
async def endpoint(request: request_model):
|
||||
return func(request)
|
||||
|
||||
return endpoint
|
||||
|
||||
|
||||
def main(
|
||||
dist_name: str, yaml_config: str, port: int = 5000, disable_ipv6: bool = False
|
||||
):
|
||||
dist = resolve_distribution(dist_name)
|
||||
if dist is None:
|
||||
raise ValueError(f"Could not find distribution {dist_name}")
|
||||
|
||||
with open(yaml_config, "r") as fp:
|
||||
config = yaml.safe_load(fp)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
all_endpoints = api_surface_endpoints()
|
||||
|
||||
adapter_configs = config["adapters"]
|
||||
for surface, adapter in dist.adapters.items():
|
||||
if surface.value not in adapter_configs:
|
||||
raise ValueError(
|
||||
f"Could not find adapter config for {surface}. Please add it to the config"
|
||||
)
|
||||
|
||||
adapter_config = adapter_configs[surface.value]
|
||||
endpoints = all_endpoints[surface]
|
||||
if isinstance(adapter, PassthroughApiAdapter):
|
||||
for endpoint in endpoints:
|
||||
url = adapter.base_url.rstrip("/") + endpoint.route
|
||||
getattr(app, endpoint.method)(endpoint.route)(
|
||||
create_dynamic_passthrough(url)
|
||||
)
|
||||
else:
|
||||
impl = instantiate_adapter(adapter, adapter_config)
|
||||
for endpoint in endpoints:
|
||||
if not hasattr(impl, endpoint.name):
|
||||
# ideally this should be a typing violation already
|
||||
raise ValueError(
|
||||
f"Could not find method {endpoint.name} on {impl}!!"
|
||||
)
|
||||
|
||||
impl_method = getattr(impl, endpoint.name)
|
||||
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
|
||||
create_dynamic_typed_route(impl_method)
|
||||
)
|
||||
|
||||
for route in app.routes:
|
||||
if isinstance(route, APIRoute):
|
||||
cprint(
|
||||
f"Serving {next(iter(route.methods))} {route.path}",
|
||||
"white",
|
||||
attrs=["bold"],
|
||||
)
|
||||
|
||||
signal.signal(signal.SIGINT, handle_sigint)
|
||||
|
||||
import uvicorn
|
||||
|
||||
# FYI this does not do hot-reloads
|
||||
listen_host = "::" if not disable_ipv6 else "0.0.0.0"
|
||||
print(f"Listening on {listen_host}:{port}")
|
||||
uvicorn.run(app, host=listen_host, port=port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(main)
|
|
@ -19,7 +19,7 @@ def available_inference_adapters() -> List[Adapter]:
|
|||
"zmq",
|
||||
],
|
||||
module="llama_toolchain.inference.inference",
|
||||
config_class="llama_toolchain.inference.inference.InlineImplConfig",
|
||||
config_class="llama_toolchain.inference.inference.MetaReferenceImplConfig",
|
||||
),
|
||||
SourceAdapter(
|
||||
api_surface=ApiSurface.inference,
|
||||
|
|
|
@ -7,9 +7,6 @@
|
|||
from enum import Enum
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
from hydra.core.config_store import ConfigStore
|
||||
|
||||
from hydra_zen import builds
|
||||
from llama_models.llama3_1.api.datatypes import CheckpointQuantizationFormat
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
@ -19,13 +16,6 @@ from typing_extensions import Annotated
|
|||
from .datatypes import QuantizationConfig
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ImplType(Enum):
|
||||
inline = "inline"
|
||||
remote = "remote"
|
||||
ollama = "ollama"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CheckpointType(Enum):
|
||||
pytorch = "pytorch"
|
||||
|
@ -66,8 +56,8 @@ class ModelCheckpointConfig(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class InlineImplConfig(BaseModel):
|
||||
impl_type: Literal[ImplType.inline.value] = ImplType.inline.value
|
||||
class MetaReferenceImplConfig(BaseModel):
|
||||
model: str
|
||||
checkpoint_config: ModelCheckpointConfig
|
||||
quantization: Optional[QuantizationConfig] = None
|
||||
torch_seed: Optional[int] = None
|
||||
|
@ -75,28 +65,7 @@ class InlineImplConfig(BaseModel):
|
|||
max_batch_size: int = 1
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RemoteImplConfig(BaseModel):
|
||||
impl_type: Literal[ImplType.remote.value] = ImplType.remote.value
|
||||
url: str = Field(..., description="The URL of the remote module")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OllamaImplConfig(BaseModel):
|
||||
impl_type: Literal[ImplType.ollama.value] = ImplType.ollama.value
|
||||
model: str = Field(..., description="The name of the model in ollama catalog")
|
||||
url: str = Field(..., description="The URL for the ollama server")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class InferenceConfig(BaseModel):
|
||||
impl_config: Annotated[
|
||||
Union[InlineImplConfig, RemoteImplConfig, OllamaImplConfig],
|
||||
Field(discriminator="impl_type"),
|
||||
]
|
||||
|
||||
|
||||
InferenceHydraConfig = builds(InferenceConfig)
|
||||
|
||||
cs = ConfigStore.instance()
|
||||
cs.store(name="inference_config", node=InferenceHydraConfig)
|
||||
|
|
|
@ -4,19 +4,19 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .api.config import ImplType, InferenceConfig
|
||||
# from .api.config import ImplType, InferenceConfig
|
||||
|
||||
|
||||
async def get_inference_api_instance(config: InferenceConfig):
|
||||
if config.impl_config.impl_type == ImplType.inline.value:
|
||||
from .inference import InferenceImpl
|
||||
# async def get_inference_api_instance(config: InferenceConfig):
|
||||
# if config.impl_config.impl_type == ImplType.inline.value:
|
||||
# from .inference import InferenceImpl
|
||||
|
||||
return InferenceImpl(config.impl_config)
|
||||
elif config.impl_config.impl_type == ImplType.ollama.value:
|
||||
from .ollama import OllamaInference
|
||||
# return InferenceImpl(config.impl_config)
|
||||
# elif config.impl_config.impl_type == ImplType.ollama.value:
|
||||
# from .ollama import OllamaInference
|
||||
|
||||
return OllamaInference(config.impl_config)
|
||||
# return OllamaInference(config.impl_config)
|
||||
|
||||
from .client import InferenceClient
|
||||
# from .client import InferenceClient
|
||||
|
||||
return InferenceClient(config.impl_config.url)
|
||||
# return InferenceClient(config.impl_config.url)
|
||||
|
|
|
@ -29,7 +29,7 @@ from llama_models.llama3_1.api.model import Transformer
|
|||
from llama_models.llama3_1.api.tokenizer import Tokenizer
|
||||
from termcolor import cprint
|
||||
|
||||
from .api.config import CheckpointType, InlineImplConfig
|
||||
from .api.config import CheckpointType, MetaReferenceImplConfig
|
||||
from .api.datatypes import QuantizationType
|
||||
|
||||
|
||||
|
@ -42,7 +42,7 @@ class TokenResult:
|
|||
|
||||
class Llama:
|
||||
@staticmethod
|
||||
def build(config: InlineImplConfig):
|
||||
def build(config: MetaReferenceImplConfig):
|
||||
"""
|
||||
Build a Llama instance by initializing and loading a model checkpoint.
|
||||
|
||||
|
|
|
@ -4,11 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import AsyncGenerator
|
||||
import asyncio
|
||||
|
||||
from typing import AsyncIterator, Union
|
||||
|
||||
from llama_models.llama3_1.api.datatypes import StopReason
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from .api.config import InlineImplConfig
|
||||
from .api.config import MetaReferenceImplConfig
|
||||
from .api.datatypes import (
|
||||
ChatCompletionResponseEvent,
|
||||
ChatCompletionResponseEventType,
|
||||
|
@ -19,23 +22,35 @@ from .api.endpoints import (
|
|||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionRequest,
|
||||
Inference,
|
||||
)
|
||||
from .model_parallel import LlamaModelParallelGenerator
|
||||
|
||||
|
||||
def get_adapter_impl(config: InlineImplConfig) -> Inference:
|
||||
async def get_adapter_impl(config: MetaReferenceImplConfig):
|
||||
assert isinstance(
|
||||
config, InlineImplConfig
|
||||
config, MetaReferenceImplConfig
|
||||
), f"Unexpected config type: {type(config)}"
|
||||
return InferenceImpl(config)
|
||||
|
||||
impl = MetaReferenceInferenceImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
class InferenceImpl(Inference):
|
||||
# there's a single model parallel process running serving the model. for now,
|
||||
# we don't support multiple concurrent requests to this process.
|
||||
SEMAPHORE = asyncio.Semaphore(1)
|
||||
|
||||
def __init__(self, config: InlineImplConfig) -> None:
|
||||
|
||||
class MetaReferenceInferenceImpl(Inference):
|
||||
|
||||
def __init__(self, config: MetaReferenceImplConfig) -> None:
|
||||
self.config = config
|
||||
model = resolve_model(config.model)
|
||||
if model is None:
|
||||
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
|
||||
self.model = model
|
||||
# verify that the checkpoint actually is for this model lol
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.generator = LlamaModelParallelGenerator(self.config)
|
||||
|
@ -44,10 +59,27 @@ class InferenceImpl(Inference):
|
|||
async def shutdown(self) -> None:
|
||||
self.generator.stop()
|
||||
|
||||
async def completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
raise NotImplementedError()
|
||||
# hm, when stream=False, we should not be doing SSE :/ which is what the
|
||||
# top-level server is going to do. make the typing more specific here
|
||||
async def chat_completion(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> AsyncIterator[
|
||||
Union[ChatCompletionResponseStreamChunk, ChatCompletionResponse]
|
||||
]:
|
||||
model = resolve_model(request.model)
|
||||
if model is None:
|
||||
raise RuntimeError(
|
||||
f"Unknown model: {request.model}, Run `llama model list`"
|
||||
)
|
||||
elif model.descriptor() != self.model.descriptor():
|
||||
raise RuntimeError(
|
||||
f"Model mismatch: {request.model} != {self.model.descriptor()}"
|
||||
)
|
||||
|
||||
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||
if SEMAPHORE.locked():
|
||||
raise RuntimeError("Only one concurrent request is supported")
|
||||
|
||||
async with SEMAPHORE:
|
||||
if request.stream:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
|
@ -125,7 +157,9 @@ class InferenceImpl(Inference):
|
|||
|
||||
# TODO(ashwin): parse tool calls separately here and report errors?
|
||||
# if someone breaks the iteration before coming here we are toast
|
||||
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
|
||||
message = self.generator.formatter.decode_assistant_message(
|
||||
tokens, stop_reason
|
||||
)
|
||||
if request.stream:
|
||||
parsed_tool_calls = len(message.tool_calls) > 0
|
||||
if ipython and not parsed_tool_calls:
|
||||
|
|
|
@ -13,7 +13,7 @@ from llama_models.llama3_1.api.chat_format import ChatFormat
|
|||
from llama_models.llama3_1.api.datatypes import Message
|
||||
from llama_models.llama3_1.api.tokenizer import Tokenizer
|
||||
|
||||
from .api.config import InlineImplConfig
|
||||
from .api.config import MetaReferenceImplConfig
|
||||
from .generation import Llama
|
||||
from .parallel_utils import ModelParallelProcessGroup
|
||||
|
||||
|
@ -42,7 +42,7 @@ class ModelRunner:
|
|||
)
|
||||
|
||||
|
||||
def init_model_cb(config: InlineImplConfig):
|
||||
def init_model_cb(config: MetaReferenceImplConfig):
|
||||
llama = Llama.build(config)
|
||||
return ModelRunner(llama)
|
||||
|
||||
|
@ -58,7 +58,7 @@ class LlamaModelParallelGenerator:
|
|||
clear at the callsite why we need to use a context manager.
|
||||
"""
|
||||
|
||||
def __init__(self, config: InlineImplConfig):
|
||||
def __init__(self, config: MetaReferenceImplConfig):
|
||||
self.config = config
|
||||
|
||||
# this is a hack because Agent's loop uses this to tokenize and check if input is too long
|
||||
|
|
|
@ -17,7 +17,7 @@ from llama_models.llama3_1.api.model import Transformer, TransformerBlock
|
|||
|
||||
from llama_toolchain.inference.api.config import (
|
||||
CheckpointQuantizationFormat,
|
||||
InlineImplConfig,
|
||||
MetaReferenceImplConfig,
|
||||
)
|
||||
from llama_toolchain.inference.api.datatypes import QuantizationType
|
||||
|
||||
|
@ -46,7 +46,7 @@ def swiglu_wrapper(
|
|||
|
||||
def convert_to_quantized_model(
|
||||
model: Transformer,
|
||||
config: InlineImplConfig,
|
||||
config: MetaReferenceImplConfig,
|
||||
fp8_activation_scale_ub: Optional[float] = 1200.0,
|
||||
) -> Transformer:
|
||||
if config.quantization.type == QuantizationType.bf16.value:
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from .datatypes import * # noqa: F403
|
||||
from typing import Protocol
|
||||
from typing import List, Protocol
|
||||
|
||||
from llama_models.llama3_1.api.datatypes import Message
|
||||
|
||||
|
@ -19,7 +19,7 @@ class RunShieldRequest(BaseModel):
|
|||
messages: List[Message]
|
||||
|
||||
|
||||
class SafetyCheck(Protocol):
|
||||
class Safety(Protocol):
|
||||
|
||||
@webmethod(route="/safety/run_shield")
|
||||
async def run_shield(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue