mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-06 20:44:58 +00:00
Adapter -> Provider
This commit is contained in:
parent
db3e6dda07
commit
65a9e40174
15 changed files with 119 additions and 110 deletions
|
@ -63,7 +63,7 @@ def configure_llama_distribution(dist: "Distribution", conda_env: str):
|
|||
from llama_toolchain.common.exec import run_command
|
||||
from llama_toolchain.common.prompt_for_config import prompt_for_config
|
||||
from llama_toolchain.common.serialize import EnumEncoder
|
||||
from llama_toolchain.distribution.datatypes import PassthroughApiAdapter
|
||||
from llama_toolchain.distribution.datatypes import RemoteProviderSpec
|
||||
from llama_toolchain.distribution.dynamic import instantiate_class_type
|
||||
|
||||
python_exe = run_command(shlex.split("which python"))
|
||||
|
@ -84,28 +84,28 @@ def configure_llama_distribution(dist: "Distribution", conda_env: str):
|
|||
with open(config_path, "r") as fp:
|
||||
existing_config = yaml.safe_load(fp)
|
||||
|
||||
adapter_configs = {}
|
||||
for api, adapter in dist.adapters.items():
|
||||
if isinstance(adapter, PassthroughApiAdapter):
|
||||
adapter_configs[api.value] = adapter.dict()
|
||||
provider_configs = {}
|
||||
for api, provider_spec in dist.provider_specs.items():
|
||||
if isinstance(provider_spec, RemoteProviderSpec):
|
||||
provider_configs[api.value] = provider_spec.dict()
|
||||
else:
|
||||
cprint(f"Configuring API surface: {api.value}", "white", attrs=["bold"])
|
||||
config_type = instantiate_class_type(adapter.config_class)
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
config = prompt_for_config(
|
||||
config_type,
|
||||
(
|
||||
config_type(**existing_config["adapters"][api.value])
|
||||
if existing_config and api.value in existing_config["adapters"]
|
||||
config_type(**existing_config["providers"][api.value])
|
||||
if existing_config and api.value in existing_config["providers"]
|
||||
else None
|
||||
),
|
||||
)
|
||||
adapter_configs[api.value] = {
|
||||
"adapter_id": adapter.adapter_id,
|
||||
provider_configs[api.value] = {
|
||||
"provider_id": provider_spec.provider_id,
|
||||
**config.dict(),
|
||||
}
|
||||
|
||||
dist_config = {
|
||||
"adapters": adapter_configs,
|
||||
"providers": provider_configs,
|
||||
"conda_env": conda_env,
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ class DistributionCreate(Subcommand):
|
|||
required=True,
|
||||
)
|
||||
# for each Api the user wants to support, we should
|
||||
# get the list of available adapters, ask which one the user
|
||||
# get the list of available providers, ask which one the user
|
||||
# wants to pick and then ask for their configuration.
|
||||
|
||||
def _run_distribution_create_cmd(self, args: argparse.Namespace) -> None:
|
||||
|
|
|
@ -33,17 +33,17 @@ class DistributionList(Subcommand):
|
|||
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
|
||||
headers = [
|
||||
"Name",
|
||||
"Adapters",
|
||||
"ProviderSpecs",
|
||||
"Description",
|
||||
]
|
||||
|
||||
rows = []
|
||||
for dist in available_distributions():
|
||||
adapters = {k.value: v.adapter_id for k, v in dist.adapters.items()}
|
||||
providers = {k.value: v.provider_id for k, v in dist.provider_specs.items()}
|
||||
rows.append(
|
||||
[
|
||||
dist.name,
|
||||
json.dumps(adapters, indent=2),
|
||||
json.dumps(providers, indent=2),
|
||||
dist.description,
|
||||
]
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue