Adapter -> Provider

This commit is contained in:
Ashwin Bharambe 2024-08-05 13:26:29 -07:00
parent db3e6dda07
commit 65a9e40174
15 changed files with 119 additions and 110 deletions

View file

@ -63,7 +63,7 @@ def configure_llama_distribution(dist: "Distribution", conda_env: str):
from llama_toolchain.common.exec import run_command
from llama_toolchain.common.prompt_for_config import prompt_for_config
from llama_toolchain.common.serialize import EnumEncoder
from llama_toolchain.distribution.datatypes import PassthroughApiAdapter
from llama_toolchain.distribution.datatypes import RemoteProviderSpec
from llama_toolchain.distribution.dynamic import instantiate_class_type
python_exe = run_command(shlex.split("which python"))
@ -84,28 +84,28 @@ def configure_llama_distribution(dist: "Distribution", conda_env: str):
with open(config_path, "r") as fp:
existing_config = yaml.safe_load(fp)
adapter_configs = {}
for api, adapter in dist.adapters.items():
if isinstance(adapter, PassthroughApiAdapter):
adapter_configs[api.value] = adapter.dict()
provider_configs = {}
for api, provider_spec in dist.provider_specs.items():
if isinstance(provider_spec, RemoteProviderSpec):
provider_configs[api.value] = provider_spec.dict()
else:
cprint(f"Configuring API surface: {api.value}", "white", attrs=["bold"])
config_type = instantiate_class_type(adapter.config_class)
config_type = instantiate_class_type(provider_spec.config_class)
config = prompt_for_config(
config_type,
(
config_type(**existing_config["adapters"][api.value])
if existing_config and api.value in existing_config["adapters"]
config_type(**existing_config["providers"][api.value])
if existing_config and api.value in existing_config["providers"]
else None
),
)
adapter_configs[api.value] = {
"adapter_id": adapter.adapter_id,
provider_configs[api.value] = {
"provider_id": provider_spec.provider_id,
**config.dict(),
}
dist_config = {
"adapters": adapter_configs,
"providers": provider_configs,
"conda_env": conda_env,
}