Refactoring distribution/distribution.py

This file was becoming too large and unclear what it housed. Split it
into pieces.
This commit is contained in:
Ashwin Bharambe 2024-10-02 13:20:17 -07:00
parent 546f05bd3f
commit df68db644b
9 changed files with 89 additions and 78 deletions

View file

@ -175,7 +175,7 @@ class StackBuild(Subcommand):
import yaml
from llama_stack.distribution.distribution import (
Api,
api_providers,
get_provider_registry,
builtin_automatically_routed_apis,
)
from llama_stack.distribution.utils.dynamic import instantiate_class_type
@ -245,7 +245,7 @@ class StackBuild(Subcommand):
)
providers = dict()
all_providers = api_providers()
all_providers = get_provider_registry()
routing_table_apis = set(
x.routing_table_api for x in builtin_automatically_routed_apis()
)

View file

@ -34,9 +34,9 @@ class StackListProviders(Subcommand):
def _run_providers_list_cmd(self, args: argparse.Namespace) -> None:
from llama_stack.cli.table import print_table
from llama_stack.distribution.distribution import Api, api_providers
from llama_stack.distribution.distribution import Api, get_provider_registry
all_providers = api_providers()
all_providers = get_provider_registry()
providers_for_api = all_providers[Api(args.api)]
# eventually, this should query a registry at llama.meta.com/llamastack/distributions

View file

@ -17,7 +17,17 @@ from llama_stack.distribution.utils.exec import run_with_pty
from llama_stack.distribution.datatypes import * # noqa: F403
from pathlib import Path
from llama_stack.distribution.distribution import api_providers, SERVER_DEPENDENCIES
from llama_stack.distribution.distribution import get_provider_registry
# These are the dependencies needed by the distribution server.
# `llama-stack` is automatically installed by the installation script.
SERVER_DEPENDENCIES = [
"fastapi",
"fire",
"httpx",
"uvicorn",
]
class ImageType(Enum):
@ -42,7 +52,7 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
)
# extend package dependencies based on providers spec
all_providers = api_providers()
all_providers = get_provider_registry()
for (
api_str,
provider_or_providers,

View file

@ -15,8 +15,8 @@ from termcolor import cprint
from llama_stack.apis.memory.memory import MemoryBankType
from llama_stack.distribution.distribution import (
api_providers,
builtin_automatically_routed_apis,
get_provider_registry,
stack_apis,
)
from llama_stack.distribution.utils.dynamic import instantiate_class_type
@ -62,7 +62,7 @@ def configure_api_providers(
config.apis_to_serve = list(set([a for a in apis if a != "telemetry"]))
apis = [v.value for v in stack_apis()]
all_providers = api_providers()
all_providers = get_provider_registry()
# configure simple case for with non-routing providers to api_providers
for api_str in spec.providers.keys():

View file

@ -5,30 +5,11 @@
# the root directory of this source tree.
import importlib
import inspect
from typing import Dict, List
from pydantic import BaseModel
from llama_stack.apis.agents import Agents
from llama_stack.apis.inference import Inference
from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.models import Models
from llama_stack.apis.safety import Safety
from llama_stack.apis.shields import Shields
from llama_stack.apis.telemetry import Telemetry
from .datatypes import Api, ApiEndpoint, ProviderSpec, remote_provider_spec
# These are the dependencies needed by the distribution server.
# `llama-stack` is automatically installed by the installation script.
SERVER_DEPENDENCIES = [
"fastapi",
"fire",
"httpx",
"uvicorn",
]
from llama_stack.providers.datatypes import Api, ProviderSpec, remote_provider_spec
def stack_apis() -> List[Api]:
@ -57,45 +38,7 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
]
def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
apis = {}
protocols = {
Api.inference: Inference,
Api.safety: Safety,
Api.agents: Agents,
Api.memory: Memory,
Api.telemetry: Telemetry,
Api.models: Models,
Api.shields: Shields,
Api.memory_banks: MemoryBanks,
}
for api, protocol in protocols.items():
endpoints = []
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
for name, method in protocol_methods:
if not hasattr(method, "__webmethod__"):
continue
webmethod = method.__webmethod__
route = webmethod.route
if webmethod.method == "GET":
method = "get"
elif webmethod.method == "DELETE":
method = "delete"
else:
method = "post"
endpoints.append(ApiEndpoint(route=route, method=method, name=name))
apis[api] = endpoints
return apis
def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]:
ret = {}
routing_table_apis = set(
x.routing_table_api for x in builtin_automatically_routed_apis()

View file

@ -8,8 +8,8 @@ from typing import Any, Dict, List, Set
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.distribution import (
api_providers,
builtin_automatically_routed_apis,
get_provider_registry,
)
from llama_stack.distribution.utils.dynamic import instantiate_provider
@ -20,7 +20,7 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
- flatmaps, sorts and resolves the providers in dependency order
- for each API, produces either a (local, passthrough or router) implementation
"""
all_providers = api_providers()
all_providers = get_provider_registry()
specs = {}
configs = {}

View file

@ -0,0 +1,64 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import inspect
from typing import Dict, List
from pydantic import BaseModel
from llama_stack.apis.agents import Agents
from llama_stack.apis.inference import Inference
from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.models import Models
from llama_stack.apis.safety import Safety
from llama_stack.apis.shields import Shields
from llama_stack.apis.telemetry import Telemetry
from llama_stack.providers.datatypes import Api
class ApiEndpoint(BaseModel):
route: str
method: str
name: str
def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
apis = {}
protocols = {
Api.inference: Inference,
Api.safety: Safety,
Api.agents: Agents,
Api.memory: Memory,
Api.telemetry: Telemetry,
Api.models: Models,
Api.shields: Shields,
Api.memory_banks: MemoryBanks,
}
for api, protocol in protocols.items():
endpoints = []
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
for name, method in protocol_methods:
if not hasattr(method, "__webmethod__"):
continue
webmethod = method.__webmethod__
route = webmethod.route
if webmethod.method == "GET":
method = "get"
elif webmethod.method == "DELETE":
method = "delete"
else:
method = "post"
endpoints.append(ApiEndpoint(route=route, method=method, name=name))
apis[api] = endpoints
return apis

View file

@ -39,10 +39,11 @@ from llama_stack.providers.utils.telemetry.tracing import (
)
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.distribution import api_endpoints
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import resolve_impls_with_routing
from .endpoints import get_all_api_endpoints
def is_async_iterator_type(typ):
if hasattr(typ, "__origin__"):
@ -299,7 +300,7 @@ def main(
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])
all_endpoints = api_endpoints()
all_endpoints = get_all_api_endpoints()
if config.apis_to_serve:
apis_to_serve = set(config.apis_to_serve)

View file

@ -25,13 +25,6 @@ class Api(Enum):
memory_banks = "memory_banks"
@json_schema_type
class ApiEndpoint(BaseModel):
route: str
method: str
name: str
@json_schema_type
class ProviderSpec(BaseModel):
api: Api