Further generalize Xi's changes (#88)

* Further generalize Xi's changes

- introduce a slightly more general notion of an AutoRouted provider
- the AutoRouted provider is associated with a RoutingTable provider
- e.g. inference -> models
- Introduced safety -> shields and memory -> memory_banks
  correspondences

* typo

* Basic build and run succeeded
This commit is contained in:
Ashwin Bharambe 2024-09-22 16:31:18 -07:00 committed by GitHub
parent b8914bb56f
commit c1ab66f1e6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 597 additions and 418 deletions

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import asyncio
import importlib
import inspect
import json
import signal
@ -36,6 +35,9 @@ from fastapi import Body, FastAPI, HTTPException, Request, Response
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.routing import APIRoute
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_stack.providers.utils.telemetry.tracing import (
end_trace,
@ -43,18 +45,15 @@ from llama_stack.providers.utils.telemetry.tracing import (
SpanStatus,
start_trace,
)
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.distribution import api_endpoints, api_providers
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.utils.dynamic import (
instantiate_builtin_provider,
instantiate_provider,
instantiate_router,
from llama_stack.distribution.distribution import (
api_endpoints,
api_providers,
builtin_automatically_routed_apis,
)
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.utils.dynamic import instantiate_provider
def is_async_iterator_type(typ):
@ -292,9 +291,7 @@ def snake_to_camel(snake_str):
return "".join(word.capitalize() for word in snake_str.split("_"))
async def resolve_impls_with_routing(
stack_run_config: StackRunConfig,
) -> Dict[Api, Any]:
async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, Any]:
"""
Does two things:
- flatmaps, sorts and resolves the providers in dependency order
@ -302,48 +299,80 @@ async def resolve_impls_with_routing(
"""
all_providers = api_providers()
specs = {}
configs = {}
for api_str in stack_run_config.apis_to_serve:
for api_str, config in run_config.api_providers.items():
api = Api(api_str)
# TODO: check that these APIs are not in the routing table part of the config
providers = all_providers[api]
# check for routing table, we need to pass routing table to the router implementation
if api_str in stack_run_config.provider_routing_table:
specs[api] = RouterProviderSpec(
api=api,
module=f"llama_stack.distribution.routers",
api_dependencies=[],
routing_table=stack_run_config.provider_routing_table[api_str],
if config.provider_id not in providers:
raise ValueError(
f"Unknown provider `{config.provider_id}` is not available for API `{api}`"
)
else:
if api_str in stack_run_config.provider_map:
provider_map_entry = stack_run_config.provider_map[api_str]
provider_id = provider_map_entry.provider_id
else:
# not defined in config, will be a builtin provider, assign builtin provider id
provider_id = "builtin"
specs[api] = providers[config.provider_id]
configs[api] = config
if provider_id not in providers:
apis_to_serve = run_config.apis_to_serve or set(
list(specs.keys()) + list(run_config.routing_tables.keys())
)
print("apis_to_serve", apis_to_serve)
for info in builtin_automatically_routed_apis():
source_api = info.routing_table_api
assert (
source_api not in specs
), f"Routing table API {source_api} specified in wrong place?"
assert (
info.router_api not in specs
), f"Auto-routed API {info.router_api} specified in wrong place?"
if info.router_api.value not in apis_to_serve:
continue
if source_api.value not in run_config.routing_tables:
raise ValueError(f"Routing table for `{source_api.value}` is not provided?")
routing_table = run_config.routing_tables[source_api.value]
providers = all_providers[info.router_api]
inner_specs = []
for rt_entry in routing_table.entries:
if rt_entry.provider_id not in providers:
raise ValueError(
f"Unknown provider `{provider_id}` is not available for API `{api}`"
f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`"
)
specs[api] = providers[provider_id]
inner_specs.append(providers[rt_entry.provider_id])
specs[source_api] = RoutingTableProviderSpec(
api=source_api,
module="llama_stack.distribution.routers",
api_dependencies=[],
inner_specs=inner_specs,
)
configs[source_api] = routing_table
specs[info.router_api] = AutoRoutedProviderSpec(
api=info.router_api,
module="llama_stack.distribution.routers",
routing_table_api=source_api,
api_dependencies=[source_api],
)
configs[info.router_api] = {}
sorted_specs = topological_sort(specs.values())
print(f"Resolved {len(sorted_specs)} providers in topological order")
for spec in sorted_specs:
print(f" {spec.api}: {spec.provider_id}")
print("")
impls = {}
for spec in sorted_specs:
api = spec.api
deps = {api: impls[api] for api in spec.api_dependencies}
if api.value in stack_run_config.provider_map:
provider_config = stack_run_config.provider_map[api.value]
impl = await instantiate_provider(spec, deps, provider_config)
elif api.value in stack_run_config.provider_routing_table:
impl = await instantiate_router(
spec, api.value, stack_run_config.provider_routing_table
)
else:
impl = await instantiate_builtin_provider(spec, stack_run_config)
impl = await instantiate_provider(spec, deps, configs[api])
impls[api] = impl
return impls, specs
@ -355,16 +384,23 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
app = FastAPI()
# impls, specs = asyncio.run(resolve_impls(config.provider_map))
impls, specs = asyncio.run(resolve_impls_with_routing(config))
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])
all_endpoints = api_endpoints()
apis_to_serve = config.apis_to_serve or list(config.provider_map.keys())
if config.apis_to_serve:
apis_to_serve = set(config.apis_to_serve)
for inf in builtin_automatically_routed_apis():
if inf.router_api.value in apis_to_serve:
apis_to_serve.add(inf.routing_table_api)
else:
apis_to_serve = set(impls.keys())
for api_str in apis_to_serve:
api = Api(api_str)
endpoints = all_endpoints[api]
impl = impls[api]
@ -391,7 +427,11 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
create_dynamic_typed_route(
impl_method,
endpoint.method,
provider_spec.provider_data_validator,
(
provider_spec.provider_data_validator
if not isinstance(provider_spec, RoutingTableProviderSpec)
else None
),
)
)