refactor a method out

This commit is contained in:
Ashwin Bharambe 2024-08-05 13:14:15 -07:00
parent 125fdb1b2a
commit db3e6dda07

View file

@ -36,7 +36,7 @@ from fastapi.routing import APIRoute
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from .datatypes import Adapter, Api, PassthroughApiAdapter
from .datatypes import Adapter, Api, Distribution, PassthroughApiAdapter
from .distribution import api_endpoints
from .dynamic import instantiate_adapter, instantiate_client
@ -250,6 +250,29 @@ def topological_sort(adapters: List[Adapter]) -> List[Adapter]:
return [by_id[x] for x in stack]
def resolve_impls(dist: Distribution, config: Dict[str, Any]) -> Dict[Api, Any]:
adapter_configs = config["adapters"]
adapters = topological_sort(dist.adapters.values())
impls = {}
for adapter in adapters:
api = adapter.api
if api.value not in adapter_configs:
raise ValueError(
f"Could not find adapter config for {api}. Please add it to the config"
)
adapter_config = adapter_configs[api.value]
if isinstance(adapter, PassthroughApiAdapter):
impls[api] = instantiate_client(adapter, adapter.base_url.rstrip("/"))
else:
deps = {api: impls[api] for api in adapter.adapter_dependencies}
impl = instantiate_adapter(adapter, adapter_config, deps)
impls[api] = impl
return impls
def main(
dist_name: str, yaml_config: str, port: int = 5000, disable_ipv6: bool = False
):
@ -263,21 +286,10 @@ def main(
app = FastAPI()
all_endpoints = api_endpoints()
impls = resolve_impls(dist, config)
adapter_configs = config["adapters"]
adapters = topological_sort(dist.adapters.values())
# TODO: split this into two parts, first you resolve all impls
# and then you create the routes.
impls = {}
for adapter in adapters:
for adapter in dist.adapters.values():
api = adapter.api
if api.value not in adapter_configs:
raise ValueError(
f"Could not find adapter config for {api}. Please add it to the config"
)
adapter_config = adapter_configs[api.value]
endpoints = all_endpoints[api]
if isinstance(adapter, PassthroughApiAdapter):
for endpoint in endpoints:
@ -285,11 +297,8 @@ def main(
getattr(app, endpoint.method)(endpoint.route)(
create_dynamic_passthrough(url)
)
impls[api] = instantiate_client(adapter, adapter.base_url.rstrip("/"))
else:
deps = {api: impls[api] for api in adapter.adapter_dependencies}
impl = instantiate_adapter(adapter, adapter_config, deps)
impls[api] = impl
impl = impls[api]
for endpoint in endpoints:
if not hasattr(impl, endpoint.name):
# ideally this should be a typing violation already