mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
Bring agentic system api to toolchain
Add adapter dependencies and resolve adapters using a topological sort
This commit is contained in:
parent
b0e5340645
commit
be19b22391
31 changed files with 2740 additions and 25 deletions
|
@ -12,7 +12,16 @@ from collections.abc import (
|
|||
AsyncIterator as AsyncIteratorABC,
|
||||
)
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, get_type_hints, Optional
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
get_type_hints,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
)
|
||||
|
||||
import fire
|
||||
import httpx
|
||||
|
@ -27,9 +36,9 @@ from fastapi.routing import APIRoute
|
|||
from pydantic import BaseModel, ValidationError
|
||||
from termcolor import cprint
|
||||
|
||||
from .datatypes import PassthroughApiAdapter
|
||||
from .datatypes import Adapter, ApiSurface, PassthroughApiAdapter
|
||||
from .distribution import api_surface_endpoints
|
||||
from .dynamic import instantiate_adapter
|
||||
from .dynamic import instantiate_adapter, instantiate_client
|
||||
|
||||
from .registry import resolve_distribution
|
||||
|
||||
|
@ -213,6 +222,29 @@ def create_dynamic_typed_route(func: Any):
|
|||
return endpoint
|
||||
|
||||
|
||||
def topological_sort(adapters: List[Adapter]) -> List[Adapter]:
|
||||
|
||||
by_id = {x.api_surface: x for x in adapters}
|
||||
|
||||
def dfs(a: Adapter, visited: Set[ApiSurface], stack: List[ApiSurface]):
|
||||
visited.add(a.api_surface)
|
||||
|
||||
for surface in a.adapter_dependencies:
|
||||
if surface not in visited:
|
||||
dfs(by_id[surface], visited, stack)
|
||||
|
||||
stack.append(a.api_surface)
|
||||
|
||||
visited = set()
|
||||
stack = []
|
||||
|
||||
for a in adapters:
|
||||
if a.api_surface not in visited:
|
||||
dfs(a, visited, stack)
|
||||
|
||||
return [by_id[x] for x in stack]
|
||||
|
||||
|
||||
def main(
|
||||
dist_name: str, yaml_config: str, port: int = 5000, disable_ipv6: bool = False
|
||||
):
|
||||
|
@ -228,7 +260,13 @@ def main(
|
|||
all_endpoints = api_surface_endpoints()
|
||||
|
||||
adapter_configs = config["adapters"]
|
||||
for surface, adapter in dist.adapters.items():
|
||||
adapters = topological_sort(dist.adapters.values())
|
||||
|
||||
# TODO: split this into two parts, first you resolve all impls
|
||||
# and then you create the routes.
|
||||
impls = {}
|
||||
for adapter in adapters:
|
||||
surface = adapter.api_surface
|
||||
if surface.value not in adapter_configs:
|
||||
raise ValueError(
|
||||
f"Could not find adapter config for {surface}. Please add it to the config"
|
||||
|
@ -242,8 +280,11 @@ def main(
|
|||
getattr(app, endpoint.method)(endpoint.route)(
|
||||
create_dynamic_passthrough(url)
|
||||
)
|
||||
impls[surface] = instantiate_client(adapter, adapter.base_url.rstrip("/"))
|
||||
else:
|
||||
impl = instantiate_adapter(adapter, adapter_config)
|
||||
deps = {surface: impls[surface] for surface in adapter.adapter_dependencies}
|
||||
impl = instantiate_adapter(adapter, adapter_config, deps)
|
||||
impls[surface] = impl
|
||||
for endpoint in endpoints:
|
||||
if not hasattr(impl, endpoint.name):
|
||||
# ideally this should be a typing violation already
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue