Refactoring distribution/distribution.py

This file was becoming too large and unclear what it housed. Split it
into pieces.
This commit is contained in:
Ashwin Bharambe 2024-10-02 13:20:17 -07:00
parent 546f05bd3f
commit df68db644b
9 changed files with 89 additions and 78 deletions

View file

@ -175,7 +175,7 @@ class StackBuild(Subcommand):
import yaml import yaml
from llama_stack.distribution.distribution import ( from llama_stack.distribution.distribution import (
Api, Api,
api_providers, get_provider_registry,
builtin_automatically_routed_apis, builtin_automatically_routed_apis,
) )
from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.dynamic import instantiate_class_type
@ -245,7 +245,7 @@ class StackBuild(Subcommand):
) )
providers = dict() providers = dict()
all_providers = api_providers() all_providers = get_provider_registry()
routing_table_apis = set( routing_table_apis = set(
x.routing_table_api for x in builtin_automatically_routed_apis() x.routing_table_api for x in builtin_automatically_routed_apis()
) )

View file

@ -34,9 +34,9 @@ class StackListProviders(Subcommand):
def _run_providers_list_cmd(self, args: argparse.Namespace) -> None: def _run_providers_list_cmd(self, args: argparse.Namespace) -> None:
from llama_stack.cli.table import print_table from llama_stack.cli.table import print_table
from llama_stack.distribution.distribution import Api, api_providers from llama_stack.distribution.distribution import Api, get_provider_registry
all_providers = api_providers() all_providers = get_provider_registry()
providers_for_api = all_providers[Api(args.api)] providers_for_api = all_providers[Api(args.api)]
# eventually, this should query a registry at llama.meta.com/llamastack/distributions # eventually, this should query a registry at llama.meta.com/llamastack/distributions

View file

@ -17,7 +17,17 @@ from llama_stack.distribution.utils.exec import run_with_pty
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
from pathlib import Path from pathlib import Path
from llama_stack.distribution.distribution import api_providers, SERVER_DEPENDENCIES from llama_stack.distribution.distribution import get_provider_registry
# These are the dependencies needed by the distribution server.
# `llama-stack` is automatically installed by the installation script.
SERVER_DEPENDENCIES = [
"fastapi",
"fire",
"httpx",
"uvicorn",
]
class ImageType(Enum): class ImageType(Enum):
@ -42,7 +52,7 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
) )
# extend package dependencies based on providers spec # extend package dependencies based on providers spec
all_providers = api_providers() all_providers = get_provider_registry()
for ( for (
api_str, api_str,
provider_or_providers, provider_or_providers,

View file

@ -15,8 +15,8 @@ from termcolor import cprint
from llama_stack.apis.memory.memory import MemoryBankType from llama_stack.apis.memory.memory import MemoryBankType
from llama_stack.distribution.distribution import ( from llama_stack.distribution.distribution import (
api_providers,
builtin_automatically_routed_apis, builtin_automatically_routed_apis,
get_provider_registry,
stack_apis, stack_apis,
) )
from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.dynamic import instantiate_class_type
@ -62,7 +62,7 @@ def configure_api_providers(
config.apis_to_serve = list(set([a for a in apis if a != "telemetry"])) config.apis_to_serve = list(set([a for a in apis if a != "telemetry"]))
apis = [v.value for v in stack_apis()] apis = [v.value for v in stack_apis()]
all_providers = api_providers() all_providers = get_provider_registry()
# configure simple case for with non-routing providers to api_providers # configure simple case for with non-routing providers to api_providers
for api_str in spec.providers.keys(): for api_str in spec.providers.keys():

View file

@ -5,30 +5,11 @@
# the root directory of this source tree. # the root directory of this source tree.
import importlib import importlib
import inspect
from typing import Dict, List from typing import Dict, List
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.apis.agents import Agents from llama_stack.providers.datatypes import Api, ProviderSpec, remote_provider_spec
from llama_stack.apis.inference import Inference
from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.models import Models
from llama_stack.apis.safety import Safety
from llama_stack.apis.shields import Shields
from llama_stack.apis.telemetry import Telemetry
from .datatypes import Api, ApiEndpoint, ProviderSpec, remote_provider_spec
# These are the dependencies needed by the distribution server.
# `llama-stack` is automatically installed by the installation script.
SERVER_DEPENDENCIES = [
"fastapi",
"fire",
"httpx",
"uvicorn",
]
def stack_apis() -> List[Api]: def stack_apis() -> List[Api]:
@ -57,45 +38,7 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
] ]
def api_endpoints() -> Dict[Api, List[ApiEndpoint]]: def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]:
apis = {}
protocols = {
Api.inference: Inference,
Api.safety: Safety,
Api.agents: Agents,
Api.memory: Memory,
Api.telemetry: Telemetry,
Api.models: Models,
Api.shields: Shields,
Api.memory_banks: MemoryBanks,
}
for api, protocol in protocols.items():
endpoints = []
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
for name, method in protocol_methods:
if not hasattr(method, "__webmethod__"):
continue
webmethod = method.__webmethod__
route = webmethod.route
if webmethod.method == "GET":
method = "get"
elif webmethod.method == "DELETE":
method = "delete"
else:
method = "post"
endpoints.append(ApiEndpoint(route=route, method=method, name=name))
apis[api] = endpoints
return apis
def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
ret = {} ret = {}
routing_table_apis = set( routing_table_apis = set(
x.routing_table_api for x in builtin_automatically_routed_apis() x.routing_table_api for x in builtin_automatically_routed_apis()

View file

@ -8,8 +8,8 @@ from typing import Any, Dict, List, Set
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.distribution import ( from llama_stack.distribution.distribution import (
api_providers,
builtin_automatically_routed_apis, builtin_automatically_routed_apis,
get_provider_registry,
) )
from llama_stack.distribution.utils.dynamic import instantiate_provider from llama_stack.distribution.utils.dynamic import instantiate_provider
@ -20,7 +20,7 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
- flatmaps, sorts and resolves the providers in dependency order - flatmaps, sorts and resolves the providers in dependency order
- for each API, produces either a (local, passthrough or router) implementation - for each API, produces either a (local, passthrough or router) implementation
""" """
all_providers = api_providers() all_providers = get_provider_registry()
specs = {} specs = {}
configs = {} configs = {}

View file

@ -0,0 +1,64 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import inspect
from typing import Dict, List
from pydantic import BaseModel
from llama_stack.apis.agents import Agents
from llama_stack.apis.inference import Inference
from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.models import Models
from llama_stack.apis.safety import Safety
from llama_stack.apis.shields import Shields
from llama_stack.apis.telemetry import Telemetry
from llama_stack.providers.datatypes import Api
class ApiEndpoint(BaseModel):
route: str
method: str
name: str
def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
apis = {}
protocols = {
Api.inference: Inference,
Api.safety: Safety,
Api.agents: Agents,
Api.memory: Memory,
Api.telemetry: Telemetry,
Api.models: Models,
Api.shields: Shields,
Api.memory_banks: MemoryBanks,
}
for api, protocol in protocols.items():
endpoints = []
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
for name, method in protocol_methods:
if not hasattr(method, "__webmethod__"):
continue
webmethod = method.__webmethod__
route = webmethod.route
if webmethod.method == "GET":
method = "get"
elif webmethod.method == "DELETE":
method = "delete"
else:
method = "post"
endpoints.append(ApiEndpoint(route=route, method=method, name=name))
apis[api] = endpoints
return apis

View file

@ -39,10 +39,11 @@ from llama_stack.providers.utils.telemetry.tracing import (
) )
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.distribution import api_endpoints
from llama_stack.distribution.request_headers import set_request_provider_data from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import resolve_impls_with_routing from llama_stack.distribution.resolver import resolve_impls_with_routing
from .endpoints import get_all_api_endpoints
def is_async_iterator_type(typ): def is_async_iterator_type(typ):
if hasattr(typ, "__origin__"): if hasattr(typ, "__origin__"):
@ -299,7 +300,7 @@ def main(
if Api.telemetry in impls: if Api.telemetry in impls:
setup_logger(impls[Api.telemetry]) setup_logger(impls[Api.telemetry])
all_endpoints = api_endpoints() all_endpoints = get_all_api_endpoints()
if config.apis_to_serve: if config.apis_to_serve:
apis_to_serve = set(config.apis_to_serve) apis_to_serve = set(config.apis_to_serve)

View file

@ -25,13 +25,6 @@ class Api(Enum):
memory_banks = "memory_banks" memory_banks = "memory_banks"
@json_schema_type
class ApiEndpoint(BaseModel):
route: str
method: str
name: str
@json_schema_type @json_schema_type
class ProviderSpec(BaseModel): class ProviderSpec(BaseModel):
api: Api api: Api