From 6a95edc80640cafab9ab8e6a9409768b133275de Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Fri, 20 Sep 2024 13:47:58 -0700 Subject: [PATCH] backward compatibility --- llama_stack/distribution/server/server.py | 140 +++++++++++++--------- llama_stack/examples/simple-run.yaml | 1 + 2 files changed, 82 insertions(+), 59 deletions(-) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 60855070e..64c1111e7 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -287,60 +287,60 @@ def snake_to_camel(snake_str): return "".join(word.capitalize() for word in snake_str.split("_")) -async def resolve_impls_with_routing( - stack_run_config: StackRunConfig, -) -> Dict[Api, Any]: +# async def resolve_impls_with_routing( +# stack_run_config: StackRunConfig, +# ) -> Dict[Api, Any]: - all_providers = api_providers() - specs = {} +# all_providers = api_providers() +# specs = {} - for api_str in stack_run_config.apis_to_serve: - api = Api(api_str) - providers = all_providers[api] +# for api_str in stack_run_config.apis_to_serve: +# api = Api(api_str) +# providers = all_providers[api] - # check for regular providers without routing - if api_str in stack_run_config.provider_map: - provider_map_entry = stack_run_config.provider_map[api_str] - if provider_map_entry.provider_id not in providers: - raise ValueError( - f"Unknown provider `{provider_id}` is not available for API `{api}`" - ) - specs[api] = providers[provider_map_entry.provider_id] +# # check for regular providers without routing +# if api_str in stack_run_config.provider_map: +# provider_map_entry = stack_run_config.provider_map[api_str] +# if provider_map_entry.provider_id not in providers: +# raise ValueError( +# f"Unknown provider `{provider_id}` is not available for API `{api}`" +# ) +# specs[api] = providers[provider_map_entry.provider_id] - # check for routing table, we need to pass routing table to the router implementation - if api_str in stack_run_config.provider_routing_table: - router_entry = stack_run_config.provider_routing_table[api_str] - inner_specs = [] - for rt_entry in router_entry: - if rt_entry.provider_id not in providers: - raise ValueError( - f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`" - ) - inner_specs.append(providers[rt_entry.provider_id]) +# # check for routing table, we need to pass routing table to the router implementation +# if api_str in stack_run_config.provider_routing_table: +# router_entry = stack_run_config.provider_routing_table[api_str] +# inner_specs = [] +# for rt_entry in router_entry: +# if rt_entry.provider_id not in providers: +# raise ValueError( +# f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`" +# ) +# inner_specs.append(providers[rt_entry.provider_id]) - specs[api] = RouterProviderSpec( - api=api, - module=f"llama_stack.distribution.routers.{api.value.lower()}", - api_dependencies=[], - inner_specs=inner_specs, - ) +# specs[api] = RouterProviderSpec( +# api=api, +# module=f"llama_stack.distribution.routers.{api.value.lower()}", +# api_dependencies=[], +# inner_specs=inner_specs, +# ) - sorted_specs = topological_sort(specs.values()) +# sorted_specs = topological_sort(specs.values()) - impls = {} - for spec in sorted_specs: - api = spec.api - deps = {api: impls[api] for api in spec.api_dependencies} - if api.value in stack_run_config.provider_map: - provider_config = stack_run_config.provider_map[api.value] - elif api.value in stack_run_config.provider_routing_table: - provider_config = stack_run_config.provider_routing_table[api.value] - else: - raise ValueError(f"Cannot find provider_config for Api {api.value}") - impl = await instantiate_provider(spec, deps, provider_config) - impls[api] = impl +# impls = {} +# for spec in sorted_specs: +# api = spec.api +# deps = {api: impls[api] for api in spec.api_dependencies} +# if api.value in stack_run_config.provider_map: +# provider_config = stack_run_config.provider_map[api.value] +# elif api.value in stack_run_config.provider_routing_table: +# provider_config = stack_run_config.provider_routing_table[api.value] +# else: +# raise ValueError(f"Cannot find provider_config for Api {api.value}") +# impl = await instantiate_provider(spec, deps, provider_config) +# impls[api] = impl - return impls, specs +# return impls, specs async def resolve_impls( @@ -352,12 +352,10 @@ async def resolve_impls( - for each API, produces either a (local, passthrough or router) implementation """ all_providers = api_providers() - specs = {} for api_str, item in provider_map.items(): api = Api(api_str) providers = all_providers[api] - if isinstance(item, GenericProviderConfig): if item.provider_id not in providers: raise ValueError( @@ -365,31 +363,39 @@ async def resolve_impls( ) specs[api] = providers[item.provider_id] else: - raise ValueError( - f"Please define routing table in provider_routing_table of run config" + assert isinstance(item, list) + inner_specs = [] + for rt_entry in item: + if rt_entry.provider_id not in providers: + raise ValueError( + f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`" + ) + inner_specs.append(providers[rt_entry.provider_id]) + + specs[api] = RouterProviderSpec( + api=api, + module=f"llama_stack.providers.routers.{api.value.lower()}", + api_dependencies=[], + inner_specs=inner_specs, ) sorted_specs = topological_sort(specs.values()) - impls = {} for spec in sorted_specs: api = spec.api - deps = {api: impls[api] for api in spec.api_dependencies} impl = await instantiate_provider(spec, deps, provider_map[api.value]) impls[api] = impl - return impls, specs -def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False): - with open(yaml_config, "r") as fp: - config = StackRunConfig(**yaml.safe_load(fp)) - - cprint(f"StackRunConfig: {config}", "blue") +# This runs +def run_main_DEPRECATED( + config: StackRunConfig, port: int = 5000, disable_ipv6: bool = False +): app = FastAPI() - impls, specs = asyncio.run(resolve_impls_with_routing(config)) + impls, specs = asyncio.run(resolve_impls(config.provider_map)) if Api.telemetry in impls: setup_logger(impls[Api.telemetry]) @@ -449,5 +455,21 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False): uvicorn.run(app, host=listen_host, port=port) +def run_main(config: StackRunConfig, port: int = 5000, disable_ipv6: bool = False): + raise ValueError("Not implemented") + + +def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False): + with open(yaml_config, "r") as fp: + config = StackRunConfig(**yaml.safe_load(fp)) + + cprint(f"StackRunConfig: {config}", "blue") + + if not config.provider_routing_table: + run_main_DEPRECATED(config, port, disable_ipv6) + else: + run_main(config, port, disable_ipv6) + + if __name__ == "__main__": fire.Fire(main) diff --git a/llama_stack/examples/simple-run.yaml b/llama_stack/examples/simple-run.yaml index 5b5301592..1e50ef849 100644 --- a/llama_stack/examples/simple-run.yaml +++ b/llama_stack/examples/simple-run.yaml @@ -36,3 +36,4 @@ provider_map: telemetry: provider_id: meta-reference config: {} +provider_routing_table: {}