mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
Adapter -> Provider
This commit is contained in:
parent
db3e6dda07
commit
65a9e40174
15 changed files with 119 additions and 110 deletions
|
@ -63,7 +63,7 @@ def configure_llama_distribution(dist: "Distribution", conda_env: str):
|
|||
from llama_toolchain.common.exec import run_command
|
||||
from llama_toolchain.common.prompt_for_config import prompt_for_config
|
||||
from llama_toolchain.common.serialize import EnumEncoder
|
||||
from llama_toolchain.distribution.datatypes import PassthroughApiAdapter
|
||||
from llama_toolchain.distribution.datatypes import RemoteProviderSpec
|
||||
from llama_toolchain.distribution.dynamic import instantiate_class_type
|
||||
|
||||
python_exe = run_command(shlex.split("which python"))
|
||||
|
@ -84,28 +84,28 @@ def configure_llama_distribution(dist: "Distribution", conda_env: str):
|
|||
with open(config_path, "r") as fp:
|
||||
existing_config = yaml.safe_load(fp)
|
||||
|
||||
adapter_configs = {}
|
||||
for api, adapter in dist.adapters.items():
|
||||
if isinstance(adapter, PassthroughApiAdapter):
|
||||
adapter_configs[api.value] = adapter.dict()
|
||||
provider_configs = {}
|
||||
for api, provider_spec in dist.provider_specs.items():
|
||||
if isinstance(provider_spec, RemoteProviderSpec):
|
||||
provider_configs[api.value] = provider_spec.dict()
|
||||
else:
|
||||
cprint(f"Configuring API surface: {api.value}", "white", attrs=["bold"])
|
||||
config_type = instantiate_class_type(adapter.config_class)
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
config = prompt_for_config(
|
||||
config_type,
|
||||
(
|
||||
config_type(**existing_config["adapters"][api.value])
|
||||
if existing_config and api.value in existing_config["adapters"]
|
||||
config_type(**existing_config["providers"][api.value])
|
||||
if existing_config and api.value in existing_config["providers"]
|
||||
else None
|
||||
),
|
||||
)
|
||||
adapter_configs[api.value] = {
|
||||
"adapter_id": adapter.adapter_id,
|
||||
provider_configs[api.value] = {
|
||||
"provider_id": provider_spec.provider_id,
|
||||
**config.dict(),
|
||||
}
|
||||
|
||||
dist_config = {
|
||||
"adapters": adapter_configs,
|
||||
"providers": provider_configs,
|
||||
"conda_env": conda_env,
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue