ApiSurface -> Api

This commit is contained in:
Ashwin Bharambe 2024-08-05 12:44:56 -07:00
parent 7890921e5c
commit 125fdb1b2a
13 changed files with 76 additions and 87 deletions

View file

@ -12,14 +12,14 @@ from strong_typing.schema import json_schema_type
@json_schema_type
class ApiSurface(Enum):
class Api(Enum):
inference = "inference"
safety = "safety"
agentic_system = "agentic_system"
@json_schema_type
class ApiSurfaceEndpoint(BaseModel):
class ApiEndpoint(BaseModel):
route: str
method: str
name: str
@ -27,7 +27,7 @@ class ApiSurfaceEndpoint(BaseModel):
@json_schema_type
class Adapter(BaseModel):
api_surface: ApiSurface
api: Api
adapter_id: str
@ -49,7 +49,7 @@ Fully-qualified name of the module to import. The module is expected to have:
...,
description="Fully-qualified classname of the config for this adapter",
)
adapter_dependencies: List[ApiSurface] = Field(
adapter_dependencies: List[Api] = Field(
default_factory=list,
description="Higher-level API surfaces may depend on other adapters to provide their functionality",
)
@ -75,7 +75,7 @@ class Distribution(BaseModel):
name: str
description: str
adapters: Dict[ApiSurface, Adapter] = Field(
adapters: Dict[Api, Adapter] = Field(
default_factory=dict,
description="The API surfaces provided by this distribution",
)

View file

@ -11,7 +11,7 @@ from llama_toolchain.agentic_system.api.endpoints import AgenticSystem
from llama_toolchain.inference.api.endpoints import Inference
from llama_toolchain.safety.api.endpoints import Safety
from .datatypes import ApiSurface, ApiSurfaceEndpoint, Distribution, SourceAdapter
from .datatypes import Api, ApiEndpoint, Distribution, SourceAdapter
def distribution_dependencies(distribution: Distribution) -> List[str]:
@ -24,16 +24,16 @@ def distribution_dependencies(distribution: Distribution) -> List[str]:
] + distribution.additional_pip_packages
def api_surface_endpoints() -> Dict[ApiSurface, List[ApiSurfaceEndpoint]]:
surfaces = {}
def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
apis = {}
protocols = {
ApiSurface.inference: Inference,
ApiSurface.safety: Safety,
ApiSurface.agentic_system: AgenticSystem,
Api.inference: Inference,
Api.safety: Safety,
Api.agentic_system: AgenticSystem,
}
for surface, protocol in protocols.items():
for api, protocol in protocols.items():
endpoints = []
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
@ -46,8 +46,8 @@ def api_surface_endpoints() -> Dict[ApiSurface, List[ApiSurfaceEndpoint]]:
# use `post` for all methods right now until we fix up the `webmethod` openapi
# annotation and write our own openapi generator
endpoints.append(ApiSurfaceEndpoint(route=route, method="post", name=name))
endpoints.append(ApiEndpoint(route=route, method="post", name=name))
surfaces[surface] = endpoints
apis[api] = endpoints
return surfaces
return apis

View file

@ -17,7 +17,7 @@ def instantiate_class_type(fully_qualified_name):
return getattr(module, class_name)
# returns a class implementing the protocol corresponding to the ApiSurface
# returns a class implementing the protocol corresponding to the Api
def instantiate_adapter(
adapter: SourceAdapter, adapter_config: Dict[str, Any], deps: Dict[str, Adapter]
):

View file

@ -12,7 +12,7 @@ from llama_toolchain.agentic_system.adapters import available_agentic_system_ada
from llama_toolchain.inference.adapters import available_inference_adapters
from llama_toolchain.safety.adapters import available_safety_adapters
from .datatypes import ApiSurface, Distribution, PassthroughApiAdapter
from .datatypes import Api, Distribution, PassthroughApiAdapter
# This is currently duplicated from `requirements.txt` with a few minor changes
# dev-dependencies like "ufmt" etc. are nuked. A few specialized dependencies
@ -45,16 +45,16 @@ COMMON_DEPENDENCIES = [
]
def client_module(api_surface: ApiSurface) -> str:
return f"llama_toolchain.{api_surface.value}.client"
def client_module(api: Api) -> str:
return f"llama_toolchain.{api.value}.client"
def passthrough(api_surface: ApiSurface, port: int) -> PassthroughApiAdapter:
def passthrough(api: Api, port: int) -> PassthroughApiAdapter:
return PassthroughApiAdapter(
api_surface=api_surface,
adapter_id=f"{api_surface.value}-passthrough",
api=api,
adapter_id=f"{api.value}-passthrough",
base_url=f"http://localhost:{port}",
module=client_module(api_surface),
module=client_module(api),
)
@ -72,11 +72,9 @@ def available_distributions() -> List[Distribution]:
description="Use code from `llama_toolchain` itself to serve all llama stack APIs",
additional_pip_packages=COMMON_DEPENDENCIES,
adapters={
ApiSurface.inference: inference_adapters_by_id["meta-reference"],
ApiSurface.safety: safety_adapters_by_id["meta-reference"],
ApiSurface.agentic_system: agentic_system_adapters_by_id[
"meta-reference"
],
Api.inference: inference_adapters_by_id["meta-reference"],
Api.safety: safety_adapters_by_id["meta-reference"],
Api.agentic_system: agentic_system_adapters_by_id["meta-reference"],
},
),
Distribution(
@ -97,9 +95,9 @@ def available_distributions() -> List[Distribution]:
"uvicorn",
],
adapters={
ApiSurface.inference: passthrough(ApiSurface.inference, 5001),
ApiSurface.safety: passthrough(ApiSurface.safety, 5001),
ApiSurface.agentic_system: passthrough(ApiSurface.agentic_system, 5001),
Api.inference: passthrough(Api.inference, 5001),
Api.safety: passthrough(Api.safety, 5001),
Api.agentic_system: passthrough(Api.agentic_system, 5001),
},
),
Distribution(
@ -107,11 +105,9 @@ def available_distributions() -> List[Distribution]:
description="Like local-source, but use ollama for running LLM inference",
additional_pip_packages=COMMON_DEPENDENCIES,
adapters={
ApiSurface.inference: inference_adapters_by_id["meta-ollama"],
ApiSurface.safety: safety_adapters_by_id["meta-reference"],
ApiSurface.agentic_system: agentic_system_adapters_by_id[
"meta-reference"
],
Api.inference: inference_adapters_by_id["meta-ollama"],
Api.safety: safety_adapters_by_id["meta-reference"],
Api.agentic_system: agentic_system_adapters_by_id["meta-reference"],
},
),
]

