Bring agentic system api to toolchain

Add adapter dependencies and resolve adapters using a topological sort
This commit is contained in:
Ashwin Bharambe 2024-08-04 10:53:38 -07:00
parent b0e5340645
commit be19b22391
31 changed files with 2740 additions and 25 deletions

View file

@ -15,6 +15,7 @@ from strong_typing.schema import json_schema_type
class ApiSurface(Enum):
inference = "inference"
safety = "safety"
agentic_system = "agentic_system"
@json_schema_type
@ -39,14 +40,19 @@ class SourceAdapter(Adapter):
module: str = Field(
...,
description="""
Fully-qualified name of the module to import. The module is expected to have
a `get_adapter_instance()` method which will be passed a validated config object
of type `config_class`.""",
Fully-qualified name of the module to import. The module is expected to have:
- `get_adapter_impl(config, deps)`: returns the local implementation
""",
)
config_class: str = Field(
...,
description="Fully-qualified classname of the config for this adapter",
)
adapter_dependencies: List[ApiSurface] = Field(
default_factory=list,
description="Higher-level API surfaces may depend on other adapters to provide their functionality",
)
@json_schema_type
@ -56,6 +62,13 @@ class PassthroughApiAdapter(Adapter):
default_factory=dict,
description="Headers (e.g., authorization) to send with the request",
)
module: str = Field(
...,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_client_impl(base_url)`: returns a client which can be used to call the remote implementation
""",
)
class Distribution(BaseModel):

View file

@ -7,6 +7,7 @@
import inspect
from typing import Dict, List
from llama_toolchain.agentic_system.api.endpoints import AgenticSystem
from llama_toolchain.inference.api.endpoints import Inference
from llama_toolchain.safety.api.endpoints import Safety
@ -29,6 +30,7 @@ def api_surface_endpoints() -> Dict[ApiSurface, List[ApiSurfaceEndpoint]]:
protocols = {
ApiSurface.inference: Inference,
ApiSurface.safety: Safety,
ApiSurface.agentic_system: AgenticSystem,
}
for surface, protocol in protocols.items():

View file

@ -8,7 +8,7 @@ import asyncio
import importlib
from typing import Any, Dict
from .datatypes import SourceAdapter
from .datatypes import Adapter, PassthroughApiAdapter, SourceAdapter
def instantiate_class_type(fully_qualified_name):
@ -18,9 +18,17 @@ def instantiate_class_type(fully_qualified_name):
# returns a class implementing the protocol corresponding to the ApiSurface
def instantiate_adapter(adapter: SourceAdapter, adapter_config: Dict[str, Any]):
def instantiate_adapter(
adapter: SourceAdapter, adapter_config: Dict[str, Any], deps: Dict[str, Adapter]
):
module = importlib.import_module(adapter.module)
config_type = instantiate_class_type(adapter.config_class)
config = config_type(**adapter_config)
return asyncio.run(module.get_adapter_impl(config))
return asyncio.run(module.get_adapter_impl(config, deps))
def instantiate_client(adapter: PassthroughApiAdapter, base_url: str):
module = importlib.import_module(adapter.module)
return asyncio.run(module.get_client_impl(base_url))

View file

