Adapter -> Provider

This commit is contained in:
Ashwin Bharambe 2024-08-05 13:26:29 -07:00
parent db3e6dda07
commit 65a9e40174
15 changed files with 119 additions and 110 deletions

View file

@ -36,9 +36,9 @@ from fastapi.routing import APIRoute
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from .datatypes import Adapter, Api, Distribution, PassthroughApiAdapter
from .datatypes import Api, Distribution, ProviderSpec, RemoteProviderSpec
from .distribution import api_endpoints
from .dynamic import instantiate_adapter, instantiate_client
from .dynamic import instantiate_client, instantiate_provider
from .registry import resolve_distribution
@ -226,15 +226,15 @@ def create_dynamic_typed_route(func: Any):
return endpoint
def topological_sort(adapters: List[Adapter]) -> List[Adapter]:
def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
by_id = {x.api: x for x in adapters}
by_id = {x.api: x for x in providers}
def dfs(a: Adapter, visited: Set[Api], stack: List[Api]):
def dfs(a: ProviderSpec, visited: Set[Api], stack: List[Api]):
visited.add(a.api)
if not isinstance(a, PassthroughApiAdapter):
for api in a.adapter_dependencies:
if not isinstance(a, RemoteProviderSpec):
for api in a.api_dependencies:
if api not in visited:
dfs(by_id[api], visited, stack)
@ -243,7 +243,7 @@ def topological_sort(adapters: List[Adapter]) -> List[Adapter]:
visited = set()
stack = []
for a in adapters:
for a in providers:
if a.api not in visited:
dfs(a, visited, stack)
@ -251,23 +251,25 @@ def topological_sort(adapters: List[Adapter]) -> List[Adapter]:
def resolve_impls(dist: Distribution, config: Dict[str, Any]) -> Dict[Api, Any]:
adapter_configs = config["adapters"]
adapters = topological_sort(dist.adapters.values())
provider_configs = config["providers"]
provider_specs = topological_sort(dist.provider_specs.values())
impls = {}
for adapter in adapters:
api = adapter.api
if api.value not in adapter_configs:
for provider_spec in provider_specs:
api = provider_spec.api
if api.value not in provider_configs:
raise ValueError(
f"Could not find adapter config for {api}. Please add it to the config"
f"Could not find provider_spec config for {api}. Please add it to the config"
)
adapter_config = adapter_configs[api.value]
if isinstance(adapter, PassthroughApiAdapter):
impls[api] = instantiate_client(adapter, adapter.base_url.rstrip("/"))
provider_config = provider_configs[api.value]
if isinstance(provider_spec, RemoteProviderSpec):
impls[api] = instantiate_client(
provider_spec, provider_spec.base_url.rstrip("/")
)
else:
deps = {api: impls[api] for api in adapter.adapter_dependencies}
impl = instantiate_adapter(adapter, adapter_config, deps)
deps = {api: impls[api] for api in provider_spec.api_dependencies}
impl = instantiate_provider(provider_spec, provider_config, deps)
impls[api] = impl
return impls
@ -288,12 +290,12 @@ def main(
all_endpoints = api_endpoints()
impls = resolve_impls(dist, config)
for adapter in dist.adapters.values():
api = adapter.api
for provider_spec in dist.provider_specs.values():
api = provider_spec.api
endpoints = all_endpoints[api]
if isinstance(adapter, PassthroughApiAdapter):
if isinstance(provider_spec, RemoteProviderSpec):
for endpoint in endpoints:
url = adapter.base_url.rstrip("/") + endpoint.route
url = provider_spec.base_url.rstrip("/") + endpoint.route
getattr(app, endpoint.method)(endpoint.route)(
create_dynamic_passthrough(url)
)