View file

@ -36,8 +36,8 @@ from fastapi.routing import APIRoute
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from .datatypes import Adapter, ApiSurface, PassthroughApiAdapter
from .distribution import api_surface_endpoints
from .datatypes import Adapter, Api, PassthroughApiAdapter
from .distribution import api_endpoints
from .dynamic import instantiate_adapter, instantiate_client
from .registry import resolve_distribution
@ -173,7 +173,7 @@ def create_dynamic_typed_route(func: Any):
request_model = next(iter(hints.values()))
response_model = hints["return"]
# NOTE: I think it is better to just add a method within each ApiSurface
# NOTE: I think it is better to just add a method within each Api
# "Protocol" / adapter-impl to tell what sort of a response this request
# is going to produce. /chat_completion can produce a streaming or
# non-streaming response depending on if request.stream is True / False.
@ -228,23 +228,23 @@ def create_dynamic_typed_route(func: Any):
def topological_sort(adapters: List[Adapter]) -> List[Adapter]:
by_id = {x.api_surface: x for x in adapters}
by_id = {x.api: x for x in adapters}
def dfs(a: Adapter, visited: Set[ApiSurface], stack: List[ApiSurface]):
visited.add(a.api_surface)
def dfs(a: Adapter, visited: Set[Api], stack: List[Api]):
visited.add(a.api)
if not isinstance(a, PassthroughApiAdapter):
for surface in a.adapter_dependencies:
if surface not in visited:
dfs(by_id[surface], visited, stack)
for api in a.adapter_dependencies:
if api not in visited:
dfs(by_id[api], visited, stack)
stack.append(a.api_surface)
stack.append(a.api)
visited = set()
stack = []
for a in adapters:
if a.api_surface not in visited:
if a.api not in visited:
dfs(a, visited, stack)
return [by_id[x] for x in stack]
@ -262,7 +262,7 @@ def main(
app = FastAPI()
all_endpoints = api_surface_endpoints()
all_endpoints = api_endpoints()
adapter_configs = config["adapters"]
adapters = topological_sort(dist.adapters.values())
@ -271,25 +271,25 @@ def main(
# and then you create the routes.
impls = {}
for adapter in adapters:
surface = adapter.api_surface
if surface.value not in adapter_configs:
api = adapter.api
if api.value not in adapter_configs:
raise ValueError(
f"Could not find adapter config for {surface}. Please add it to the config"
f"Could not find adapter config for {api}. Please add it to the config"
)
adapter_config = adapter_configs[surface.value]
endpoints = all_endpoints[surface]
adapter_config = adapter_configs[api.value]
endpoints = all_endpoints[api]
if isinstance(adapter, PassthroughApiAdapter):
for endpoint in endpoints:
url = adapter.base_url.rstrip("/") + endpoint.route
getattr(app, endpoint.method)(endpoint.route)(
create_dynamic_passthrough(url)
)
impls[surface] = instantiate_client(adapter, adapter.base_url.rstrip("/"))
impls[api] = instantiate_client(adapter, adapter.base_url.rstrip("/"))
else:
deps = {surface: impls[surface] for surface in adapter.adapter_dependencies}
deps = {api: impls[api] for api in adapter.adapter_dependencies}
impl = instantiate_adapter(adapter, adapter_config, deps)
impls[surface] = impl
impls[api] = impl
for endpoint in endpoints:
if not hasattr(impl, endpoint.name):
# ideally this should be a typing violation already