@ -7,6 +7,8 @@
from functools import lru_cache
from typing import List, Optional
from llama_toolchain.agentic_system.adapters import available_agentic_system_adapters
from llama_toolchain.inference.adapters import available_inference_adapters
from llama_toolchain.safety.adapters import available_safety_adapters
@ -43,10 +45,26 @@ COMMON_DEPENDENCIES = [
]
def client_module(api_surface: ApiSurface) -> str:
return f"llama_toolchain.{api_surface.value}.client"
def passthrough(api_surface: ApiSurface, port: int) -> PassthroughApiAdapter:
return PassthroughApiAdapter(
api_surface=api_surface,
adapter_id=f"{api_surface.value}-passthrough",
base_url=f"http://localhost:{port}",
module=client_module(api_surface),
)
@lru_cache()
def available_distributions() -> List[Distribution]:
inference_adapters_by_id = {a.adapter_id: a for a in available_inference_adapters()}
safety_adapters_by_id = {a.adapter_id: a for a in available_safety_adapters()}
agentic_system_adapters_by_id = {
a.adapter_id: a for a in available_agentic_system_adapters()
}
return [
Distribution(
@ -56,6 +74,9 @@ def available_distributions() -> List[Distribution]:
adapters={
ApiSurface.inference: inference_adapters_by_id["meta-reference"],
ApiSurface.safety: safety_adapters_by_id["meta-reference"],
ApiSurface.agentic_system: agentic_system_adapters_by_id[
"meta-reference"
],
},
),
Distribution(
@ -76,16 +97,9 @@ def available_distributions() -> List[Distribution]:
"uvicorn",
],
adapters={
ApiSurface.inference: PassthroughApiAdapter(
api_surface=ApiSurface.inference,
adapter_id="inference-passthrough",
base_url="http://localhost:5001",
),
ApiSurface.safety: PassthroughApiAdapter(
api_surface=ApiSurface.safety,
adapter_id="safety-passthrough",
base_url="http://localhost:5001",
),
ApiSurface.inference: passthrough(ApiSurface.inference, 5001),
ApiSurface.safety: passthrough(ApiSurface.safety, 5001),
ApiSurface.agentic_system: passthrough(ApiSurface.agentic_system, 5001),
},
),
Distribution(
@ -95,6 +109,9 @@ def available_distributions() -> List[Distribution]:
adapters={
ApiSurface.inference: inference_adapters_by_id["meta-ollama"],
ApiSurface.safety: safety_adapters_by_id["meta-reference"],
ApiSurface.agentic_system: agentic_system_adapters_by_id[
"meta-reference"
],
},
),
]

View file

@ -12,7 +12,16 @@ from collections.abc import (
AsyncIterator as AsyncIteratorABC,
)
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, AsyncIterator, Dict, get_type_hints, Optional
from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Dict,
get_type_hints,
List,
Optional,
Set,
)
import fire
import httpx
@ -27,9 +36,9 @@ from fastapi.routing import APIRoute
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from .datatypes import PassthroughApiAdapter
from .datatypes import Adapter, ApiSurface, PassthroughApiAdapter
from .distribution import api_surface_endpoints
from .dynamic import instantiate_adapter
from .dynamic import instantiate_adapter, instantiate_client
from .registry import resolve_distribution
@ -213,6 +222,29 @@ def create_dynamic_typed_route(func: Any):
return endpoint
def topological_sort(adapters: List[Adapter]) -> List[Adapter]:
by_id = {x.api_surface: x for x in adapters}
def dfs(a: Adapter, visited: Set[ApiSurface], stack: List[ApiSurface]):
visited.add(a.api_surface)
for surface in a.adapter_dependencies:
if surface not in visited:
dfs(by_id[surface], visited, stack)
stack.append(a.api_surface)
visited = set()
stack = []
for a in adapters:
if a.api_surface not in visited:
dfs(a, visited, stack)
return [by_id[x] for x in stack]
def main(
dist_name: str, yaml_config: str, port: int = 5000, disable_ipv6: bool = False
):
@ -228,7 +260,13 @@ def main(
all_endpoints = api_surface_endpoints()
adapter_configs = config["adapters"]
for surface, adapter in dist.adapters.items():
adapters = topological_sort(dist.adapters.values())
# TODO: split this into two parts, first you resolve all impls
# and then you create the routes.
impls = {}
for adapter in adapters:
surface = adapter.api_surface
if surface.value not in adapter_configs:
raise ValueError(
f"Could not find adapter config for {surface}. Please add it to the config"
@ -242,8 +280,11 @@ def main(
getattr(app, endpoint.method)(endpoint.route)(
create_dynamic_passthrough(url)
)
impls[surface] = instantiate_client(adapter, adapter.base_url.rstrip("/"))
else:
impl = instantiate_adapter(adapter, adapter_config)
deps = {surface: impls[surface] for surface in adapter.adapter_dependencies}
impl = instantiate_adapter(adapter, adapter_config, deps)
impls[surface] = impl
for endpoint in endpoints:
if not hasattr(impl, endpoint.name):
# ideally this should be a typing violation already