mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-24 16:57:21 +00:00
Some checks failed
Coverage Badge / unit-tests (push) Failing after 1s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Installer CI / lint (push) Failing after 3s
Integration Tests / discover-tests (push) Successful in 3s
Installer CI / smoke-test-on-dev (push) Failing after 5s
Vector IO Integration Tests / test-matrix (3.12, inline::faiss) (push) Failing after 5s
Vector IO Integration Tests / test-matrix (3.12, inline::sqlite-vec) (push) Failing after 5s
Vector IO Integration Tests / test-matrix (3.12, inline::milvus) (push) Failing after 5s
Vector IO Integration Tests / test-matrix (3.12, remote::chromadb) (push) Failing after 6s
Python Package Build Test / build (3.12) (push) Failing after 3s
Python Package Build Test / build (3.13) (push) Failing after 2s
Test Llama Stack Build / generate-matrix (push) Successful in 3s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 10s
Test Llama Stack Build / build-single-provider (push) Failing after 5s
Vector IO Integration Tests / test-matrix (3.13, inline::faiss) (push) Failing after 5s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 13s
Unit Tests / unit-tests (3.13) (push) Failing after 6s
Test External API and Providers / test-external (venv) (push) Failing after 5s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 6s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 5s
Unit Tests / unit-tests (3.12) (push) Failing after 7s
Vector IO Integration Tests / test-matrix (3.13, remote::pgvector) (push) Failing after 6s
Update ReadTheDocs / update-readthedocs (push) Failing after 8s
Integration Tests / test-matrix (push) Failing after 8s
Vector IO Integration Tests / test-matrix (3.13, remote::chromadb) (push) Failing after 5s
Vector IO Integration Tests / test-matrix (3.12, remote::pgvector) (push) Failing after 6s
Vector IO Integration Tests / test-matrix (3.13, inline::sqlite-vec) (push) Failing after 6s
Vector IO Integration Tests / test-matrix (3.13, inline::milvus) (push) Failing after 5s
Test Llama Stack Build / build (push) Failing after 6s
Pre-commit / pre-commit (push) Successful in 57s
# What does this PR do? Prototype on a new feature to allow new APIs to be plugged in Llama Stack. Opened for early feedback on the approach and test appetite on the functionality. @ashwinb @raghotham open for early feedback, thanks! --------- Signed-off-by: Sébastien Han <seb@redhat.com> Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
141 lines
5.1 KiB
Python
141 lines
5.1 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import inspect
|
|
import re
|
|
from collections.abc import Callable
|
|
from typing import Any
|
|
|
|
from aiohttp import hdrs
|
|
from starlette.routing import Route
|
|
|
|
from llama_stack.apis.datatypes import Api, ExternalApiSpec
|
|
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
|
|
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
|
|
from llama_stack.distribution.resolver import api_protocol_map
|
|
from llama_stack.schema_utils import WebMethod
|
|
|
|
EndpointFunc = Callable[..., Any]
|
|
PathParams = dict[str, str]
|
|
RouteInfo = tuple[EndpointFunc, str, WebMethod]
|
|
PathImpl = dict[str, RouteInfo]
|
|
RouteImpls = dict[str, PathImpl]
|
|
RouteMatch = tuple[EndpointFunc, PathParams, str, WebMethod]
|
|
|
|
|
|
def toolgroup_protocol_map():
|
|
return {
|
|
SpecialToolGroup.rag_tool: RAGToolRuntime,
|
|
}
|
|
|
|
|
|
def get_all_api_routes(
|
|
external_apis: dict[Api, ExternalApiSpec] | None = None,
|
|
) -> dict[Api, list[tuple[Route, WebMethod]]]:
|
|
apis = {}
|
|
|
|
protocols = api_protocol_map(external_apis)
|
|
toolgroup_protocols = toolgroup_protocol_map()
|
|
for api, protocol in protocols.items():
|
|
routes = []
|
|
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
|
|
|
|
# HACK ALERT
|
|
if api == Api.tool_runtime:
|
|
for tool_group in SpecialToolGroup:
|
|
sub_protocol = toolgroup_protocols[tool_group]
|
|
sub_protocol_methods = inspect.getmembers(sub_protocol, predicate=inspect.isfunction)
|
|
for name, method in sub_protocol_methods:
|
|
if not hasattr(method, "__webmethod__"):
|
|
continue
|
|
protocol_methods.append((f"{tool_group.value}.{name}", method))
|
|
|
|
for name, method in protocol_methods:
|
|
if not hasattr(method, "__webmethod__"):
|
|
continue
|
|
|
|
# The __webmethod__ attribute is dynamically added by the @webmethod decorator
|
|
# mypy doesn't know about this dynamic attribute, so we ignore the attr-defined error
|
|
webmethod = method.__webmethod__ # type: ignore[attr-defined]
|
|
path = f"/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}"
|
|
if webmethod.method == hdrs.METH_GET:
|
|
http_method = hdrs.METH_GET
|
|
elif webmethod.method == hdrs.METH_DELETE:
|
|
http_method = hdrs.METH_DELETE
|
|
else:
|
|
http_method = hdrs.METH_POST
|
|
routes.append(
|
|
(Route(path=path, methods=[http_method], name=name, endpoint=None), webmethod)
|
|
) # setting endpoint to None since don't use a Router object
|
|
|
|
apis[api] = routes
|
|
|
|
return apis
|
|
|
|
|
|
def initialize_route_impls(impls, external_apis: dict[Api, ExternalApiSpec] | None = None) -> RouteImpls:
|
|
api_to_routes = get_all_api_routes(external_apis)
|
|
route_impls: RouteImpls = {}
|
|
|
|
def _convert_path_to_regex(path: str) -> str:
|
|
# Convert {param} to named capture groups
|
|
# handle {param:path} as well which allows for forward slashes in the param value
|
|
pattern = re.sub(
|
|
r"{(\w+)(?::path)?}",
|
|
lambda m: f"(?P<{m.group(1)}>{'[^/]+' if not m.group(0).endswith(':path') else '.+'})",
|
|
path,
|
|
)
|
|
|
|
return f"^{pattern}$"
|
|
|
|
for api, api_routes in api_to_routes.items():
|
|
if api not in impls:
|
|
continue
|
|
for route, webmethod in api_routes:
|
|
impl = impls[api]
|
|
func = getattr(impl, route.name)
|
|
# Get the first (and typically only) method from the set, filtering out HEAD
|
|
available_methods = [m for m in route.methods if m != "HEAD"]
|
|
if not available_methods:
|
|
continue # Skip if only HEAD method is available
|
|
method = available_methods[0].lower()
|
|
if method not in route_impls:
|
|
route_impls[method] = {}
|
|
route_impls[method][_convert_path_to_regex(route.path)] = (
|
|
func,
|
|
route.path,
|
|
webmethod,
|
|
)
|
|
|
|
return route_impls
|
|
|
|
|
|
def find_matching_route(method: str, path: str, route_impls: RouteImpls) -> RouteMatch:
|
|
"""Find the matching endpoint implementation for a given method and path.
|
|
|
|
Args:
|
|
method: HTTP method (GET, POST, etc.)
|
|
path: URL path to match against
|
|
route_impls: A dictionary of endpoint implementations
|
|
|
|
Returns:
|
|
A tuple of (endpoint_function, path_params, route_path, webmethod_metadata)
|
|
|
|
Raises:
|
|
ValueError: If no matching endpoint is found
|
|
"""
|
|
impls = route_impls.get(method.lower())
|
|
if not impls:
|
|
raise ValueError(f"No endpoint found for {path}")
|
|
|
|
for regex, (func, route_path, webmethod) in impls.items():
|
|
match = re.match(regex, path)
|
|
if match:
|
|
# Extract named groups from the regex match
|
|
path_params = match.groupdict()
|
|
return func, path_params, route_path, webmethod
|
|
|
|
raise ValueError(f"No endpoint found for {path}")
|