Add a RoutableProvider protocol, support for multiple routing keys (#163)

* Update configure.py to use multiple routing keys for safety
* Refactor distribution/datatypes into a providers/datatypes
* Cleanup
This commit is contained in:
Ashwin Bharambe 2024-09-30 17:30:21 -07:00 committed by GitHub
parent 73decb3781
commit eb2d8a31a5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
24 changed files with 600 additions and 577 deletions

View file

@ -117,18 +117,18 @@ def configure_api_providers(
if api_str == "safety":
# TODO: add support for other safety providers, and simplify safety provider config
if p == "meta-reference":
for shield_type in MetaReferenceShieldType:
routing_entries.append(
RoutableProviderConfig(
routing_key=shield_type.value,
provider_id=p,
config=cfg.dict(),
)
routing_entries.append(
RoutableProviderConfig(
routing_key=[s.value for s in MetaReferenceShieldType],
provider_id=p,
config=cfg.dict(),
)
)
else:
cprint(
f"[WARN] Interactive configuration of safety provider {p} is not supported, please manually configure safety shields types in routing_table of run.yaml",
f"[WARN] Interactive configuration of safety provider {p} is not supported. Please look for `{routing_key}` in run.yaml and replace it appropriately.",
"yellow",
attrs=["bold"],
)
routing_entries.append(
RoutableProviderConfig(

View file

@ -5,228 +5,16 @@
# the root directory of this source tree.
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional, Protocol, Union
from llama_models.schema_utils import json_schema_type
from typing import Dict, List, Optional, Union
from pydantic import BaseModel, Field
@json_schema_type
class Api(Enum):
inference = "inference"
safety = "safety"
agents = "agents"
memory = "memory"
telemetry = "telemetry"
models = "models"
shields = "shields"
memory_banks = "memory_banks"
from llama_stack.providers.datatypes import * # noqa: F403
@json_schema_type
class ApiEndpoint(BaseModel):
route: str
method: str
name: str
@json_schema_type
class ProviderSpec(BaseModel):
api: Api
provider_id: str
config_class: str = Field(
...,
description="Fully-qualified classname of the config for this provider",
)
api_dependencies: List[Api] = Field(
default_factory=list,
description="Higher-level API surfaces may depend on other providers to provide their functionality",
)
class RoutingTable(Protocol):
def get_routing_keys(self) -> List[str]: ...
def get_provider_impl(self, routing_key: str) -> Any: ...
class GenericProviderConfig(BaseModel):
provider_id: str
config: Dict[str, Any]
class PlaceholderProviderConfig(BaseModel):
"""Placeholder provider config for API whose provider are defined in routing_table"""
providers: List[str]
class RoutableProviderConfig(GenericProviderConfig):
routing_key: str
# Example: /inference, /safety
@json_schema_type
class AutoRoutedProviderSpec(ProviderSpec):
provider_id: str = "router"
config_class: str = ""
docker_image: Optional[str] = None
routing_table_api: Api
module: str = Field(
...,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_router_impl(config, provider_specs, deps)`: returns the router implementation
""",
)
provider_data_validator: Optional[str] = Field(
default=None,
)
@property
def pip_packages(self) -> List[str]:
raise AssertionError("Should not be called on AutoRoutedProviderSpec")
# Example: /models, /shields
@json_schema_type
class RoutingTableProviderSpec(ProviderSpec):
provider_id: str = "routing_table"
config_class: str = ""
docker_image: Optional[str] = None
inner_specs: List[ProviderSpec]
module: str = Field(
...,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_router_impl(config, provider_specs, deps)`: returns the router implementation
""",
)
pip_packages: List[str] = Field(default_factory=list)
@json_schema_type
class AdapterSpec(BaseModel):
adapter_id: str = Field(
...,
description="Unique identifier for this adapter",
)
module: str = Field(
...,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_adapter_impl(config, deps)`: returns the adapter implementation
""",
)
pip_packages: List[str] = Field(
default_factory=list,
description="The pip dependencies needed for this implementation",
)
config_class: Optional[str] = Field(
default=None,
description="Fully-qualified classname of the config for this provider",
)
provider_data_validator: Optional[str] = Field(
default=None,
)
@json_schema_type
class InlineProviderSpec(ProviderSpec):
pip_packages: List[str] = Field(
default_factory=list,
description="The pip dependencies needed for this implementation",
)
docker_image: Optional[str] = Field(
default=None,
description="""
The docker image to use for this implementation. If one is provided, pip_packages will be ignored.
If a provider depends on other providers, the dependencies MUST NOT specify a docker image.
""",
)
module: str = Field(
...,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_provider_impl(config, deps)`: returns the local implementation
""",
)
provider_data_validator: Optional[str] = Field(
default=None,
)
class RemoteProviderConfig(BaseModel):
host: str = "localhost"
port: int
@property
def url(self) -> str:
return f"http://{self.host}:{self.port}"
def remote_provider_id(adapter_id: str) -> str:
return f"remote::{adapter_id}"
@json_schema_type
class RemoteProviderSpec(ProviderSpec):
adapter: Optional[AdapterSpec] = Field(
default=None,
description="""
If some code is needed to convert the remote responses into Llama Stack compatible
API responses, specify the adapter here. If not specified, it indicates the remote
as being "Llama Stack compatible"
""",
)
@property
def docker_image(self) -> Optional[str]:
return None
@property
def module(self) -> str:
if self.adapter:
return self.adapter.module
return f"llama_stack.apis.{self.api.value}.client"
@property
def pip_packages(self) -> List[str]:
if self.adapter:
return self.adapter.pip_packages
return []
@property
def provider_data_validator(self) -> Optional[str]:
if self.adapter:
return self.adapter.provider_data_validator
return None
# Can avoid this by using Pydantic computed_field
def remote_provider_spec(
api: Api, adapter: Optional[AdapterSpec] = None
) -> RemoteProviderSpec:
config_class = (
adapter.config_class
if adapter and adapter.config_class
else "llama_stack.distribution.datatypes.RemoteProviderConfig"
)
provider_id = remote_provider_id(adapter.adapter_id) if adapter else "remote"
return RemoteProviderSpec(
api=api, provider_id=provider_id, config_class=config_class, adapter=adapter
)
LLAMA_STACK_BUILD_CONFIG_VERSION = "v1"
LLAMA_STACK_RUN_CONFIG_VERSION = "v1"
@json_schema_type
@ -247,6 +35,7 @@ in the runtime configuration to help route to the correct provider.""",
@json_schema_type
class StackRunConfig(BaseModel):
version: str = LLAMA_STACK_RUN_CONFIG_VERSION
built_at: datetime
image_name: str = Field(
@ -295,6 +84,7 @@ Provider configurations for each of the APIs provided by this package.
@json_schema_type
class BuildConfig(BaseModel):
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION
name: str
distribution_spec: DistributionSpec = Field(
description="The distribution spec to build including API providers. "

View file

@ -23,7 +23,7 @@ class NeedsRequestProviderData:
if not validator_class:
raise ValueError(f"Provider {provider_id} does not have a validator")
val = _THREAD_LOCAL.provider_data_header_value
val = getattr(_THREAD_LOCAL, "provider_data_header_value", None)
if not val:
return None

View file

@ -0,0 +1,129 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Set
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.distribution import (
api_providers,
builtin_automatically_routed_apis,
)
from llama_stack.distribution.utils.dynamic import instantiate_provider
async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, Any]:
"""
Does two things:
- flatmaps, sorts and resolves the providers in dependency order
- for each API, produces either a (local, passthrough or router) implementation
"""
all_providers = api_providers()
specs = {}
configs = {}
for api_str, config in run_config.api_providers.items():
api = Api(api_str)
# TODO: check that these APIs are not in the routing table part of the config
providers = all_providers[api]
# skip checks for API whose provider config is specified in routing_table
if isinstance(config, PlaceholderProviderConfig):
continue
if config.provider_id not in providers:
raise ValueError(
f"Unknown provider `{config.provider_id}` is not available for API `{api}`"
)
specs[api] = providers[config.provider_id]
configs[api] = config
apis_to_serve = run_config.apis_to_serve or set(
list(specs.keys()) + list(run_config.routing_table.keys())
)
for info in builtin_automatically_routed_apis():
source_api = info.routing_table_api
assert (
source_api not in specs
), f"Routing table API {source_api} specified in wrong place?"
assert (
info.router_api not in specs
), f"Auto-routed API {info.router_api} specified in wrong place?"
if info.router_api.value not in apis_to_serve:
continue
print("router_api", info.router_api)
if info.router_api.value not in run_config.routing_table:
raise ValueError(f"Routing table for `{source_api.value}` is not provided?")
routing_table = run_config.routing_table[info.router_api.value]
providers = all_providers[info.router_api]
inner_specs = []
inner_deps = []
for rt_entry in routing_table:
if rt_entry.provider_id not in providers:
raise ValueError(
f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`"
)
inner_specs.append(providers[rt_entry.provider_id])
inner_deps.extend(providers[rt_entry.provider_id].api_dependencies)
specs[source_api] = RoutingTableProviderSpec(
api=source_api,
module="llama_stack.distribution.routers",
api_dependencies=inner_deps,
inner_specs=inner_specs,
)
configs[source_api] = routing_table
specs[info.router_api] = AutoRoutedProviderSpec(
api=info.router_api,
module="llama_stack.distribution.routers",
routing_table_api=source_api,
api_dependencies=[source_api],
)
configs[info.router_api] = {}
sorted_specs = topological_sort(specs.values())
print(f"Resolved {len(sorted_specs)} providers in topological order")
for spec in sorted_specs:
print(f" {spec.api}: {spec.provider_id}")
print("")
impls = {}
for spec in sorted_specs:
api = spec.api
deps = {api: impls[api] for api in spec.api_dependencies}
impl = await instantiate_provider(spec, deps, configs[api])
impls[api] = impl
return impls, specs
def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
by_id = {x.api: x for x in providers}
def dfs(a: ProviderSpec, visited: Set[Api], stack: List[Api]):
visited.add(a.api)
for api in a.api_dependencies:
if api not in visited:
dfs(by_id[api], visited, stack)
stack.append(a.api)
visited = set()
stack = []
for a in providers:
if a.api not in visited:
dfs(a, visited, stack)
return [by_id[x] for x in stack]

View file

@ -19,18 +19,35 @@ from llama_stack.distribution.datatypes import * # noqa: F403
class CommonRoutingTableImpl(RoutingTable):
def __init__(
self,
inner_impls: List[Tuple[str, Any]],
inner_impls: List[Tuple[RoutingKey, Any]],
routing_table_config: Dict[str, List[RoutableProviderConfig]],
) -> None:
self.providers = {k: v for k, v in inner_impls}
self.routing_keys = list(self.providers.keys())
self.unique_providers = []
self.providers = {}
self.routing_keys = []
for key, impl in inner_impls:
keys = key if isinstance(key, list) else [key]
self.unique_providers.append((keys, impl))
for k in keys:
if k in self.providers:
raise ValueError(f"Duplicate routing key {k}")
self.providers[k] = impl
self.routing_keys.append(k)
self.routing_table_config = routing_table_config
async def initialize(self) -> None:
pass
for keys, p in self.unique_providers:
spec = p.__provider_spec__
if isinstance(spec, RemoteProviderSpec) and spec.adapter is None:
continue
await p.validate_routing_keys(keys)
async def shutdown(self) -> None:
for p in self.providers.values():
for _, p in self.unique_providers:
await p.shutdown()
def get_provider_impl(self, routing_key: str) -> Any:

View file

@ -17,16 +17,7 @@ from collections.abc import (
from contextlib import asynccontextmanager
from http import HTTPStatus
from ssl import SSLError
from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Dict,
get_type_hints,
List,
Optional,
Set,
)
from typing import Any, AsyncGenerator, AsyncIterator, Dict, get_type_hints, Optional
import fire
import httpx
@ -48,13 +39,9 @@ from llama_stack.providers.utils.telemetry.tracing import (
)
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.distribution import (
api_endpoints,
api_providers,
builtin_automatically_routed_apis,
)
from llama_stack.distribution.distribution import api_endpoints
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.utils.dynamic import instantiate_provider
from llama_stack.distribution.resolver import resolve_impls_with_routing
def is_async_iterator_type(typ):
@ -289,125 +276,6 @@ def create_dynamic_typed_route(func: Any, method: str):
return endpoint
def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
by_id = {x.api: x for x in providers}
def dfs(a: ProviderSpec, visited: Set[Api], stack: List[Api]):
visited.add(a.api)
for api in a.api_dependencies:
if api not in visited:
dfs(by_id[api], visited, stack)
stack.append(a.api)
visited = set()
stack = []
for a in providers:
if a.api not in visited:
dfs(a, visited, stack)
return [by_id[x] for x in stack]
def snake_to_camel(snake_str):
return "".join(word.capitalize() for word in snake_str.split("_"))
async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, Any]:
"""
Does two things:
- flatmaps, sorts and resolves the providers in dependency order
- for each API, produces either a (local, passthrough or router) implementation
"""
all_providers = api_providers()
specs = {}
configs = {}
for api_str, config in run_config.api_providers.items():
api = Api(api_str)
# TODO: check that these APIs are not in the routing table part of the config
providers = all_providers[api]
# skip checks for API whose provider config is specified in routing_table
if isinstance(config, PlaceholderProviderConfig):
continue
if config.provider_id not in providers:
raise ValueError(
f"Unknown provider `{config.provider_id}` is not available for API `{api}`"
)
specs[api] = providers[config.provider_id]
configs[api] = config
apis_to_serve = run_config.apis_to_serve or set(
list(specs.keys()) + list(run_config.routing_table.keys())
)
for info in builtin_automatically_routed_apis():
source_api = info.routing_table_api
assert (
source_api not in specs
), f"Routing table API {source_api} specified in wrong place?"
assert (
info.router_api not in specs
), f"Auto-routed API {info.router_api} specified in wrong place?"
if info.router_api.value not in apis_to_serve:
continue
print("router_api", info.router_api)
if info.router_api.value not in run_config.routing_table:
raise ValueError(f"Routing table for `{source_api.value}` is not provided?")
routing_table = run_config.routing_table[info.router_api.value]
providers = all_providers[info.router_api]
inner_specs = []
inner_deps = []
for rt_entry in routing_table:
if rt_entry.provider_id not in providers:
raise ValueError(
f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`"
)
inner_specs.append(providers[rt_entry.provider_id])
inner_deps.extend(providers[rt_entry.provider_id].api_dependencies)
specs[source_api] = RoutingTableProviderSpec(
api=source_api,
module="llama_stack.distribution.routers",
api_dependencies=inner_deps,
inner_specs=inner_specs,
)
configs[source_api] = routing_table
specs[info.router_api] = AutoRoutedProviderSpec(
api=info.router_api,
module="llama_stack.distribution.routers",
routing_table_api=source_api,
api_dependencies=[source_api],
)
configs[info.router_api] = {}
sorted_specs = topological_sort(specs.values())
print(f"Resolved {len(sorted_specs)} providers in topological order")
for spec in sorted_specs:
print(f" {spec.api}: {spec.provider_id}")
print("")
impls = {}
for spec in sorted_specs:
api = spec.api
deps = {api: impls[api] for api in spec.api_dependencies}
impl = await instantiate_provider(spec, deps, configs[api])
impls[api] = impl
return impls, specs
def main(
yaml_config: str = "llamastack-run.yaml",
port: int = 5000,

View file

@ -42,22 +42,7 @@ routing_table:
config:
llama_guard_shield: null
prompt_guard_shield: null
routing_key: llama_guard
- provider_id: meta-reference
config:
llama_guard_shield: null
prompt_guard_shield: null
routing_key: code_scanner_guard
- provider_id: meta-reference
config:
llama_guard_shield: null
prompt_guard_shield: null
routing_key: injection_shield
- provider_id: meta-reference
config:
llama_guard_shield: null
prompt_guard_shield: null
routing_key: jailbreak_shield
routing_key: ["llama_guard", "code_scanner_guard", "injection_shield", "jailbreak_shield"]
memory:
- provider_id: meta-reference
config: {}

View file

@ -45,22 +45,7 @@ routing_table:
config:
llama_guard_shield: null
prompt_guard_shield: null
routing_key: llama_guard
- provider_id: meta-reference
config:
llama_guard_shield: null
prompt_guard_shield: null
routing_key: code_scanner_guard
- provider_id: meta-reference
config:
llama_guard_shield: null
prompt_guard_shield: null
routing_key: injection_shield
- provider_id: meta-reference
config:
llama_guard_shield: null
prompt_guard_shield: null
routing_key: jailbreak_shield
routing_key: ["llama_guard", "code_scanner_guard", "injection_shield", "jailbreak_shield"]
memory:
- provider_id: meta-reference
config: {}