diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index a230dacf7..708787374 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -209,7 +209,8 @@ class ProviderRoutingEntry(GenericProviderConfig): routing_key: str -ProviderMapEntry = Union[GenericProviderConfig, List[ProviderRoutingEntry]] +ProviderMapEntry = Union[GenericProviderConfig, str] +ProviderRoutingTableEntry = List[ProviderRoutingEntry] @json_schema_type @@ -248,6 +249,12 @@ As examples: The key may support wild-cards alsothe routing_key to route to the correct provider.""", ) + provider_routing_table: Dict[str, ProviderRoutingTableEntry] = Field( + description=""" + Map of API to a list of providers backing the API. + Each provider is a (routing_key, provider_config) tuple. + """ + ) @json_schema_type diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index deb7bf787..4c9e97899 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -290,18 +290,18 @@ def snake_to_camel(snake_str): async def resolve_impls_with_routing( stack_run_config: StackRunConfig, ) -> Dict[Api, Any]: + raise NotImplementedError("This is not implemented yet") async def resolve_impls( - stack_run_config: StackRunConfig, + provider_map: Dict[str, ProviderMapEntry], ) -> Dict[Api, Any]: """ Does two things: - flatmaps, sorts and resolves the providers in dependency order - for each API, produces either a (local, passthrough or router) implementation """ - provider_map = stack_run_config.provider_map all_providers = api_providers() specs = {} @@ -349,9 +349,15 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False): with open(yaml_config, "r") as fp: config = StackRunConfig(**yaml.safe_load(fp)) + cprint(f"StackRunConfig: {config}", "blue") app = FastAPI() - impls, specs = asyncio.run(resolve_impls(config)) + # check if routing table exists + if config.provider_routing_table is not None: + impls, specs = asyncio.run(resolve_impls_with_routing(config)) + else: + impls, specs = asyncio.run(resolve_impls(config.provider_map)) + if Api.telemetry in impls: setup_logger(impls[Api.telemetry]) diff --git a/llama_stack/examples/router-table-run.yaml b/llama_stack/examples/router-table-run.yaml index aec0fca7e..a94c883cb 100644 --- a/llama_stack/examples/router-table-run.yaml +++ b/llama_stack/examples/router-table-run.yaml @@ -9,7 +9,7 @@ provider_map: # use builtin-router as dummy field memory: builtin-router inference: builtin-router -routing_table: +provider_routing_table: inference: - routing_key: Meta-Llama3.1-8B-Instruct provider_id: meta-reference @@ -91,4 +91,3 @@ routing_table: # api: safety # config: # model: Prompt-Guard-86M -