refactor a method out

This commit is contained in:
Ashwin Bharambe 2024-08-05 13:14:15 -07:00
parent 125fdb1b2a
commit db3e6dda07

View file

@ -36,7 +36,7 @@ from fastapi.routing import APIRoute
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
from termcolor import cprint from termcolor import cprint
from .datatypes import Adapter, Api, PassthroughApiAdapter from .datatypes import Adapter, Api, Distribution, PassthroughApiAdapter
from .distribution import api_endpoints from .distribution import api_endpoints
from .dynamic import instantiate_adapter, instantiate_client from .dynamic import instantiate_adapter, instantiate_client
@ -250,6 +250,29 @@ def topological_sort(adapters: List[Adapter]) -> List[Adapter]:
return [by_id[x] for x in stack] return [by_id[x] for x in stack]
def resolve_impls(dist: Distribution, config: Dict[str, Any]) -> Dict[Api, Any]:
adapter_configs = config["adapters"]
adapters = topological_sort(dist.adapters.values())
impls = {}
for adapter in adapters:
api = adapter.api
if api.value not in adapter_configs:
raise ValueError(
f"Could not find adapter config for {api}. Please add it to the config"
)
adapter_config = adapter_configs[api.value]
if isinstance(adapter, PassthroughApiAdapter):
impls[api] = instantiate_client(adapter, adapter.base_url.rstrip("/"))
else:
deps = {api: impls[api] for api in adapter.adapter_dependencies}
impl = instantiate_adapter(adapter, adapter_config, deps)
impls[api] = impl
return impls
def main( def main(
dist_name: str, yaml_config: str, port: int = 5000, disable_ipv6: bool = False dist_name: str, yaml_config: str, port: int = 5000, disable_ipv6: bool = False
): ):
@ -263,21 +286,10 @@ def main(
app = FastAPI() app = FastAPI()
all_endpoints = api_endpoints() all_endpoints = api_endpoints()
impls = resolve_impls(dist, config)
adapter_configs = config["adapters"] for adapter in dist.adapters.values():
adapters = topological_sort(dist.adapters.values())
# TODO: split this into two parts, first you resolve all impls
# and then you create the routes.
impls = {}
for adapter in adapters:
api = adapter.api api = adapter.api
if api.value not in adapter_configs:
raise ValueError(
f"Could not find adapter config for {api}. Please add it to the config"
)
adapter_config = adapter_configs[api.value]
endpoints = all_endpoints[api] endpoints = all_endpoints[api]
if isinstance(adapter, PassthroughApiAdapter): if isinstance(adapter, PassthroughApiAdapter):
for endpoint in endpoints: for endpoint in endpoints:
@ -285,11 +297,8 @@ def main(
getattr(app, endpoint.method)(endpoint.route)( getattr(app, endpoint.method)(endpoint.route)(
create_dynamic_passthrough(url) create_dynamic_passthrough(url)
) )
impls[api] = instantiate_client(adapter, adapter.base_url.rstrip("/"))
else: else:
deps = {api: impls[api] for api in adapter.adapter_dependencies} impl = impls[api]
impl = instantiate_adapter(adapter, adapter_config, deps)
impls[api] = impl
for endpoint in endpoints: for endpoint in endpoints:
if not hasattr(impl, endpoint.name): if not hasattr(impl, endpoint.name):
# ideally this should be a typing violation already # ideally this should be a typing violation already