From 211abd27d528349436547f5b6ead9256023f7f51 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Sun, 22 Sep 2024 21:29:47 -0700 Subject: [PATCH] fix configure for simple case --- llama_stack/distribution/configure.py | 79 ++++++++++-------------- llama_stack/distribution/distribution.py | 4 +- 2 files changed, 34 insertions(+), 49 deletions(-) diff --git a/llama_stack/distribution/configure.py b/llama_stack/distribution/configure.py index e8807b8e1..a64f91770 100644 --- a/llama_stack/distribution/configure.py +++ b/llama_stack/distribution/configure.py @@ -9,7 +9,11 @@ from typing import Any from pydantic import BaseModel from llama_stack.distribution.datatypes import * # noqa: F403 -from llama_stack.distribution.distribution import api_providers, stack_apis +from llama_stack.distribution.distribution import ( + api_providers, + builtin_automatically_routed_apis, + stack_apis, +) from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.prompt_for_config import prompt_for_config @@ -29,7 +33,14 @@ def make_routing_entry_type(config_class: Any): def configure_api_providers( config: StackRunConfig, spec: DistributionSpec ) -> StackRunConfig: + cprint(f"configure_api_providers {spec}", "red") apis = config.apis_to_serve or list(spec.providers.keys()) + + # append the bulitin automatically routed APIs + for inf in builtin_automatically_routed_apis(): + if inf.router_api.value in apis: + apis.append(inf.routing_table_api.value) + config.apis_to_serve = [a for a in apis if a != "telemetry"] apis = [v.value for v in stack_apis()] @@ -43,52 +54,26 @@ def configure_api_providers( api = Api(api_str) provider_or_providers = spec.providers[api_str] - if isinstance(provider_or_providers, list) and len(provider_or_providers) > 1: - print( - "You have specified multiple providers for this API. We will configure a routing table now. For each provider, provide a routing key followed by provider configuration.\n" - ) - - routing_entries = [] - for p in provider_or_providers: - print(f"Configuring provider `{p}`...") - provider_spec = all_providers[api][p] - config_type = instantiate_class_type(provider_spec.config_class) - - # TODO: we need to validate the routing keys, and - # perhaps it is better if we break this out into asking - # for a routing key separately from the associated config - wrapper_type = make_routing_entry_type(config_type) - rt_entry = prompt_for_config(wrapper_type, None) - - routing_entries.append( - ProviderRoutingEntry( - provider_id=p, - routing_key=rt_entry.routing_key, - config=rt_entry.config.dict(), - ) - ) - config.api_providers[api_str] = routing_entries - else: - p = ( - provider_or_providers[0] - if isinstance(provider_or_providers, list) - else provider_or_providers - ) - print(f"Configuring provider `{p}`...") - provider_spec = all_providers[api][p] - config_type = instantiate_class_type(provider_spec.config_class) - try: - provider_config = config.api_providers.get(api_str) - if provider_config: - existing = config_type(**provider_config.config) - else: - existing = None - except Exception: + p = ( + provider_or_providers[0] + if isinstance(provider_or_providers, list) + else provider_or_providers + ) + print(f"Configuring provider `{p}`...") + provider_spec = all_providers[api][p] + config_type = instantiate_class_type(provider_spec.config_class) + try: + provider_config = config.api_providers.get(api_str) + if provider_config: + existing = config_type(**provider_config.config) + else: existing = None - cfg = prompt_for_config(config_type, existing) - config.api_providers[api_str] = GenericProviderConfig( - provider_id=p, - config=cfg.dict(), - ) + except Exception: + existing = None + cfg = prompt_for_config(config_type, existing) + config.api_providers[api_str] = GenericProviderConfig( + provider_id=p, + config=cfg.dict(), + ) return config diff --git a/llama_stack/distribution/distribution.py b/llama_stack/distribution/distribution.py index 6b72afed5..b641b6582 100644 --- a/llama_stack/distribution/distribution.py +++ b/llama_stack/distribution/distribution.py @@ -8,8 +8,6 @@ import importlib import inspect from typing import Dict, List -from pydantic import BaseModel - from llama_stack.apis.agents import Agents from llama_stack.apis.inference import Inference from llama_stack.apis.memory import Memory @@ -19,6 +17,8 @@ from llama_stack.apis.safety import Safety from llama_stack.apis.shields import Shields from llama_stack.apis.telemetry import Telemetry +from pydantic import BaseModel + from .datatypes import Api, ApiEndpoint, ProviderSpec, remote_provider_spec # These are the dependencies needed by the distribution server.