mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-05 12:21:52 +00:00
Introduce a "Router" layer for providers
Some providers need to be factorized and considered as thin routing layers on top of other providers. Consider two examples: - The inference API should be a routing layer over inference providers, routed using the "model" key - The memory banks API is another instance where various memory bank types will be provided by independent providers (e.g., a vector store is served by Chroma while a keyvalue memory can be served by Redis or PGVector) This commit introduces a generalized routing layer for this purpose.
This commit is contained in:
parent
5c1f2616b5
commit
b6a3ef51da
12 changed files with 384 additions and 118 deletions
|
@ -9,6 +9,7 @@ import inspect
|
|||
import json
|
||||
import signal
|
||||
import traceback
|
||||
|
||||
from collections.abc import (
|
||||
AsyncGenerator as AsyncGeneratorABC,
|
||||
AsyncIterator as AsyncIteratorABC,
|
||||
|
@ -44,8 +45,8 @@ from llama_toolchain.telemetry.tracing import (
|
|||
SpanStatus,
|
||||
start_trace,
|
||||
)
|
||||
from llama_toolchain.core.datatypes import * # noqa: F403
|
||||
|
||||
from .datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
|
||||
from .distribution import api_endpoints, api_providers
|
||||
from .dynamic import instantiate_provider
|
||||
|
||||
|
@ -271,61 +272,80 @@ def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
|
|||
return [by_id[x] for x in stack]
|
||||
|
||||
|
||||
def resolve_impls(
|
||||
provider_specs: Dict[str, ProviderSpec], config: Dict[str, Any]
|
||||
) -> Dict[Api, Any]:
|
||||
provider_configs = config["providers"]
|
||||
provider_specs = topological_sort(provider_specs.values())
|
||||
def snake_to_camel(snake_str):
|
||||
return "".join(word.capitalize() for word in snake_str.split("_"))
|
||||
|
||||
impls = {}
|
||||
for provider_spec in provider_specs:
|
||||
api = provider_spec.api
|
||||
if api.value not in provider_configs:
|
||||
raise ValueError(
|
||||
f"Could not find provider_spec config for {api}. Please add it to the config"
|
||||
|
||||
async def resolve_impls(
|
||||
provider_map: Dict[str, ProviderMapEntry],
|
||||
) -> Dict[Api, Any]:
|
||||
"""
|
||||
Does two things:
|
||||
- flatmaps, sorts and resolves the providers in dependency order
|
||||
- for each API, produces either a (local, passthrough or router) implementation
|
||||
"""
|
||||
all_providers = api_providers()
|
||||
|
||||
specs = {}
|
||||
for api_str, item in provider_map.items():
|
||||
api = Api(api_str)
|
||||
providers = all_providers[api]
|
||||
|
||||
if isinstance(item, GenericProviderConfig):
|
||||
if item.provider_id not in providers:
|
||||
raise ValueError(
|
||||
f"Unknown provider `{provider_id}` is not available for API `{api}`"
|
||||
)
|
||||
specs[api] = providers[item.provider_id]
|
||||
else:
|
||||
assert isinstance(item, list)
|
||||
inner_specs = []
|
||||
for rt_entry in item:
|
||||
if rt_entry.provider_id not in providers:
|
||||
raise ValueError(
|
||||
f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`"
|
||||
)
|
||||
inner_specs.append(providers[rt_entry.provider_id])
|
||||
|
||||
specs[api] = RouterProviderSpec(
|
||||
api=api,
|
||||
module=f"llama_toolchain.{api.value.lower()}.router",
|
||||
api_dependencies=[],
|
||||
inner_specs=inner_specs,
|
||||
)
|
||||
|
||||
if isinstance(provider_spec, InlineProviderSpec):
|
||||
deps = {api: impls[api] for api in provider_spec.api_dependencies}
|
||||
else:
|
||||
deps = {}
|
||||
provider_config = provider_configs[api.value]
|
||||
impl = instantiate_provider(provider_spec, provider_config, deps)
|
||||
sorted_specs = topological_sort(specs.values())
|
||||
|
||||
impls = {}
|
||||
for spec in sorted_specs:
|
||||
api = spec.api
|
||||
|
||||
deps = {api: impls[api] for api in spec.api_dependencies}
|
||||
impl = await instantiate_provider(spec, deps, provider_map[api.value])
|
||||
impls[api] = impl
|
||||
|
||||
return impls
|
||||
return impls, specs
|
||||
|
||||
|
||||
def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
||||
with open(yaml_config, "r") as fp:
|
||||
config = yaml.safe_load(fp)
|
||||
config = StackRunConfig(**yaml.safe_load(fp))
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
all_endpoints = api_endpoints()
|
||||
all_providers = api_providers()
|
||||
|
||||
provider_specs = {}
|
||||
for api_str, provider_config in config["providers"].items():
|
||||
api = Api(api_str)
|
||||
providers = all_providers[api]
|
||||
provider_id = provider_config["provider_id"]
|
||||
if provider_id not in providers:
|
||||
raise ValueError(
|
||||
f"Unknown provider `{provider_id}` is not available for API `{api}`"
|
||||
)
|
||||
|
||||
provider_specs[api] = providers[provider_id]
|
||||
|
||||
impls = resolve_impls(provider_specs, config)
|
||||
impls, specs = asyncio.run(resolve_impls(config.provider_map))
|
||||
if Api.telemetry in impls:
|
||||
setup_logger(impls[Api.telemetry])
|
||||
|
||||
for provider_spec in provider_specs.values():
|
||||
api = provider_spec.api
|
||||
all_endpoints = api_endpoints()
|
||||
|
||||
apis_to_serve = config.apis_to_serve or list(config.provider_map.keys())
|
||||
for api_str in apis_to_serve:
|
||||
api = Api(api_str)
|
||||
endpoints = all_endpoints[api]
|
||||
impl = impls[api]
|
||||
|
||||
provider_spec = specs[api]
|
||||
if (
|
||||
isinstance(provider_spec, RemoteProviderSpec)
|
||||
and provider_spec.adapter is None
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue