implement full-passthrough in the server

This commit is contained in:
Ashwin Bharambe 2024-08-03 14:15:20 -07:00
parent 38fd76f85c
commit 9dafa6ad94
8 changed files with 69 additions and 71 deletions

View file

@ -11,11 +11,12 @@ import json
import shlex
from pathlib import Path
from typing import Annotated, get_args, get_origin, Literal, Optional, Union
from typing import get_args, get_origin, Literal, Optional, Union
import yaml
from pydantic import BaseModel
from termcolor import cprint
from typing_extensions import Annotated
from llama_toolchain.cli.subcommand import Subcommand
from llama_toolchain.utils import DISTRIBS_BASE_DIR, EnumEncoder
@ -37,6 +38,7 @@ class DistributionConfigure(Subcommand):
def _add_arguments(self):
from llama_toolchain.distribution.registry import available_distributions
self.parser.add_argument(
"--name",
type=str,
@ -64,6 +66,7 @@ class DistributionConfigure(Subcommand):
def configure_llama_distribution(dist: "Distribution", conda_env: str):
from llama_toolchain.distribution.datatypes import PassthroughApiAdapter
from .utils import run_command
python_exe = run_command(shlex.split("which python"))
@ -101,7 +104,10 @@ def configure_llama_distribution(dist: "Distribution", conda_env: str):
else None
),
)
adapter_configs[api_surface.value] = config.dict()
adapter_configs[api_surface.value] = {
adapter_id: adapter.adapter_id,
**config.dict(),
}
dist_config = {
"adapters": adapter_configs,
@ -138,6 +144,8 @@ def get_non_none_type(field_type):
return next(arg for arg in get_args(field_type) if arg is not type(None))
# TODO: maybe support List values (for simple types, it should be comma-separated and for complex ones,
# it should prompt iteratively if the user wants to add more values)
def prompt_for_config(
config_type: type[BaseModel], existing_config: Optional[BaseModel] = None
) -> BaseModel: