mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
Further generalize Xi's changes (#88)
* Further generalize Xi's changes - introduce a slightly more general notion of an AutoRouted provider - the AutoRouted provider is associated with a RoutingTable provider - e.g. inference -> models - Introduced safety -> shields and memory -> memory_banks correspondences * typo * Basic build and run succeeded
This commit is contained in:
parent
b8914bb56f
commit
c1ab66f1e6
21 changed files with 597 additions and 418 deletions
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
import signal
|
||||
|
@ -36,6 +35,9 @@ from fastapi import Body, FastAPI, HTTPException, Request, Response
|
|||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastapi.routing import APIRoute
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from termcolor import cprint
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.providers.utils.telemetry.tracing import (
|
||||
end_trace,
|
||||
|
@ -43,18 +45,15 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
|||
SpanStatus,
|
||||
start_trace,
|
||||
)
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from termcolor import cprint
|
||||
from typing_extensions import Annotated
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.distribution import api_endpoints, api_providers
|
||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||
from llama_stack.distribution.utils.dynamic import (
|
||||
instantiate_builtin_provider,
|
||||
instantiate_provider,
|
||||
instantiate_router,
|
||||
from llama_stack.distribution.distribution import (
|
||||
api_endpoints,
|
||||
api_providers,
|
||||
builtin_automatically_routed_apis,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_provider
|
||||
|
||||
|
||||
def is_async_iterator_type(typ):
|
||||
|
@ -292,9 +291,7 @@ def snake_to_camel(snake_str):
|
|||
return "".join(word.capitalize() for word in snake_str.split("_"))
|
||||
|
||||
|
||||
async def resolve_impls_with_routing(
|
||||
stack_run_config: StackRunConfig,
|
||||
) -> Dict[Api, Any]:
|
||||
async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, Any]:
|
||||
"""
|
||||
Does two things:
|
||||
- flatmaps, sorts and resolves the providers in dependency order
|
||||
|
@ -302,48 +299,80 @@ async def resolve_impls_with_routing(
|
|||
"""
|
||||
all_providers = api_providers()
|
||||
specs = {}
|
||||
configs = {}
|
||||
|
||||
for api_str in stack_run_config.apis_to_serve:
|
||||
for api_str, config in run_config.api_providers.items():
|
||||
api = Api(api_str)
|
||||
|
||||
# TODO: check that these APIs are not in the routing table part of the config
|
||||
providers = all_providers[api]
|
||||
|
||||
# check for routing table, we need to pass routing table to the router implementation
|
||||
if api_str in stack_run_config.provider_routing_table:
|
||||
specs[api] = RouterProviderSpec(
|
||||
api=api,
|
||||
module=f"llama_stack.distribution.routers",
|
||||
api_dependencies=[],
|
||||
routing_table=stack_run_config.provider_routing_table[api_str],
|
||||
if config.provider_id not in providers:
|
||||
raise ValueError(
|
||||
f"Unknown provider `{config.provider_id}` is not available for API `{api}`"
|
||||
)
|
||||
else:
|
||||
if api_str in stack_run_config.provider_map:
|
||||
provider_map_entry = stack_run_config.provider_map[api_str]
|
||||
provider_id = provider_map_entry.provider_id
|
||||
else:
|
||||
# not defined in config, will be a builtin provider, assign builtin provider id
|
||||
provider_id = "builtin"
|
||||
specs[api] = providers[config.provider_id]
|
||||
configs[api] = config
|
||||
|
||||
if provider_id not in providers:
|
||||
apis_to_serve = run_config.apis_to_serve or set(
|
||||
list(specs.keys()) + list(run_config.routing_tables.keys())
|
||||
)
|
||||
print("apis_to_serve", apis_to_serve)
|
||||
for info in builtin_automatically_routed_apis():
|
||||
source_api = info.routing_table_api
|
||||
|
||||
assert (
|
||||
source_api not in specs
|
||||
), f"Routing table API {source_api} specified in wrong place?"
|
||||
assert (
|
||||
info.router_api not in specs
|
||||
), f"Auto-routed API {info.router_api} specified in wrong place?"
|
||||
|
||||
if info.router_api.value not in apis_to_serve:
|
||||
continue
|
||||
|
||||
if source_api.value not in run_config.routing_tables:
|
||||
raise ValueError(f"Routing table for `{source_api.value}` is not provided?")
|
||||
|
||||
routing_table = run_config.routing_tables[source_api.value]
|
||||
|
||||
providers = all_providers[info.router_api]
|
||||
|
||||
inner_specs = []
|
||||
for rt_entry in routing_table.entries:
|
||||
if rt_entry.provider_id not in providers:
|
||||
raise ValueError(
|
||||
f"Unknown provider `{provider_id}` is not available for API `{api}`"
|
||||
f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`"
|
||||
)
|
||||
specs[api] = providers[provider_id]
|
||||
inner_specs.append(providers[rt_entry.provider_id])
|
||||
|
||||
specs[source_api] = RoutingTableProviderSpec(
|
||||
api=source_api,
|
||||
module="llama_stack.distribution.routers",
|
||||
api_dependencies=[],
|
||||
inner_specs=inner_specs,
|
||||
)
|
||||
configs[source_api] = routing_table
|
||||
|
||||
specs[info.router_api] = AutoRoutedProviderSpec(
|
||||
api=info.router_api,
|
||||
module="llama_stack.distribution.routers",
|
||||
routing_table_api=source_api,
|
||||
api_dependencies=[source_api],
|
||||
)
|
||||
configs[info.router_api] = {}
|
||||
|
||||
sorted_specs = topological_sort(specs.values())
|
||||
|
||||
print(f"Resolved {len(sorted_specs)} providers in topological order")
|
||||
for spec in sorted_specs:
|
||||
print(f" {spec.api}: {spec.provider_id}")
|
||||
print("")
|
||||
impls = {}
|
||||
for spec in sorted_specs:
|
||||
api = spec.api
|
||||
deps = {api: impls[api] for api in spec.api_dependencies}
|
||||
if api.value in stack_run_config.provider_map:
|
||||
provider_config = stack_run_config.provider_map[api.value]
|
||||
impl = await instantiate_provider(spec, deps, provider_config)
|
||||
elif api.value in stack_run_config.provider_routing_table:
|
||||
impl = await instantiate_router(
|
||||
spec, api.value, stack_run_config.provider_routing_table
|
||||
)
|
||||
else:
|
||||
impl = await instantiate_builtin_provider(spec, stack_run_config)
|
||||
impl = await instantiate_provider(spec, deps, configs[api])
|
||||
|
||||
impls[api] = impl
|
||||
|
||||
return impls, specs
|
||||
|
@ -355,16 +384,23 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
|||
|
||||
app = FastAPI()
|
||||
|
||||
# impls, specs = asyncio.run(resolve_impls(config.provider_map))
|
||||
impls, specs = asyncio.run(resolve_impls_with_routing(config))
|
||||
if Api.telemetry in impls:
|
||||
setup_logger(impls[Api.telemetry])
|
||||
|
||||
all_endpoints = api_endpoints()
|
||||
|
||||
apis_to_serve = config.apis_to_serve or list(config.provider_map.keys())
|
||||
if config.apis_to_serve:
|
||||
apis_to_serve = set(config.apis_to_serve)
|
||||
for inf in builtin_automatically_routed_apis():
|
||||
if inf.router_api.value in apis_to_serve:
|
||||
apis_to_serve.add(inf.routing_table_api)
|
||||
else:
|
||||
apis_to_serve = set(impls.keys())
|
||||
|
||||
for api_str in apis_to_serve:
|
||||
api = Api(api_str)
|
||||
|
||||
endpoints = all_endpoints[api]
|
||||
impl = impls[api]
|
||||
|
||||
|
@ -391,7 +427,11 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
|||
create_dynamic_typed_route(
|
||||
impl_method,
|
||||
endpoint.method,
|
||||
provider_spec.provider_data_validator,
|
||||
(
|
||||
provider_spec.provider_data_validator
|
||||
if not isinstance(provider_spec, RoutingTableProviderSpec)
|
||||
else None
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue