forked from phoenix-oss/llama-stack-mirror
Refactoring distribution/distribution.py
This file was becoming too large and unclear what it housed. Split it into pieces.
This commit is contained in:
parent
546f05bd3f
commit
df68db644b
9 changed files with 89 additions and 78 deletions
|
@ -175,7 +175,7 @@ class StackBuild(Subcommand):
|
||||||
import yaml
|
import yaml
|
||||||
from llama_stack.distribution.distribution import (
|
from llama_stack.distribution.distribution import (
|
||||||
Api,
|
Api,
|
||||||
api_providers,
|
get_provider_registry,
|
||||||
builtin_automatically_routed_apis,
|
builtin_automatically_routed_apis,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||||
|
@ -245,7 +245,7 @@ class StackBuild(Subcommand):
|
||||||
)
|
)
|
||||||
|
|
||||||
providers = dict()
|
providers = dict()
|
||||||
all_providers = api_providers()
|
all_providers = get_provider_registry()
|
||||||
routing_table_apis = set(
|
routing_table_apis = set(
|
||||||
x.routing_table_api for x in builtin_automatically_routed_apis()
|
x.routing_table_api for x in builtin_automatically_routed_apis()
|
||||||
)
|
)
|
||||||
|
|
|
@ -34,9 +34,9 @@ class StackListProviders(Subcommand):
|
||||||
|
|
||||||
def _run_providers_list_cmd(self, args: argparse.Namespace) -> None:
|
def _run_providers_list_cmd(self, args: argparse.Namespace) -> None:
|
||||||
from llama_stack.cli.table import print_table
|
from llama_stack.cli.table import print_table
|
||||||
from llama_stack.distribution.distribution import Api, api_providers
|
from llama_stack.distribution.distribution import Api, get_provider_registry
|
||||||
|
|
||||||
all_providers = api_providers()
|
all_providers = get_provider_registry()
|
||||||
providers_for_api = all_providers[Api(args.api)]
|
providers_for_api = all_providers[Api(args.api)]
|
||||||
|
|
||||||
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
|
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
|
||||||
|
|
|
@ -17,7 +17,17 @@ from llama_stack.distribution.utils.exec import run_with_pty
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from llama_stack.distribution.distribution import api_providers, SERVER_DEPENDENCIES
|
from llama_stack.distribution.distribution import get_provider_registry
|
||||||
|
|
||||||
|
|
||||||
|
# These are the dependencies needed by the distribution server.
|
||||||
|
# `llama-stack` is automatically installed by the installation script.
|
||||||
|
SERVER_DEPENDENCIES = [
|
||||||
|
"fastapi",
|
||||||
|
"fire",
|
||||||
|
"httpx",
|
||||||
|
"uvicorn",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class ImageType(Enum):
|
class ImageType(Enum):
|
||||||
|
@ -42,7 +52,7 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
|
||||||
)
|
)
|
||||||
|
|
||||||
# extend package dependencies based on providers spec
|
# extend package dependencies based on providers spec
|
||||||
all_providers = api_providers()
|
all_providers = get_provider_registry()
|
||||||
for (
|
for (
|
||||||
api_str,
|
api_str,
|
||||||
provider_or_providers,
|
provider_or_providers,
|
||||||
|
|
|
@ -15,8 +15,8 @@ from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack.apis.memory.memory import MemoryBankType
|
from llama_stack.apis.memory.memory import MemoryBankType
|
||||||
from llama_stack.distribution.distribution import (
|
from llama_stack.distribution.distribution import (
|
||||||
api_providers,
|
|
||||||
builtin_automatically_routed_apis,
|
builtin_automatically_routed_apis,
|
||||||
|
get_provider_registry,
|
||||||
stack_apis,
|
stack_apis,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||||
|
@ -62,7 +62,7 @@ def configure_api_providers(
|
||||||
config.apis_to_serve = list(set([a for a in apis if a != "telemetry"]))
|
config.apis_to_serve = list(set([a for a in apis if a != "telemetry"]))
|
||||||
|
|
||||||
apis = [v.value for v in stack_apis()]
|
apis = [v.value for v in stack_apis()]
|
||||||
all_providers = api_providers()
|
all_providers = get_provider_registry()
|
||||||
|
|
||||||
# configure simple case for with non-routing providers to api_providers
|
# configure simple case for with non-routing providers to api_providers
|
||||||
for api_str in spec.providers.keys():
|
for api_str in spec.providers.keys():
|
||||||
|
|
|
@ -5,30 +5,11 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.agents import Agents
|
from llama_stack.providers.datatypes import Api, ProviderSpec, remote_provider_spec
|
||||||
from llama_stack.apis.inference import Inference
|
|
||||||
from llama_stack.apis.memory import Memory
|
|
||||||
from llama_stack.apis.memory_banks import MemoryBanks
|
|
||||||
from llama_stack.apis.models import Models
|
|
||||||
from llama_stack.apis.safety import Safety
|
|
||||||
from llama_stack.apis.shields import Shields
|
|
||||||
from llama_stack.apis.telemetry import Telemetry
|
|
||||||
|
|
||||||
from .datatypes import Api, ApiEndpoint, ProviderSpec, remote_provider_spec
|
|
||||||
|
|
||||||
# These are the dependencies needed by the distribution server.
|
|
||||||
# `llama-stack` is automatically installed by the installation script.
|
|
||||||
SERVER_DEPENDENCIES = [
|
|
||||||
"fastapi",
|
|
||||||
"fire",
|
|
||||||
"httpx",
|
|
||||||
"uvicorn",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def stack_apis() -> List[Api]:
|
def stack_apis() -> List[Api]:
|
||||||
|
@ -57,45 +38,7 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]:
|
||||||
apis = {}
|
|
||||||
|
|
||||||
protocols = {
|
|
||||||
Api.inference: Inference,
|
|
||||||
Api.safety: Safety,
|
|
||||||
Api.agents: Agents,
|
|
||||||
Api.memory: Memory,
|
|
||||||
Api.telemetry: Telemetry,
|
|
||||||
Api.models: Models,
|
|
||||||
Api.shields: Shields,
|
|
||||||
Api.memory_banks: MemoryBanks,
|
|
||||||
}
|
|
||||||
|
|
||||||
for api, protocol in protocols.items():
|
|
||||||
endpoints = []
|
|
||||||
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
|
|
||||||
|
|
||||||
for name, method in protocol_methods:
|
|
||||||
if not hasattr(method, "__webmethod__"):
|
|
||||||
continue
|
|
||||||
|
|
||||||
webmethod = method.__webmethod__
|
|
||||||
route = webmethod.route
|
|
||||||
|
|
||||||
if webmethod.method == "GET":
|
|
||||||
method = "get"
|
|
||||||
elif webmethod.method == "DELETE":
|
|
||||||
method = "delete"
|
|
||||||
else:
|
|
||||||
method = "post"
|
|
||||||
endpoints.append(ApiEndpoint(route=route, method=method, name=name))
|
|
||||||
|
|
||||||
apis[api] = endpoints
|
|
||||||
|
|
||||||
return apis
|
|
||||||
|
|
||||||
|
|
||||||
def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
|
|
||||||
ret = {}
|
ret = {}
|
||||||
routing_table_apis = set(
|
routing_table_apis = set(
|
||||||
x.routing_table_api for x in builtin_automatically_routed_apis()
|
x.routing_table_api for x in builtin_automatically_routed_apis()
|
||||||
|
|
|
@ -8,8 +8,8 @@ from typing import Any, Dict, List, Set
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
from llama_stack.distribution.distribution import (
|
from llama_stack.distribution.distribution import (
|
||||||
api_providers,
|
|
||||||
builtin_automatically_routed_apis,
|
builtin_automatically_routed_apis,
|
||||||
|
get_provider_registry,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.utils.dynamic import instantiate_provider
|
from llama_stack.distribution.utils.dynamic import instantiate_provider
|
||||||
|
|
||||||
|
@ -20,7 +20,7 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
|
||||||
- flatmaps, sorts and resolves the providers in dependency order
|
- flatmaps, sorts and resolves the providers in dependency order
|
||||||
- for each API, produces either a (local, passthrough or router) implementation
|
- for each API, produces either a (local, passthrough or router) implementation
|
||||||
"""
|
"""
|
||||||
all_providers = api_providers()
|
all_providers = get_provider_registry()
|
||||||
specs = {}
|
specs = {}
|
||||||
configs = {}
|
configs = {}
|
||||||
|
|
||||||
|
|
64
llama_stack/distribution/server/endpoints.py
Normal file
64
llama_stack/distribution/server/endpoints.py
Normal file
|
@ -0,0 +1,64 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from llama_stack.apis.agents import Agents
|
||||||
|
from llama_stack.apis.inference import Inference
|
||||||
|
from llama_stack.apis.memory import Memory
|
||||||
|
from llama_stack.apis.memory_banks import MemoryBanks
|
||||||
|
from llama_stack.apis.models import Models
|
||||||
|
from llama_stack.apis.safety import Safety
|
||||||
|
from llama_stack.apis.shields import Shields
|
||||||
|
from llama_stack.apis.telemetry import Telemetry
|
||||||
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
|
|
||||||
|
class ApiEndpoint(BaseModel):
|
||||||
|
route: str
|
||||||
|
method: str
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
||||||
|
apis = {}
|
||||||
|
|
||||||
|
protocols = {
|
||||||
|
Api.inference: Inference,
|
||||||
|
Api.safety: Safety,
|
||||||
|
Api.agents: Agents,
|
||||||
|
Api.memory: Memory,
|
||||||
|
Api.telemetry: Telemetry,
|
||||||
|
Api.models: Models,
|
||||||
|
Api.shields: Shields,
|
||||||
|
Api.memory_banks: MemoryBanks,
|
||||||
|
}
|
||||||
|
|
||||||
|
for api, protocol in protocols.items():
|
||||||
|
endpoints = []
|
||||||
|
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
|
||||||
|
|
||||||
|
for name, method in protocol_methods:
|
||||||
|
if not hasattr(method, "__webmethod__"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
webmethod = method.__webmethod__
|
||||||
|
route = webmethod.route
|
||||||
|
|
||||||
|
if webmethod.method == "GET":
|
||||||
|
method = "get"
|
||||||
|
elif webmethod.method == "DELETE":
|
||||||
|
method = "delete"
|
||||||
|
else:
|
||||||
|
method = "post"
|
||||||
|
endpoints.append(ApiEndpoint(route=route, method=method, name=name))
|
||||||
|
|
||||||
|
apis[api] = endpoints
|
||||||
|
|
||||||
|
return apis
|
|
@ -39,10 +39,11 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.distribution.distribution import api_endpoints
|
|
||||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||||
from llama_stack.distribution.resolver import resolve_impls_with_routing
|
from llama_stack.distribution.resolver import resolve_impls_with_routing
|
||||||
|
|
||||||
|
from .endpoints import get_all_api_endpoints
|
||||||
|
|
||||||
|
|
||||||
def is_async_iterator_type(typ):
|
def is_async_iterator_type(typ):
|
||||||
if hasattr(typ, "__origin__"):
|
if hasattr(typ, "__origin__"):
|
||||||
|
@ -299,7 +300,7 @@ def main(
|
||||||
if Api.telemetry in impls:
|
if Api.telemetry in impls:
|
||||||
setup_logger(impls[Api.telemetry])
|
setup_logger(impls[Api.telemetry])
|
||||||
|
|
||||||
all_endpoints = api_endpoints()
|
all_endpoints = get_all_api_endpoints()
|
||||||
|
|
||||||
if config.apis_to_serve:
|
if config.apis_to_serve:
|
||||||
apis_to_serve = set(config.apis_to_serve)
|
apis_to_serve = set(config.apis_to_serve)
|
||||||
|
|
|
@ -25,13 +25,6 @@ class Api(Enum):
|
||||||
memory_banks = "memory_banks"
|
memory_banks = "memory_banks"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ApiEndpoint(BaseModel):
|
|
||||||
route: str
|
|
||||||
method: str
|
|
||||||
name: str
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ProviderSpec(BaseModel):
|
class ProviderSpec(BaseModel):
|
||||||
api: Api
|
api: Api
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue