skeleton unified routing table, api routers

This commit is contained in:
Xi Yan 2024-09-21 13:44:33 -07:00
parent 2dc14cba2c
commit 85d927adde
11 changed files with 210 additions and 231 deletions

View file

@ -50,7 +50,10 @@ from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.distribution import api_endpoints, api_providers
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.utils.dynamic import instantiate_provider
from llama_stack.distribution.utils.dynamic import (
instantiate_provider,
instantiate_router,
)
def is_async_iterator_type(typ):
@ -288,8 +291,8 @@ def snake_to_camel(snake_str):
return "".join(word.capitalize() for word in snake_str.split("_"))
async def resolve_impls(
provider_map: Dict[str, ProviderMapEntry],
async def resolve_impls_with_routing(
stack_run_config: StackRunConfig,
) -> Dict[Api, Any]:
"""
Does two things:
@ -297,33 +300,28 @@ async def resolve_impls(
- for each API, produces either a (local, passthrough or router) implementation
"""
all_providers = api_providers()
specs = {}
for api_str, item in provider_map.items():
for api_str in stack_run_config.apis_to_serve:
api = Api(api_str)
providers = all_providers[api]
if isinstance(item, GenericProviderConfig):
if item.provider_id not in providers:
# check for regular providers without routing
if api_str in stack_run_config.provider_map:
provider_map_entry = stack_run_config.provider_map[api_str]
if provider_map_entry.provider_id not in providers:
raise ValueError(
f"Unknown provider `{provider_id}` is not available for API `{api}`"
)
specs[api] = providers[item.provider_id]
else:
assert isinstance(item, list)
inner_specs = []
for rt_entry in item:
if rt_entry.provider_id not in providers:
raise ValueError(
f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`"
)
inner_specs.append(providers[rt_entry.provider_id])
specs[api] = providers[provider_map_entry.provider_id]
# check for routing table, we need to pass routing table to the router implementation
if api_str in stack_run_config.provider_routing_table:
specs[api] = RouterProviderSpec(
api=api,
module=f"llama_stack.providers.routers.{api.value.lower()}",
module=f"llama_stack.distribution.routers",
api_dependencies=[],
inner_specs=inner_specs,
routing_table=stack_run_config.provider_routing_table[api_str],
)
sorted_specs = topological_sort(specs.values())
@ -331,9 +329,16 @@ async def resolve_impls(
impls = {}
for spec in sorted_specs:
api = spec.api
deps = {api: impls[api] for api in spec.api_dependencies}
impl = await instantiate_provider(spec, deps, provider_map[api.value])
if api.value in stack_run_config.provider_map:
provider_config = stack_run_config.provider_map[api.value]
impl = await instantiate_provider(spec, deps, provider_config)
elif api.value in stack_run_config.provider_routing_table:
impl = await instantiate_router(
spec, api.value, stack_run_config.provider_routing_table
)
else:
raise ValueError(f"Cannot find provider_config for Api {api.value}")
impls[api] = impl
return impls, specs
@ -345,7 +350,8 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
app = FastAPI()
impls, specs = asyncio.run(resolve_impls(config.provider_map))
# impls, specs = asyncio.run(resolve_impls(config.provider_map))
impls, specs = asyncio.run(resolve_impls_with_routing(config))
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])