Further generalize Xi's changes

- introduce a slightly more general notion of an AutoRouted provider
- the AutoRouted provider is associated with a RoutingTable provider
- e.g. inference -> models
- Introduced safety -> shields and memory -> memory_banks
  correspondences
This commit is contained in:
Ashwin Bharambe 2024-09-22 12:06:43 -07:00
parent b8914bb56f
commit e1966b90d9
19 changed files with 559 additions and 388 deletions

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import asyncio
import importlib
import inspect
import json
import signal
@ -36,6 +35,9 @@ from fastapi import Body, FastAPI, HTTPException, Request, Response
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.routing import APIRoute
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_stack.providers.utils.telemetry.tracing import (
end_trace,
@ -43,18 +45,11 @@ from llama_stack.providers.utils.telemetry.tracing import (
SpanStatus,
start_trace,
)
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.distribution import api_endpoints, api_providers
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.utils.dynamic import (
instantiate_builtin_provider,
instantiate_provider,
instantiate_router,
)
from llama_stack.distribution.utils.dynamic import instantiate_provider
def is_async_iterator_type(typ):
@ -292,9 +287,7 @@ def snake_to_camel(snake_str):
return "".join(word.capitalize() for word in snake_str.split("_"))
async def resolve_impls_with_routing(
stack_run_config: StackRunConfig,
) -> Dict[Api, Any]:
async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, Any]:
"""
Does two things:
- flatmaps, sorts and resolves the providers in dependency order
@ -302,32 +295,61 @@ async def resolve_impls_with_routing(
"""
all_providers = api_providers()
specs = {}
configs = {}
for api_str in stack_run_config.apis_to_serve:
for api_str, config in run_config.api_providers.items():
api = Api(api_str)
# TODO: check that these APIs are not in the routing table part of the config
providers = all_providers[api]
# check for routing table, we need to pass routing table to the router implementation
if api_str in stack_run_config.provider_routing_table:
specs[api] = RouterProviderSpec(
api=api,
module=f"llama_stack.distribution.routers",
api_dependencies=[],
routing_table=stack_run_config.provider_routing_table[api_str],
if config.provider_id not in providers:
raise ValueError(
f"Unknown provider `{config.provider_id}` is not available for API `{api}`"
)
else:
if api_str in stack_run_config.provider_map:
provider_map_entry = stack_run_config.provider_map[api_str]
provider_id = provider_map_entry.provider_id
else:
# not defined in config, will be a builtin provider, assign builtin provider id
provider_id = "builtin"
specs[api] = providers[config.provider_id]
configs[api] = config
if provider_id not in providers:
for info in builtin_automatically_routed_apis():
source_api = info.api_with_routing_table
assert (
source_api not in specs
), f"Routing table API {source_api} specified in wrong place?"
assert (
info.router_api not in specs
), f"Auto-routed API {info.router_api} specified in wrong place?"
if source_api.value not in run_config.routing_tables:
raise ValueError(f"Routing table for `{source_api.value}` is not provided?")
routing_table = run_config.routing_tables[source_api.value]
providers = all_providers[info.router_api]
inner_specs = []
for rt_entry in routing_table.entries:
if rt_entry.provider_id not in providers:
raise ValueError(
f"Unknown provider `{provider_id}` is not available for API `{api}`"
f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`"
)
specs[api] = providers[provider_id]
inner_specs.append(providers[rt_entry.provider_id])
specs[source_api] = RoutingTableProviderSpec(
api=source_api,
module="llama_stack.distribution.routers",
api_dependencies=[],
inner_specs=inner_specs,
)
configs[source_api] = routing_table
specs[info.router_api] = AutoRoutedProviderSpec(
api=info.router_api,
module="llama_stack.distribution.routers",
routing_table_api=source_api,
api_dependencies=[source_api],
)
configs[info.router_api] = {}
sorted_specs = topological_sort(specs.values())
@ -335,15 +357,8 @@ async def resolve_impls_with_routing(
for spec in sorted_specs:
api = spec.api
deps = {api: impls[api] for api in spec.api_dependencies}
if api.value in stack_run_config.provider_map:
provider_config = stack_run_config.provider_map[api.value]
impl = await instantiate_provider(spec, deps, provider_config)
elif api.value in stack_run_config.provider_routing_table:
impl = await instantiate_router(
spec, api.value, stack_run_config.provider_routing_table
)
else:
impl = await instantiate_builtin_provider(spec, stack_run_config)
impl = await instantiate_provider(spec, deps, configs[api])
impls[api] = impl
return impls, specs
@ -355,7 +370,6 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
app = FastAPI()
# impls, specs = asyncio.run(resolve_impls(config.provider_map))
impls, specs = asyncio.run(resolve_impls_with_routing(config))
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])