From 95a7f225cfdd9f35dc4017e1e96dc336e70e21db Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Tue, 17 Sep 2024 14:12:22 -0700 Subject: [PATCH] stack configure fixes --- llama_stack/distribution/configure.py | 28 ++++++++++++++----- .../impls/meta_reference/agents/agents.py | 4 ++- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/llama_stack/distribution/configure.py b/llama_stack/distribution/configure.py index 2de8d47d1..3c77423d7 100644 --- a/llama_stack/distribution/configure.py +++ b/llama_stack/distribution/configure.py @@ -51,27 +51,31 @@ def configure_api_providers( apis = [v.value for v in stack_apis()] all_providers = api_providers() - apis_to_serve = req_apis.apis_to_serve + ["telemetry"] - for api_str in apis_to_serve: + for api_str in spec.providers.keys(): if api_str not in apis: raise ValueError(f"Unknown API `{api_str}`") cprint(f"Configuring API `{api_str}`...\n", "white", attrs=["bold"]) api = Api(api_str) - if isinstance(spec.providers[api_str], list): + + provider_or_providers = spec.providers[api_str] + if isinstance(provider_or_providers, list) and len(provider_or_providers) > 1: print( "You have specified multiple providers for this API. We will configure a routing table now. For each provider, provide a routing key followed by provider configuration.\n" ) + routing_entries = [] - for p in spec.providers[api_str]: + for p in provider_or_providers: print(f"Configuring provider `{p}`...") provider_spec = all_providers[api][p] config_type = instantiate_class_type(provider_spec.config_class) + # TODO: we need to validate the routing keys, and + # perhaps it is better if we break this out into asking + # for a routing key separately from the associated config wrapper_type = make_routing_entry_type(config_type) rt_entry = prompt_for_config(wrapper_type, None) - # TODO: we need to validate the routing keys routing_entries.append( ProviderRoutingEntry( provider_id=p, @@ -81,11 +85,21 @@ def configure_api_providers( ) config.provider_map[api_str] = routing_entries else: - p = spec.providers[api_str] + p = ( + provider_or_providers[0] + if isinstance(provider_or_providers, list) + else provider_or_providers + ) print(f"Configuring provider `{p}`...") provider_spec = all_providers[api][p] config_type = instantiate_class_type(provider_spec.config_class) - cfg = prompt_for_config(config_type, None) + try: + provider_config = config.provider_map.get(api_str) + if provider_config: + existing = config_type(**provider_config.config) + except Exception: + existing = None + cfg = prompt_for_config(config_type, existing) config.provider_map[api_str] = GenericProviderConfig( provider_id=p, config=cfg.dict(), diff --git a/llama_stack/providers/impls/meta_reference/agents/agents.py b/llama_stack/providers/impls/meta_reference/agents/agents.py index 25517ba6c..262afc611 100644 --- a/llama_stack/providers/impls/meta_reference/agents/agents.py +++ b/llama_stack/providers/impls/meta_reference/agents/agents.py @@ -69,7 +69,9 @@ class MetaReferenceAgentsImpl(Agents): elif tool_defn.engine == SearchEngineType.bing: key = self.config.bing_search_api_key if not key: - raise ValueError("API key not defined in config") + raise ValueError( + "Search (Brave or Bing) API key not defined in config" + ) tool = SearchTool(tool_defn.engine, key) elif isinstance(tool_defn, CodeInterpreterToolDefinition): tool = CodeInterpreterTool()