From abfa1379d1f10126ba8cafcabeb88d0dc21647d6 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Sun, 22 Sep 2024 22:01:48 -0700 Subject: [PATCH] configure w/ routing --- llama_stack/distribution/configure.py | 42 +++++++++++++++------- llama_stack/examples/router-local-run.yaml | 2 +- llama_stack/examples/simple-local-run.yaml | 2 +- 3 files changed, 32 insertions(+), 14 deletions(-) diff --git a/llama_stack/distribution/configure.py b/llama_stack/distribution/configure.py index a64f91770..4c372056b 100644 --- a/llama_stack/distribution/configure.py +++ b/llama_stack/distribution/configure.py @@ -17,6 +17,7 @@ from llama_stack.distribution.distribution import ( from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.prompt_for_config import prompt_for_config +from prompt_toolkit import prompt from termcolor import cprint @@ -28,6 +29,16 @@ def make_routing_entry_type(config_class: Any): return BaseModelWithConfig +def get_builtin_apis(provider_backed_apis: List[str]) -> List[str]: + """Get corresponding builtin APIs given provider backed APIs""" + res = [] + for inf in builtin_automatically_routed_apis(): + if inf.router_api.value in provider_backed_apis: + res.append(inf.routing_table_api.value) + + return res + + # TODO: make sure we can deal with existing configuration values correctly # instead of just overwriting them def configure_api_providers( @@ -35,31 +46,30 @@ def configure_api_providers( ) -> StackRunConfig: cprint(f"configure_api_providers {spec}", "red") apis = config.apis_to_serve or list(spec.providers.keys()) + # append the bulitin routing APIs + apis += get_builtin_apis(apis) - # append the bulitin automatically routed APIs - for inf in builtin_automatically_routed_apis(): - if inf.router_api.value in apis: - apis.append(inf.routing_table_api.value) + router_api2builtin_api = { + inf.router_api.value: inf.routing_table_api.value + for inf in builtin_automatically_routed_apis() + } config.apis_to_serve = [a for a in apis if a != "telemetry"] apis = [v.value for v in stack_apis()] all_providers = api_providers() + # configure simple case for with non-routing providers to api_providers for api_str in spec.providers.keys(): if api_str not in apis: raise ValueError(f"Unknown API `{api_str}`") - cprint(f"Configuring API `{api_str}`...\n", "white", attrs=["bold"]) + cprint(f"Configuring API `{api_str}`...", "green", attrs=["bold"]) api = Api(api_str) - provider_or_providers = spec.providers[api_str] - p = ( - provider_or_providers[0] - if isinstance(provider_or_providers, list) - else provider_or_providers - ) - print(f"Configuring provider `{p}`...") + p = spec.providers[api_str] + cprint(f"=== Configuring provider `{p}` for API {api_str}...", "green") + provider_spec = all_providers[api][p] config_type = instantiate_class_type(provider_spec.config_class) try: @@ -76,4 +86,12 @@ def configure_api_providers( config=cfg.dict(), ) + if api_str in router_api2builtin_api: + # a routing api, we need to assign it a routing_key and put it in the routing_table + routing_key = prompt( + "> Enter routing key for the {} provider: ".format(api_str), + ) + + print("") + return config diff --git a/llama_stack/examples/router-local-run.yaml b/llama_stack/examples/router-local-run.yaml index 08cf9a804..9dec673ea 100644 --- a/llama_stack/examples/router-local-run.yaml +++ b/llama_stack/examples/router-local-run.yaml @@ -9,7 +9,7 @@ apis_to_serve: - agents - safety - models -provider_map: +api_providers: telemetry: provider_id: meta-reference config: {} diff --git a/llama_stack/examples/simple-local-run.yaml b/llama_stack/examples/simple-local-run.yaml index f517116aa..28a4f3825 100644 --- a/llama_stack/examples/simple-local-run.yaml +++ b/llama_stack/examples/simple-local-run.yaml @@ -9,7 +9,7 @@ apis_to_serve: - memory - models - telemetry -provider_map: +api_providers: inference: provider_id: meta-reference config: