Bring agentic system api to toolchain

Add adapter dependencies and resolve adapters using a topological sort
This commit is contained in:
Ashwin Bharambe 2024-08-04 10:53:38 -07:00
parent b0e5340645
commit be19b22391
31 changed files with 2740 additions and 25 deletions

View file

@ -12,7 +12,16 @@ from collections.abc import (
AsyncIterator as AsyncIteratorABC,
)
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, AsyncIterator, Dict, get_type_hints, Optional
from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Dict,
get_type_hints,
List,
Optional,
Set,
)
import fire
import httpx
@ -27,9 +36,9 @@ from fastapi.routing import APIRoute
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from .datatypes import PassthroughApiAdapter
from .datatypes import Adapter, ApiSurface, PassthroughApiAdapter
from .distribution import api_surface_endpoints
from .dynamic import instantiate_adapter
from .dynamic import instantiate_adapter, instantiate_client
from .registry import resolve_distribution
@ -213,6 +222,29 @@ def create_dynamic_typed_route(func: Any):
return endpoint
def topological_sort(adapters: List[Adapter]) -> List[Adapter]:
by_id = {x.api_surface: x for x in adapters}
def dfs(a: Adapter, visited: Set[ApiSurface], stack: List[ApiSurface]):
visited.add(a.api_surface)
for surface in a.adapter_dependencies:
if surface not in visited:
dfs(by_id[surface], visited, stack)
stack.append(a.api_surface)
visited = set()
stack = []
for a in adapters:
if a.api_surface not in visited:
dfs(a, visited, stack)
return [by_id[x] for x in stack]
def main(
dist_name: str, yaml_config: str, port: int = 5000, disable_ipv6: bool = False
):
@ -228,7 +260,13 @@ def main(
all_endpoints = api_surface_endpoints()
adapter_configs = config["adapters"]
for surface, adapter in dist.adapters.items():
adapters = topological_sort(dist.adapters.values())
# TODO: split this into two parts, first you resolve all impls
# and then you create the routes.
impls = {}
for adapter in adapters:
surface = adapter.api_surface
if surface.value not in adapter_configs:
raise ValueError(
f"Could not find adapter config for {surface}. Please add it to the config"
@ -242,8 +280,11 @@ def main(
getattr(app, endpoint.method)(endpoint.route)(
create_dynamic_passthrough(url)
)
impls[surface] = instantiate_client(adapter, adapter.base_url.rstrip("/"))
else:
impl = instantiate_adapter(adapter, adapter_config)
deps = {surface: impls[surface] for surface in adapter.adapter_dependencies}
impl = instantiate_adapter(adapter, adapter_config, deps)
impls[surface] = impl
for endpoint in endpoints:
if not hasattr(impl, endpoint.name):
# ideally this should be a typing violation already