From 32b9907d699795825d9cb932afdea586686571a2 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Sat, 21 Sep 2024 12:42:10 -0700 Subject: [PATCH] Revert "add new resolve_impls_with_routing" This reverts commit 34f0c11001cfa969cf91bcb4255854e471d73051. --- llama_stack/distribution/datatypes.py | 9 +-------- llama_stack/distribution/server/server.py | 12 +++--------- llama_stack/examples/router-table-run.yaml | 3 ++- 3 files changed, 6 insertions(+), 18 deletions(-) diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index 708787374..a230dacf7 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -209,8 +209,7 @@ class ProviderRoutingEntry(GenericProviderConfig): routing_key: str -ProviderMapEntry = Union[GenericProviderConfig, str] -ProviderRoutingTableEntry = List[ProviderRoutingEntry] +ProviderMapEntry = Union[GenericProviderConfig, List[ProviderRoutingEntry]] @json_schema_type @@ -249,12 +248,6 @@ As examples: The key may support wild-cards alsothe routing_key to route to the correct provider.""", ) - provider_routing_table: Dict[str, ProviderRoutingTableEntry] = Field( - description=""" - Map of API to a list of providers backing the API. - Each provider is a (routing_key, provider_config) tuple. - """ - ) @json_schema_type diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 4c9e97899..deb7bf787 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -290,18 +290,18 @@ def snake_to_camel(snake_str): async def resolve_impls_with_routing( stack_run_config: StackRunConfig, ) -> Dict[Api, Any]: - raise NotImplementedError("This is not implemented yet") async def resolve_impls( - provider_map: Dict[str, ProviderMapEntry], + stack_run_config: StackRunConfig, ) -> Dict[Api, Any]: """ Does two things: - flatmaps, sorts and resolves the providers in dependency order - for each API, produces either a (local, passthrough or router) implementation """ + provider_map = stack_run_config.provider_map all_providers = api_providers() specs = {} @@ -349,15 +349,9 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False): with open(yaml_config, "r") as fp: config = StackRunConfig(**yaml.safe_load(fp)) - cprint(f"StackRunConfig: {config}", "blue") app = FastAPI() - # check if routing table exists - if config.provider_routing_table is not None: - impls, specs = asyncio.run(resolve_impls_with_routing(config)) - else: - impls, specs = asyncio.run(resolve_impls(config.provider_map)) - + impls, specs = asyncio.run(resolve_impls(config)) if Api.telemetry in impls: setup_logger(impls[Api.telemetry]) diff --git a/llama_stack/examples/router-table-run.yaml b/llama_stack/examples/router-table-run.yaml index a94c883cb..aec0fca7e 100644 --- a/llama_stack/examples/router-table-run.yaml +++ b/llama_stack/examples/router-table-run.yaml @@ -9,7 +9,7 @@ provider_map: # use builtin-router as dummy field memory: builtin-router inference: builtin-router -provider_routing_table: +routing_table: inference: - routing_key: Meta-Llama3.1-8B-Instruct provider_id: meta-reference @@ -91,3 +91,4 @@ provider_routing_table: # api: safety # config: # model: Prompt-Guard-86M +