diff --git a/llama_stack/cli/stack/build.py b/llama_stack/cli/stack/build.py index ef1f1a807..b7c25fa1b 100644 --- a/llama_stack/cli/stack/build.py +++ b/llama_stack/cli/stack/build.py @@ -175,7 +175,7 @@ class StackBuild(Subcommand): import yaml from llama_stack.distribution.distribution import ( Api, - api_providers, + get_provider_registry, builtin_automatically_routed_apis, ) from llama_stack.distribution.utils.dynamic import instantiate_class_type @@ -245,7 +245,7 @@ class StackBuild(Subcommand): ) providers = dict() - all_providers = api_providers() + all_providers = get_provider_registry() routing_table_apis = set( x.routing_table_api for x in builtin_automatically_routed_apis() ) diff --git a/llama_stack/cli/stack/list_providers.py b/llama_stack/cli/stack/list_providers.py index 18c4de201..25875ecbf 100644 --- a/llama_stack/cli/stack/list_providers.py +++ b/llama_stack/cli/stack/list_providers.py @@ -34,9 +34,9 @@ class StackListProviders(Subcommand): def _run_providers_list_cmd(self, args: argparse.Namespace) -> None: from llama_stack.cli.table import print_table - from llama_stack.distribution.distribution import Api, api_providers + from llama_stack.distribution.distribution import Api, get_provider_registry - all_providers = api_providers() + all_providers = get_provider_registry() providers_for_api = all_providers[Api(args.api)] # eventually, this should query a registry at llama.meta.com/llamastack/distributions diff --git a/llama_stack/distribution/build.py b/llama_stack/distribution/build.py index dabcad2a6..fe778bdb8 100644 --- a/llama_stack/distribution/build.py +++ b/llama_stack/distribution/build.py @@ -17,7 +17,17 @@ from llama_stack.distribution.utils.exec import run_with_pty from llama_stack.distribution.datatypes import * # noqa: F403 from pathlib import Path -from llama_stack.distribution.distribution import api_providers, SERVER_DEPENDENCIES +from llama_stack.distribution.distribution import get_provider_registry + + +# These are the dependencies needed by the distribution server. +# `llama-stack` is automatically installed by the installation script. +SERVER_DEPENDENCIES = [ + "fastapi", + "fire", + "httpx", + "uvicorn", +] class ImageType(Enum): @@ -42,7 +52,7 @@ def build_image(build_config: BuildConfig, build_file_path: Path): ) # extend package dependencies based on providers spec - all_providers = api_providers() + all_providers = get_provider_registry() for ( api_str, provider_or_providers, diff --git a/llama_stack/distribution/configure.py b/llama_stack/distribution/configure.py index d3b807d4a..e9b682dc0 100644 --- a/llama_stack/distribution/configure.py +++ b/llama_stack/distribution/configure.py @@ -15,8 +15,8 @@ from termcolor import cprint from llama_stack.apis.memory.memory import MemoryBankType from llama_stack.distribution.distribution import ( - api_providers, builtin_automatically_routed_apis, + get_provider_registry, stack_apis, ) from llama_stack.distribution.utils.dynamic import instantiate_class_type @@ -62,7 +62,7 @@ def configure_api_providers( config.apis_to_serve = list(set([a for a in apis if a != "telemetry"])) apis = [v.value for v in stack_apis()] - all_providers = api_providers() + all_providers = get_provider_registry() # configure simple case for with non-routing providers to api_providers for api_str in spec.providers.keys(): diff --git a/llama_stack/distribution/distribution.py b/llama_stack/distribution/distribution.py index 035febb80..0c47fd750 100644 --- a/llama_stack/distribution/distribution.py +++ b/llama_stack/distribution/distribution.py @@ -5,30 +5,11 @@ # the root directory of this source tree. import importlib -import inspect from typing import Dict, List from pydantic import BaseModel -from llama_stack.apis.agents import Agents -from llama_stack.apis.inference import Inference -from llama_stack.apis.memory import Memory -from llama_stack.apis.memory_banks import MemoryBanks -from llama_stack.apis.models import Models -from llama_stack.apis.safety import Safety -from llama_stack.apis.shields import Shields -from llama_stack.apis.telemetry import Telemetry - -from .datatypes import Api, ApiEndpoint, ProviderSpec, remote_provider_spec - -# These are the dependencies needed by the distribution server. -# `llama-stack` is automatically installed by the installation script. -SERVER_DEPENDENCIES = [ - "fastapi", - "fire", - "httpx", - "uvicorn", -] +from llama_stack.providers.datatypes import Api, ProviderSpec, remote_provider_spec def stack_apis() -> List[Api]: @@ -57,45 +38,7 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]: ] -def api_endpoints() -> Dict[Api, List[ApiEndpoint]]: - apis = {} - - protocols = { - Api.inference: Inference, - Api.safety: Safety, - Api.agents: Agents, - Api.memory: Memory, - Api.telemetry: Telemetry, - Api.models: Models, - Api.shields: Shields, - Api.memory_banks: MemoryBanks, - } - - for api, protocol in protocols.items(): - endpoints = [] - protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction) - - for name, method in protocol_methods: - if not hasattr(method, "__webmethod__"): - continue - - webmethod = method.__webmethod__ - route = webmethod.route - - if webmethod.method == "GET": - method = "get" - elif webmethod.method == "DELETE": - method = "delete" - else: - method = "post" - endpoints.append(ApiEndpoint(route=route, method=method, name=name)) - - apis[api] = endpoints - - return apis - - -def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]: +def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]: ret = {} routing_table_apis = set( x.routing_table_api for x in builtin_automatically_routed_apis() diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index f7d51c64a..8c8084969 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -8,8 +8,8 @@ from typing import Any, Dict, List, Set from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.distribution import ( - api_providers, builtin_automatically_routed_apis, + get_provider_registry, ) from llama_stack.distribution.utils.dynamic import instantiate_provider @@ -20,7 +20,7 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An - flatmaps, sorts and resolves the providers in dependency order - for each API, produces either a (local, passthrough or router) implementation """ - all_providers = api_providers() + all_providers = get_provider_registry() specs = {} configs = {} diff --git a/llama_stack/distribution/server/endpoints.py b/llama_stack/distribution/server/endpoints.py new file mode 100644 index 000000000..96de31c4b --- /dev/null +++ b/llama_stack/distribution/server/endpoints.py @@ -0,0 +1,64 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import inspect +from typing import Dict, List + +from pydantic import BaseModel + +from llama_stack.apis.agents import Agents +from llama_stack.apis.inference import Inference +from llama_stack.apis.memory import Memory +from llama_stack.apis.memory_banks import MemoryBanks +from llama_stack.apis.models import Models +from llama_stack.apis.safety import Safety +from llama_stack.apis.shields import Shields +from llama_stack.apis.telemetry import Telemetry +from llama_stack.providers.datatypes import Api + + +class ApiEndpoint(BaseModel): + route: str + method: str + name: str + + +def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]: + apis = {} + + protocols = { + Api.inference: Inference, + Api.safety: Safety, + Api.agents: Agents, + Api.memory: Memory, + Api.telemetry: Telemetry, + Api.models: Models, + Api.shields: Shields, + Api.memory_banks: MemoryBanks, + } + + for api, protocol in protocols.items(): + endpoints = [] + protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction) + + for name, method in protocol_methods: + if not hasattr(method, "__webmethod__"): + continue + + webmethod = method.__webmethod__ + route = webmethod.route + + if webmethod.method == "GET": + method = "get" + elif webmethod.method == "DELETE": + method = "delete" + else: + method = "post" + endpoints.append(ApiEndpoint(route=route, method=method, name=name)) + + apis[api] = endpoints + + return apis diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 16b1fb619..1ac1a1b16 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -39,10 +39,11 @@ from llama_stack.providers.utils.telemetry.tracing import ( ) from llama_stack.distribution.datatypes import * # noqa: F403 -from llama_stack.distribution.distribution import api_endpoints from llama_stack.distribution.request_headers import set_request_provider_data from llama_stack.distribution.resolver import resolve_impls_with_routing +from .endpoints import get_all_api_endpoints + def is_async_iterator_type(typ): if hasattr(typ, "__origin__"): @@ -299,7 +300,7 @@ def main( if Api.telemetry in impls: setup_logger(impls[Api.telemetry]) - all_endpoints = api_endpoints() + all_endpoints = get_all_api_endpoints() if config.apis_to_serve: apis_to_serve = set(config.apis_to_serve) diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index a9a3d86e9..d661b6649 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -25,13 +25,6 @@ class Api(Enum): memory_banks = "memory_banks" -@json_schema_type -class ApiEndpoint(BaseModel): - route: str - method: str - name: str - - @json_schema_type class ProviderSpec(BaseModel): api: Api