fix configure for simple case

This commit is contained in:
Xi Yan 2024-09-22 21:29:47 -07:00
parent 00ef672509
commit 211abd27d5
2 changed files with 34 additions and 49 deletions

View file

@ -9,7 +9,11 @@ from typing import Any
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.distribution import api_providers, stack_apis from llama_stack.distribution.distribution import (
api_providers,
builtin_automatically_routed_apis,
stack_apis,
)
from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
@ -29,7 +33,14 @@ def make_routing_entry_type(config_class: Any):
def configure_api_providers( def configure_api_providers(
config: StackRunConfig, spec: DistributionSpec config: StackRunConfig, spec: DistributionSpec
) -> StackRunConfig: ) -> StackRunConfig:
cprint(f"configure_api_providers {spec}", "red")
apis = config.apis_to_serve or list(spec.providers.keys()) apis = config.apis_to_serve or list(spec.providers.keys())
# append the bulitin automatically routed APIs
for inf in builtin_automatically_routed_apis():
if inf.router_api.value in apis:
apis.append(inf.routing_table_api.value)
config.apis_to_serve = [a for a in apis if a != "telemetry"] config.apis_to_serve = [a for a in apis if a != "telemetry"]
apis = [v.value for v in stack_apis()] apis = [v.value for v in stack_apis()]
@ -43,52 +54,26 @@ def configure_api_providers(
api = Api(api_str) api = Api(api_str)
provider_or_providers = spec.providers[api_str] provider_or_providers = spec.providers[api_str]
if isinstance(provider_or_providers, list) and len(provider_or_providers) > 1: p = (
print( provider_or_providers[0]
"You have specified multiple providers for this API. We will configure a routing table now. For each provider, provide a routing key followed by provider configuration.\n" if isinstance(provider_or_providers, list)
) else provider_or_providers
)
routing_entries = [] print(f"Configuring provider `{p}`...")
for p in provider_or_providers: provider_spec = all_providers[api][p]
print(f"Configuring provider `{p}`...") config_type = instantiate_class_type(provider_spec.config_class)
provider_spec = all_providers[api][p] try:
config_type = instantiate_class_type(provider_spec.config_class) provider_config = config.api_providers.get(api_str)
if provider_config:
# TODO: we need to validate the routing keys, and existing = config_type(**provider_config.config)
# perhaps it is better if we break this out into asking else:
# for a routing key separately from the associated config
wrapper_type = make_routing_entry_type(config_type)
rt_entry = prompt_for_config(wrapper_type, None)
routing_entries.append(
ProviderRoutingEntry(
provider_id=p,
routing_key=rt_entry.routing_key,
config=rt_entry.config.dict(),
)
)
config.api_providers[api_str] = routing_entries
else:
p = (
provider_or_providers[0]
if isinstance(provider_or_providers, list)
else provider_or_providers
)
print(f"Configuring provider `{p}`...")
provider_spec = all_providers[api][p]
config_type = instantiate_class_type(provider_spec.config_class)
try:
provider_config = config.api_providers.get(api_str)
if provider_config:
existing = config_type(**provider_config.config)
else:
existing = None
except Exception:
existing = None existing = None
cfg = prompt_for_config(config_type, existing) except Exception:
config.api_providers[api_str] = GenericProviderConfig( existing = None
provider_id=p, cfg = prompt_for_config(config_type, existing)
config=cfg.dict(), config.api_providers[api_str] = GenericProviderConfig(
) provider_id=p,
config=cfg.dict(),
)
return config return config

View file

@ -8,8 +8,6 @@ import importlib
import inspect import inspect
from typing import Dict, List from typing import Dict, List
from pydantic import BaseModel
from llama_stack.apis.agents import Agents from llama_stack.apis.agents import Agents
from llama_stack.apis.inference import Inference from llama_stack.apis.inference import Inference
from llama_stack.apis.memory import Memory from llama_stack.apis.memory import Memory
@ -19,6 +17,8 @@ from llama_stack.apis.safety import Safety
from llama_stack.apis.shields import Shields from llama_stack.apis.shields import Shields
from llama_stack.apis.telemetry import Telemetry from llama_stack.apis.telemetry import Telemetry
from pydantic import BaseModel
from .datatypes import Api, ApiEndpoint, ProviderSpec, remote_provider_spec from .datatypes import Api, ApiEndpoint, ProviderSpec, remote_provider_spec
# These are the dependencies needed by the distribution server. # These are the dependencies needed by the distribution server.