diff --git a/llama_toolchain/distribution/server.py b/llama_toolchain/distribution/server.py index fd49b7d70..857639551 100644 --- a/llama_toolchain/distribution/server.py +++ b/llama_toolchain/distribution/server.py @@ -36,7 +36,7 @@ from fastapi.routing import APIRoute from pydantic import BaseModel, ValidationError from termcolor import cprint -from .datatypes import Adapter, Api, PassthroughApiAdapter +from .datatypes import Adapter, Api, Distribution, PassthroughApiAdapter from .distribution import api_endpoints from .dynamic import instantiate_adapter, instantiate_client @@ -250,6 +250,29 @@ def topological_sort(adapters: List[Adapter]) -> List[Adapter]: return [by_id[x] for x in stack] +def resolve_impls(dist: Distribution, config: Dict[str, Any]) -> Dict[Api, Any]: + adapter_configs = config["adapters"] + adapters = topological_sort(dist.adapters.values()) + + impls = {} + for adapter in adapters: + api = adapter.api + if api.value not in adapter_configs: + raise ValueError( + f"Could not find adapter config for {api}. Please add it to the config" + ) + + adapter_config = adapter_configs[api.value] + if isinstance(adapter, PassthroughApiAdapter): + impls[api] = instantiate_client(adapter, adapter.base_url.rstrip("/")) + else: + deps = {api: impls[api] for api in adapter.adapter_dependencies} + impl = instantiate_adapter(adapter, adapter_config, deps) + impls[api] = impl + + return impls + + def main( dist_name: str, yaml_config: str, port: int = 5000, disable_ipv6: bool = False ): @@ -263,21 +286,10 @@ def main( app = FastAPI() all_endpoints = api_endpoints() + impls = resolve_impls(dist, config) - adapter_configs = config["adapters"] - adapters = topological_sort(dist.adapters.values()) - - # TODO: split this into two parts, first you resolve all impls - # and then you create the routes. - impls = {} - for adapter in adapters: + for adapter in dist.adapters.values(): api = adapter.api - if api.value not in adapter_configs: - raise ValueError( - f"Could not find adapter config for {api}. Please add it to the config" - ) - - adapter_config = adapter_configs[api.value] endpoints = all_endpoints[api] if isinstance(adapter, PassthroughApiAdapter): for endpoint in endpoints: @@ -285,11 +297,8 @@ def main( getattr(app, endpoint.method)(endpoint.route)( create_dynamic_passthrough(url) ) - impls[api] = instantiate_client(adapter, adapter.base_url.rstrip("/")) else: - deps = {api: impls[api] for api in adapter.adapter_dependencies} - impl = instantiate_adapter(adapter, adapter_config, deps) - impls[api] = impl + impl = impls[api] for endpoint in endpoints: if not hasattr(impl, endpoint.name): # ideally this should be a typing violation already