From 5f9a7dcdccbf59d219526d79639e133265d0b97b Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Sat, 21 Sep 2024 12:42:08 -0700 Subject: [PATCH] Revert "backward compatibility" This reverts commit 6a95edc80640cafab9ab8e6a9409768b133275de. --- llama_stack/distribution/server/server.py | 140 +++++++++------------- llama_stack/examples/simple-run.yaml | 1 - 2 files changed, 59 insertions(+), 82 deletions(-) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 64c1111e7..60855070e 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -287,60 +287,60 @@ def snake_to_camel(snake_str): return "".join(word.capitalize() for word in snake_str.split("_")) -# async def resolve_impls_with_routing( -# stack_run_config: StackRunConfig, -# ) -> Dict[Api, Any]: +async def resolve_impls_with_routing( + stack_run_config: StackRunConfig, +) -> Dict[Api, Any]: -# all_providers = api_providers() -# specs = {} + all_providers = api_providers() + specs = {} -# for api_str in stack_run_config.apis_to_serve: -# api = Api(api_str) -# providers = all_providers[api] + for api_str in stack_run_config.apis_to_serve: + api = Api(api_str) + providers = all_providers[api] -# # check for regular providers without routing -# if api_str in stack_run_config.provider_map: -# provider_map_entry = stack_run_config.provider_map[api_str] -# if provider_map_entry.provider_id not in providers: -# raise ValueError( -# f"Unknown provider `{provider_id}` is not available for API `{api}`" -# ) -# specs[api] = providers[provider_map_entry.provider_id] + # check for regular providers without routing + if api_str in stack_run_config.provider_map: + provider_map_entry = stack_run_config.provider_map[api_str] + if provider_map_entry.provider_id not in providers: + raise ValueError( + f"Unknown provider `{provider_id}` is not available for API `{api}`" + ) + specs[api] = providers[provider_map_entry.provider_id] -# # check for routing table, we need to pass routing table to the router implementation -# if api_str in stack_run_config.provider_routing_table: -# router_entry = stack_run_config.provider_routing_table[api_str] -# inner_specs = [] -# for rt_entry in router_entry: -# if rt_entry.provider_id not in providers: -# raise ValueError( -# f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`" -# ) -# inner_specs.append(providers[rt_entry.provider_id]) + # check for routing table, we need to pass routing table to the router implementation + if api_str in stack_run_config.provider_routing_table: + router_entry = stack_run_config.provider_routing_table[api_str] + inner_specs = [] + for rt_entry in router_entry: + if rt_entry.provider_id not in providers: + raise ValueError( + f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`" + ) + inner_specs.append(providers[rt_entry.provider_id]) -# specs[api] = RouterProviderSpec( -# api=api, -# module=f"llama_stack.distribution.routers.{api.value.lower()}", -# api_dependencies=[], -# inner_specs=inner_specs, -# ) + specs[api] = RouterProviderSpec( + api=api, + module=f"llama_stack.distribution.routers.{api.value.lower()}", + api_dependencies=[], + inner_specs=inner_specs, + ) -# sorted_specs = topological_sort(specs.values()) + sorted_specs = topological_sort(specs.values()) -# impls = {} -# for spec in sorted_specs: -# api = spec.api -# deps = {api: impls[api] for api in spec.api_dependencies} -# if api.value in stack_run_config.provider_map: -# provider_config = stack_run_config.provider_map[api.value] -# elif api.value in stack_run_config.provider_routing_table: -# provider_config = stack_run_config.provider_routing_table[api.value] -# else: -# raise ValueError(f"Cannot find provider_config for Api {api.value}") -# impl = await instantiate_provider(spec, deps, provider_config) -# impls[api] = impl + impls = {} + for spec in sorted_specs: + api = spec.api + deps = {api: impls[api] for api in spec.api_dependencies} + if api.value in stack_run_config.provider_map: + provider_config = stack_run_config.provider_map[api.value] + elif api.value in stack_run_config.provider_routing_table: + provider_config = stack_run_config.provider_routing_table[api.value] + else: + raise ValueError(f"Cannot find provider_config for Api {api.value}") + impl = await instantiate_provider(spec, deps, provider_config) + impls[api] = impl -# return impls, specs + return impls, specs async def resolve_impls( @@ -352,10 +352,12 @@ async def resolve_impls( - for each API, produces either a (local, passthrough or router) implementation """ all_providers = api_providers() + specs = {} for api_str, item in provider_map.items(): api = Api(api_str) providers = all_providers[api] + if isinstance(item, GenericProviderConfig): if item.provider_id not in providers: raise ValueError( @@ -363,39 +365,31 @@ async def resolve_impls( ) specs[api] = providers[item.provider_id] else: - assert isinstance(item, list) - inner_specs = [] - for rt_entry in item: - if rt_entry.provider_id not in providers: - raise ValueError( - f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`" - ) - inner_specs.append(providers[rt_entry.provider_id]) - - specs[api] = RouterProviderSpec( - api=api, - module=f"llama_stack.providers.routers.{api.value.lower()}", - api_dependencies=[], - inner_specs=inner_specs, + raise ValueError( + f"Please define routing table in provider_routing_table of run config" ) sorted_specs = topological_sort(specs.values()) + impls = {} for spec in sorted_specs: api = spec.api + deps = {api: impls[api] for api in spec.api_dependencies} impl = await instantiate_provider(spec, deps, provider_map[api.value]) impls[api] = impl + return impls, specs -# This runs -def run_main_DEPRECATED( - config: StackRunConfig, port: int = 5000, disable_ipv6: bool = False -): +def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False): + with open(yaml_config, "r") as fp: + config = StackRunConfig(**yaml.safe_load(fp)) + + cprint(f"StackRunConfig: {config}", "blue") app = FastAPI() - impls, specs = asyncio.run(resolve_impls(config.provider_map)) + impls, specs = asyncio.run(resolve_impls_with_routing(config)) if Api.telemetry in impls: setup_logger(impls[Api.telemetry]) @@ -455,21 +449,5 @@ def run_main_DEPRECATED( uvicorn.run(app, host=listen_host, port=port) -def run_main(config: StackRunConfig, port: int = 5000, disable_ipv6: bool = False): - raise ValueError("Not implemented") - - -def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False): - with open(yaml_config, "r") as fp: - config = StackRunConfig(**yaml.safe_load(fp)) - - cprint(f"StackRunConfig: {config}", "blue") - - if not config.provider_routing_table: - run_main_DEPRECATED(config, port, disable_ipv6) - else: - run_main(config, port, disable_ipv6) - - if __name__ == "__main__": fire.Fire(main) diff --git a/llama_stack/examples/simple-run.yaml b/llama_stack/examples/simple-run.yaml index 1e50ef849..5b5301592 100644 --- a/llama_stack/examples/simple-run.yaml +++ b/llama_stack/examples/simple-run.yaml @@ -36,4 +36,3 @@ provider_map: telemetry: provider_id: meta-reference config: {} -provider_routing_table: {}