implement full-passthrough in the server

This commit is contained in:
Ashwin Bharambe 2024-08-03 14:15:20 -07:00
parent 38fd76f85c
commit 9dafa6ad94
8 changed files with 69 additions and 71 deletions

View file

@ -11,11 +11,12 @@ import json
import shlex
from pathlib import Path
from typing import Annotated, get_args, get_origin, Literal, Optional, Union
from typing import get_args, get_origin, Literal, Optional, Union
import yaml
from pydantic import BaseModel
from termcolor import cprint
from typing_extensions import Annotated
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.utils import DISTRIBS_BASE_DIR, EnumEncoder
@ -37,6 +38,7 @@ class DistributionConfigure(Subcommand):
def _add_arguments(self):
from llama_toolchain.distribution.registry import available_distributions
self.parser.add_argument(
"--name",
type=str,
@ -64,6 +66,7 @@ class DistributionConfigure(Subcommand):
def configure_llama_distribution(dist: "Distribution", conda_env: str):
from llama_toolchain.distribution.datatypes import PassthroughApiAdapter
from .utils import run_command
python_exe = run_command(shlex.split("which python"))
@ -101,7 +104,10 @@ def configure_llama_distribution(dist: "Distribution", conda_env: str):
else None
),
)
adapter_configs[api_surface.value] = config.dict()
adapter_configs[api_surface.value] = {
adapter_id: adapter.adapter_id,
**config.dict(),
}
dist_config = {
"adapters": adapter_configs,
@ -138,6 +144,8 @@ def get_non_none_type(field_type):
return next(arg for arg in get_args(field_type) if arg is not type(None))
# TODO: maybe support List values (for simple types, it should be comma-separated and for complex ones,
# it should prompt iteratively if the user wants to add more values)
def prompt_for_config(
config_type: type[BaseModel], existing_config: Optional[BaseModel] = None
) -> BaseModel:

View file

@ -30,6 +30,7 @@ class DistributionInstall(Subcommand):
def _add_arguments(self):
from llama_toolchain.distribution.registry import available_distributions
self.parser.add_argument(
"--name",
type=str,
@ -63,7 +64,7 @@ class DistributionInstall(Subcommand):
os.makedirs(DISTRIBS_BASE_DIR / dist.name, exist_ok=True)
deps = distribution_dependencies(dist)
run_command([script, args.conda_env, " ".join(deps)])
run_with_pty([script, args.conda_env, " ".join(deps)])
with open(DISTRIBS_BASE_DIR / dist.name / "conda.env", "w") as f:
f.write(f"{args.conda_env}\n")