mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
Bring agentic system api to toolchain
Add adapter dependencies and resolve adapters using a topological sort
This commit is contained in:
parent
b0e5340645
commit
be19b22391
31 changed files with 2740 additions and 25 deletions
|
@ -15,6 +15,7 @@ from strong_typing.schema import json_schema_type
|
|||
class ApiSurface(Enum):
|
||||
inference = "inference"
|
||||
safety = "safety"
|
||||
agentic_system = "agentic_system"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -39,14 +40,19 @@ class SourceAdapter(Adapter):
|
|||
module: str = Field(
|
||||
...,
|
||||
description="""
|
||||
Fully-qualified name of the module to import. The module is expected to have
|
||||
a `get_adapter_instance()` method which will be passed a validated config object
|
||||
of type `config_class`.""",
|
||||
Fully-qualified name of the module to import. The module is expected to have:
|
||||
|
||||
- `get_adapter_impl(config, deps)`: returns the local implementation
|
||||
""",
|
||||
)
|
||||
config_class: str = Field(
|
||||
...,
|
||||
description="Fully-qualified classname of the config for this adapter",
|
||||
)
|
||||
adapter_dependencies: List[ApiSurface] = Field(
|
||||
default_factory=list,
|
||||
description="Higher-level API surfaces may depend on other adapters to provide their functionality",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -56,6 +62,13 @@ class PassthroughApiAdapter(Adapter):
|
|||
default_factory=dict,
|
||||
description="Headers (e.g., authorization) to send with the request",
|
||||
)
|
||||
module: str = Field(
|
||||
...,
|
||||
description="""
|
||||
Fully-qualified name of the module to import. The module is expected to have:
|
||||
- `get_client_impl(base_url)`: returns a client which can be used to call the remote implementation
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
class Distribution(BaseModel):
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
import inspect
|
||||
from typing import Dict, List
|
||||
|
||||
from llama_toolchain.agentic_system.api.endpoints import AgenticSystem
|
||||
from llama_toolchain.inference.api.endpoints import Inference
|
||||
from llama_toolchain.safety.api.endpoints import Safety
|
||||
|
||||
|
@ -29,6 +30,7 @@ def api_surface_endpoints() -> Dict[ApiSurface, List[ApiSurfaceEndpoint]]:
|
|||
protocols = {
|
||||
ApiSurface.inference: Inference,
|
||||
ApiSurface.safety: Safety,
|
||||
ApiSurface.agentic_system: AgenticSystem,
|
||||
}
|
||||
|
||||
for surface, protocol in protocols.items():
|
||||
|
|
|
@ -8,7 +8,7 @@ import asyncio
|
|||
import importlib
|
||||
from typing import Any, Dict
|
||||
|
||||
from .datatypes import SourceAdapter
|
||||
from .datatypes import Adapter, PassthroughApiAdapter, SourceAdapter
|
||||
|
||||
|
||||
def instantiate_class_type(fully_qualified_name):
|
||||
|
@ -18,9 +18,17 @@ def instantiate_class_type(fully_qualified_name):
|
|||
|
||||
|
||||
# returns a class implementing the protocol corresponding to the ApiSurface
|
||||
def instantiate_adapter(adapter: SourceAdapter, adapter_config: Dict[str, Any]):
|
||||
def instantiate_adapter(
|
||||
adapter: SourceAdapter, adapter_config: Dict[str, Any], deps: Dict[str, Adapter]
|
||||
):
|
||||
module = importlib.import_module(adapter.module)
|
||||
|
||||
config_type = instantiate_class_type(adapter.config_class)
|
||||
config = config_type(**adapter_config)
|
||||
return asyncio.run(module.get_adapter_impl(config))
|
||||
return asyncio.run(module.get_adapter_impl(config, deps))
|
||||
|
||||
|
||||
def instantiate_client(adapter: PassthroughApiAdapter, base_url: str):
|
||||
module = importlib.import_module(adapter.module)
|
||||
|
||||
return asyncio.run(module.get_client_impl(base_url))
|
||||
|
|
|
@ -7,6 +7,8 @@
|
|||
from functools import lru_cache
|
||||
from typing import List, Optional
|
||||
|
||||
from llama_toolchain.agentic_system.adapters import available_agentic_system_adapters
|
||||
|
||||
from llama_toolchain.inference.adapters import available_inference_adapters
|
||||
from llama_toolchain.safety.adapters import available_safety_adapters
|
||||
|
||||
|
@ -43,10 +45,26 @@ COMMON_DEPENDENCIES = [
|
|||
]
|
||||
|
||||
|
||||
def client_module(api_surface: ApiSurface) -> str:
|
||||
return f"llama_toolchain.{api_surface.value}.client"
|
||||
|
||||
|
||||
def passthrough(api_surface: ApiSurface, port: int) -> PassthroughApiAdapter:
|
||||
return PassthroughApiAdapter(
|
||||
api_surface=api_surface,
|
||||
adapter_id=f"{api_surface.value}-passthrough",
|
||||
base_url=f"http://localhost:{port}",
|
||||
module=client_module(api_surface),
|
||||
)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def available_distributions() -> List[Distribution]:
|
||||
inference_adapters_by_id = {a.adapter_id: a for a in available_inference_adapters()}
|
||||
safety_adapters_by_id = {a.adapter_id: a for a in available_safety_adapters()}
|
||||
agentic_system_adapters_by_id = {
|
||||
a.adapter_id: a for a in available_agentic_system_adapters()
|
||||
}
|
||||
|
||||
return [
|
||||
Distribution(
|
||||
|
@ -56,6 +74,9 @@ def available_distributions() -> List[Distribution]:
|
|||
adapters={
|
||||
ApiSurface.inference: inference_adapters_by_id["meta-reference"],
|
||||
ApiSurface.safety: safety_adapters_by_id["meta-reference"],
|
||||
ApiSurface.agentic_system: agentic_system_adapters_by_id[
|
||||
"meta-reference"
|
||||
],
|
||||
},
|
||||
),
|
||||
Distribution(
|
||||
|
@ -76,16 +97,9 @@ def available_distributions() -> List[Distribution]:
|
|||
"uvicorn",
|
||||
],
|
||||
adapters={
|
||||
ApiSurface.inference: PassthroughApiAdapter(
|
||||
api_surface=ApiSurface.inference,
|
||||
adapter_id="inference-passthrough",
|
||||
base_url="http://localhost:5001",
|
||||
),
|
||||
ApiSurface.safety: PassthroughApiAdapter(
|
||||
api_surface=ApiSurface.safety,
|
||||
adapter_id="safety-passthrough",
|
||||
base_url="http://localhost:5001",
|
||||
),
|
||||
ApiSurface.inference: passthrough(ApiSurface.inference, 5001),
|
||||
ApiSurface.safety: passthrough(ApiSurface.safety, 5001),
|
||||
ApiSurface.agentic_system: passthrough(ApiSurface.agentic_system, 5001),
|
||||
},
|
||||
),
|
||||
Distribution(
|
||||
|
@ -95,6 +109,9 @@ def available_distributions() -> List[Distribution]:
|
|||
adapters={
|
||||
ApiSurface.inference: inference_adapters_by_id["meta-ollama"],
|
||||
ApiSurface.safety: safety_adapters_by_id["meta-reference"],
|
||||
ApiSurface.agentic_system: agentic_system_adapters_by_id[
|
||||
"meta-reference"
|
||||
],
|
||||
},
|
||||
),
|
||||
]
|
||||
|
|
|
@ -12,7 +12,16 @@ from collections.abc import (
|
|||
AsyncIterator as AsyncIteratorABC,
|
||||
)
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, get_type_hints, Optional
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
get_type_hints,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
)
|
||||
|
||||
import fire
|
||||
import httpx
|
||||
|
@ -27,9 +36,9 @@ from fastapi.routing import APIRoute
|
|||
from pydantic import BaseModel, ValidationError
|
||||
from termcolor import cprint
|
||||
|
||||
from .datatypes import PassthroughApiAdapter
|
||||
from .datatypes import Adapter, ApiSurface, PassthroughApiAdapter
|
||||
from .distribution import api_surface_endpoints
|
||||
from .dynamic import instantiate_adapter
|
||||
from .dynamic import instantiate_adapter, instantiate_client
|
||||
|
||||
from .registry import resolve_distribution
|
||||
|
||||
|
@ -213,6 +222,29 @@ def create_dynamic_typed_route(func: Any):
|
|||
return endpoint
|
||||
|
||||
|
||||
def topological_sort(adapters: List[Adapter]) -> List[Adapter]:
|
||||
|
||||
by_id = {x.api_surface: x for x in adapters}
|
||||
|
||||
def dfs(a: Adapter, visited: Set[ApiSurface], stack: List[ApiSurface]):
|
||||
visited.add(a.api_surface)
|
||||
|
||||
for surface in a.adapter_dependencies:
|
||||
if surface not in visited:
|
||||
dfs(by_id[surface], visited, stack)
|
||||
|
||||
stack.append(a.api_surface)
|
||||
|
||||
visited = set()
|
||||
stack = []
|
||||
|
||||
for a in adapters:
|
||||
if a.api_surface not in visited:
|
||||
dfs(a, visited, stack)
|
||||
|
||||
return [by_id[x] for x in stack]
|
||||
|
||||
|
||||
def main(
|
||||
dist_name: str, yaml_config: str, port: int = 5000, disable_ipv6: bool = False
|
||||
):
|
||||
|
@ -228,7 +260,13 @@ def main(
|
|||
all_endpoints = api_surface_endpoints()
|
||||
|
||||
adapter_configs = config["adapters"]
|
||||
for surface, adapter in dist.adapters.items():
|
||||
adapters = topological_sort(dist.adapters.values())
|
||||
|
||||
# TODO: split this into two parts, first you resolve all impls
|
||||
# and then you create the routes.
|
||||
impls = {}
|
||||
for adapter in adapters:
|
||||
surface = adapter.api_surface
|
||||
if surface.value not in adapter_configs:
|
||||
raise ValueError(
|
||||
f"Could not find adapter config for {surface}. Please add it to the config"
|
||||
|
@ -242,8 +280,11 @@ def main(
|
|||
getattr(app, endpoint.method)(endpoint.route)(
|
||||
create_dynamic_passthrough(url)
|
||||
)
|
||||
impls[surface] = instantiate_client(adapter, adapter.base_url.rstrip("/"))
|
||||
else:
|
||||
impl = instantiate_adapter(adapter, adapter_config)
|
||||
deps = {surface: impls[surface] for surface in adapter.adapter_dependencies}
|
||||
impl = instantiate_adapter(adapter, adapter_config, deps)
|
||||
impls[surface] = impl
|
||||
for endpoint in endpoints:
|
||||
if not hasattr(impl, endpoint.name):
|
||||
# ideally this should be a typing violation already
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue