mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
ApiSurface -> Api
This commit is contained in:
parent
7890921e5c
commit
125fdb1b2a
13 changed files with 76 additions and 87 deletions
|
@ -6,13 +6,13 @@
|
|||
|
||||
from typing import List
|
||||
|
||||
from llama_toolchain.distribution.datatypes import Adapter, ApiSurface, SourceAdapter
|
||||
from llama_toolchain.distribution.datatypes import Adapter, Api, SourceAdapter
|
||||
|
||||
|
||||
def available_agentic_system_adapters() -> List[Adapter]:
|
||||
return [
|
||||
SourceAdapter(
|
||||
api_surface=ApiSurface.agentic_system,
|
||||
api=Api.agentic_system,
|
||||
adapter_id="meta-reference",
|
||||
pip_packages=[
|
||||
"codeshield",
|
||||
|
@ -22,8 +22,8 @@ def available_agentic_system_adapters() -> List[Adapter]:
|
|||
module="llama_toolchain.agentic_system.agentic_system",
|
||||
config_class="llama_toolchain.agentic_system.config.AgenticSystemConfig",
|
||||
adapter_dependencies=[
|
||||
ApiSurface.inference,
|
||||
ApiSurface.safety,
|
||||
Api.inference,
|
||||
Api.safety,
|
||||
],
|
||||
),
|
||||
]
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
|
||||
from llama_toolchain.agentic_system.api import AgenticSystem
|
||||
|
||||
from llama_toolchain.distribution.datatypes import Adapter, ApiSurface
|
||||
from llama_toolchain.distribution.datatypes import Adapter, Api
|
||||
from llama_toolchain.inference.api import Inference
|
||||
from llama_toolchain.safety.api import Safety
|
||||
|
||||
|
@ -44,16 +44,14 @@ logger = logging.getLogger()
|
|||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
async def get_adapter_impl(
|
||||
config: AgenticSystemConfig, deps: Dict[ApiSurface, Adapter]
|
||||
):
|
||||
async def get_adapter_impl(config: AgenticSystemConfig, deps: Dict[Api, Adapter]):
|
||||
assert isinstance(
|
||||
config, AgenticSystemConfig
|
||||
), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = MetaReferenceAgenticSystemImpl(
|
||||
deps[ApiSurface.inference],
|
||||
deps[ApiSurface.safety],
|
||||
deps[Api.inference],
|
||||
deps[Api.safety],
|
||||
)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -85,24 +85,21 @@ def configure_llama_distribution(dist: "Distribution", conda_env: str):
|
|||
existing_config = yaml.safe_load(fp)
|
||||
|
||||
adapter_configs = {}
|
||||
for api_surface, adapter in dist.adapters.items():
|
||||
for api, adapter in dist.adapters.items():
|
||||
if isinstance(adapter, PassthroughApiAdapter):
|
||||
adapter_configs[api_surface.value] = adapter.dict()
|
||||
adapter_configs[api.value] = adapter.dict()
|
||||
else:
|
||||
cprint(
|
||||
f"Configuring API surface: {api_surface.value}", "white", attrs=["bold"]
|
||||
)
|
||||
cprint(f"Configuring API surface: {api.value}", "white", attrs=["bold"])
|
||||
config_type = instantiate_class_type(adapter.config_class)
|
||||
config = prompt_for_config(
|
||||
config_type,
|
||||
(
|
||||
config_type(**existing_config["adapters"][api_surface.value])
|
||||
if existing_config
|
||||
and api_surface.value in existing_config["adapters"]
|
||||
config_type(**existing_config["adapters"][api.value])
|
||||
if existing_config and api.value in existing_config["adapters"]
|
||||
else None
|
||||
),
|
||||
)
|
||||
adapter_configs[api_surface.value] = {
|
||||
adapter_configs[api.value] = {
|
||||
"adapter_id": adapter.adapter_id,
|
||||
**config.dict(),
|
||||
}
|
||||
|
|
|
@ -29,7 +29,7 @@ class DistributionCreate(Subcommand):
|
|||
help="Name of the distribution to create",
|
||||
required=True,
|
||||
)
|
||||
# for each ApiSurface the user wants to support, we should
|
||||
# for each Api the user wants to support, we should
|
||||
# get the list of available adapters, ask which one the user
|
||||
# wants to pick and then ask for their configuration.
|
||||
|
||||
|
|
|
@ -12,14 +12,14 @@ from strong_typing.schema import json_schema_type
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class ApiSurface(Enum):
|
||||
class Api(Enum):
|
||||
inference = "inference"
|
||||
safety = "safety"
|
||||
agentic_system = "agentic_system"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ApiSurfaceEndpoint(BaseModel):
|
||||
class ApiEndpoint(BaseModel):
|
||||
route: str
|
||||
method: str
|
||||
name: str
|
||||
|
@ -27,7 +27,7 @@ class ApiSurfaceEndpoint(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class Adapter(BaseModel):
|
||||
api_surface: ApiSurface
|
||||
api: Api
|
||||
adapter_id: str
|
||||
|
||||
|
||||
|
@ -49,7 +49,7 @@ Fully-qualified name of the module to import. The module is expected to have:
|
|||
...,
|
||||
description="Fully-qualified classname of the config for this adapter",
|
||||
)
|
||||
adapter_dependencies: List[ApiSurface] = Field(
|
||||
adapter_dependencies: List[Api] = Field(
|
||||
default_factory=list,
|
||||
description="Higher-level API surfaces may depend on other adapters to provide their functionality",
|
||||
)
|
||||
|
@ -75,7 +75,7 @@ class Distribution(BaseModel):
|
|||
name: str
|
||||
description: str
|
||||
|
||||
adapters: Dict[ApiSurface, Adapter] = Field(
|
||||
adapters: Dict[Api, Adapter] = Field(
|
||||
default_factory=dict,
|
||||
description="The API surfaces provided by this distribution",
|
||||
)
|
||||
|
|
|
@ -11,7 +11,7 @@ from llama_toolchain.agentic_system.api.endpoints import AgenticSystem
|
|||
from llama_toolchain.inference.api.endpoints import Inference
|
||||
from llama_toolchain.safety.api.endpoints import Safety
|
||||
|
||||
from .datatypes import ApiSurface, ApiSurfaceEndpoint, Distribution, SourceAdapter
|
||||
from .datatypes import Api, ApiEndpoint, Distribution, SourceAdapter
|
||||
|
||||
|
||||
def distribution_dependencies(distribution: Distribution) -> List[str]:
|
||||
|
@ -24,16 +24,16 @@ def distribution_dependencies(distribution: Distribution) -> List[str]:
|
|||
] + distribution.additional_pip_packages
|
||||
|
||||
|
||||
def api_surface_endpoints() -> Dict[ApiSurface, List[ApiSurfaceEndpoint]]:
|
||||
surfaces = {}
|
||||
def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
||||
apis = {}
|
||||
|
||||
protocols = {
|
||||
ApiSurface.inference: Inference,
|
||||
ApiSurface.safety: Safety,
|
||||
ApiSurface.agentic_system: AgenticSystem,
|
||||
Api.inference: Inference,
|
||||
Api.safety: Safety,
|
||||
Api.agentic_system: AgenticSystem,
|
||||
}
|
||||
|
||||
for surface, protocol in protocols.items():
|
||||
for api, protocol in protocols.items():
|
||||
endpoints = []
|
||||
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
|
||||
|
||||
|
@ -46,8 +46,8 @@ def api_surface_endpoints() -> Dict[ApiSurface, List[ApiSurfaceEndpoint]]:
|
|||
|
||||
# use `post` for all methods right now until we fix up the `webmethod` openapi
|
||||
# annotation and write our own openapi generator
|
||||
endpoints.append(ApiSurfaceEndpoint(route=route, method="post", name=name))
|
||||
endpoints.append(ApiEndpoint(route=route, method="post", name=name))
|
||||
|
||||
surfaces[surface] = endpoints
|
||||
apis[api] = endpoints
|
||||
|
||||
return surfaces
|
||||
return apis
|
||||
|
|
|
@ -17,7 +17,7 @@ def instantiate_class_type(fully_qualified_name):
|
|||
return getattr(module, class_name)
|
||||
|
||||
|
||||
# returns a class implementing the protocol corresponding to the ApiSurface
|
||||
# returns a class implementing the protocol corresponding to the Api
|
||||
def instantiate_adapter(
|
||||
adapter: SourceAdapter, adapter_config: Dict[str, Any], deps: Dict[str, Adapter]
|
||||
):
|
||||
|
|
|
@ -12,7 +12,7 @@ from llama_toolchain.agentic_system.adapters import available_agentic_system_ada
|
|||
from llama_toolchain.inference.adapters import available_inference_adapters
|
||||
from llama_toolchain.safety.adapters import available_safety_adapters
|
||||
|
||||
from .datatypes import ApiSurface, Distribution, PassthroughApiAdapter
|
||||
from .datatypes import Api, Distribution, PassthroughApiAdapter
|
||||
|
||||
# This is currently duplicated from `requirements.txt` with a few minor changes
|
||||
# dev-dependencies like "ufmt" etc. are nuked. A few specialized dependencies
|
||||
|
@ -45,16 +45,16 @@ COMMON_DEPENDENCIES = [
|
|||
]
|
||||
|
||||
|
||||
def client_module(api_surface: ApiSurface) -> str:
|
||||
return f"llama_toolchain.{api_surface.value}.client"
|
||||
def client_module(api: Api) -> str:
|
||||
return f"llama_toolchain.{api.value}.client"
|
||||
|
||||
|
||||
def passthrough(api_surface: ApiSurface, port: int) -> PassthroughApiAdapter:
|
||||
def passthrough(api: Api, port: int) -> PassthroughApiAdapter:
|
||||
return PassthroughApiAdapter(
|
||||
api_surface=api_surface,
|
||||
adapter_id=f"{api_surface.value}-passthrough",
|
||||
api=api,
|
||||
adapter_id=f"{api.value}-passthrough",
|
||||
base_url=f"http://localhost:{port}",
|
||||
module=client_module(api_surface),
|
||||
module=client_module(api),
|
||||
)
|
||||
|
||||
|
||||
|
@ -72,11 +72,9 @@ def available_distributions() -> List[Distribution]:
|
|||
description="Use code from `llama_toolchain` itself to serve all llama stack APIs",
|
||||
additional_pip_packages=COMMON_DEPENDENCIES,
|
||||
adapters={
|
||||
ApiSurface.inference: inference_adapters_by_id["meta-reference"],
|
||||
ApiSurface.safety: safety_adapters_by_id["meta-reference"],
|
||||
ApiSurface.agentic_system: agentic_system_adapters_by_id[
|
||||
"meta-reference"
|
||||
],
|
||||
Api.inference: inference_adapters_by_id["meta-reference"],
|
||||
Api.safety: safety_adapters_by_id["meta-reference"],
|
||||
Api.agentic_system: agentic_system_adapters_by_id["meta-reference"],
|
||||
},
|
||||
),
|
||||
Distribution(
|
||||
|
@ -97,9 +95,9 @@ def available_distributions() -> List[Distribution]:
|
|||
"uvicorn",
|
||||
],
|
||||
adapters={
|
||||
ApiSurface.inference: passthrough(ApiSurface.inference, 5001),
|
||||
ApiSurface.safety: passthrough(ApiSurface.safety, 5001),
|
||||
ApiSurface.agentic_system: passthrough(ApiSurface.agentic_system, 5001),
|
||||
Api.inference: passthrough(Api.inference, 5001),
|
||||
Api.safety: passthrough(Api.safety, 5001),
|
||||
Api.agentic_system: passthrough(Api.agentic_system, 5001),
|
||||
},
|
||||
),
|
||||
Distribution(
|
||||
|
@ -107,11 +105,9 @@ def available_distributions() -> List[Distribution]:
|
|||
description="Like local-source, but use ollama for running LLM inference",
|
||||
additional_pip_packages=COMMON_DEPENDENCIES,
|
||||
adapters={
|
||||
ApiSurface.inference: inference_adapters_by_id["meta-ollama"],
|
||||
ApiSurface.safety: safety_adapters_by_id["meta-reference"],
|
||||
ApiSurface.agentic_system: agentic_system_adapters_by_id[
|
||||
"meta-reference"
|
||||
],
|
||||
Api.inference: inference_adapters_by_id["meta-ollama"],
|
||||
Api.safety: safety_adapters_by_id["meta-reference"],
|
||||
Api.agentic_system: agentic_system_adapters_by_id["meta-reference"],
|
||||
},
|
||||
),
|
||||
]
|
||||
|
|
|
@ -36,8 +36,8 @@ from fastapi.routing import APIRoute
|
|||
from pydantic import BaseModel, ValidationError
|
||||
from termcolor import cprint
|
||||
|
||||
from .datatypes import Adapter, ApiSurface, PassthroughApiAdapter
|
||||
from .distribution import api_surface_endpoints
|
||||
from .datatypes import Adapter, Api, PassthroughApiAdapter
|
||||
from .distribution import api_endpoints
|
||||
from .dynamic import instantiate_adapter, instantiate_client
|
||||
|
||||
from .registry import resolve_distribution
|
||||
|
@ -173,7 +173,7 @@ def create_dynamic_typed_route(func: Any):
|
|||
request_model = next(iter(hints.values()))
|
||||
response_model = hints["return"]
|
||||
|
||||
# NOTE: I think it is better to just add a method within each ApiSurface
|
||||
# NOTE: I think it is better to just add a method within each Api
|
||||
# "Protocol" / adapter-impl to tell what sort of a response this request
|
||||
# is going to produce. /chat_completion can produce a streaming or
|
||||
# non-streaming response depending on if request.stream is True / False.
|
||||
|
@ -228,23 +228,23 @@ def create_dynamic_typed_route(func: Any):
|
|||
|
||||
def topological_sort(adapters: List[Adapter]) -> List[Adapter]:
|
||||
|
||||
by_id = {x.api_surface: x for x in adapters}
|
||||
by_id = {x.api: x for x in adapters}
|
||||
|
||||
def dfs(a: Adapter, visited: Set[ApiSurface], stack: List[ApiSurface]):
|
||||
visited.add(a.api_surface)
|
||||
def dfs(a: Adapter, visited: Set[Api], stack: List[Api]):
|
||||
visited.add(a.api)
|
||||
|
||||
if not isinstance(a, PassthroughApiAdapter):
|
||||
for surface in a.adapter_dependencies:
|
||||
if surface not in visited:
|
||||
dfs(by_id[surface], visited, stack)
|
||||
for api in a.adapter_dependencies:
|
||||
if api not in visited:
|
||||
dfs(by_id[api], visited, stack)
|
||||
|
||||
stack.append(a.api_surface)
|
||||
stack.append(a.api)
|
||||
|
||||
visited = set()
|
||||
stack = []
|
||||
|
||||
for a in adapters:
|
||||
if a.api_surface not in visited:
|
||||
if a.api not in visited:
|
||||
dfs(a, visited, stack)
|
||||
|
||||
return [by_id[x] for x in stack]
|
||||
|
@ -262,7 +262,7 @@ def main(
|
|||
|
||||
app = FastAPI()
|
||||
|
||||
all_endpoints = api_surface_endpoints()
|
||||
all_endpoints = api_endpoints()
|
||||
|
||||
adapter_configs = config["adapters"]
|
||||
adapters = topological_sort(dist.adapters.values())
|
||||
|
@ -271,25 +271,25 @@ def main(
|
|||
# and then you create the routes.
|
||||
impls = {}
|
||||
for adapter in adapters:
|
||||
surface = adapter.api_surface
|
||||
if surface.value not in adapter_configs:
|
||||
api = adapter.api
|
||||
if api.value not in adapter_configs:
|
||||
raise ValueError(
|
||||
f"Could not find adapter config for {surface}. Please add it to the config"
|
||||
f"Could not find adapter config for {api}. Please add it to the config"
|
||||
)
|
||||
|
||||
adapter_config = adapter_configs[surface.value]
|
||||
endpoints = all_endpoints[surface]
|
||||
adapter_config = adapter_configs[api.value]
|
||||
endpoints = all_endpoints[api]
|
||||
if isinstance(adapter, PassthroughApiAdapter):
|
||||
for endpoint in endpoints:
|
||||
url = adapter.base_url.rstrip("/") + endpoint.route
|
||||
getattr(app, endpoint.method)(endpoint.route)(
|
||||
create_dynamic_passthrough(url)
|
||||
)
|
||||
impls[surface] = instantiate_client(adapter, adapter.base_url.rstrip("/"))
|
||||
impls[api] = instantiate_client(adapter, adapter.base_url.rstrip("/"))
|
||||
else:
|
||||
deps = {surface: impls[surface] for surface in adapter.adapter_dependencies}
|
||||
deps = {api: impls[api] for api in adapter.adapter_dependencies}
|
||||
impl = instantiate_adapter(adapter, adapter_config, deps)
|
||||
impls[surface] = impl
|
||||
impls[api] = impl
|
||||
for endpoint in endpoints:
|
||||
if not hasattr(impl, endpoint.name):
|
||||
# ideally this should be a typing violation already
|
||||
|
|
|
@ -6,13 +6,13 @@
|
|||
|
||||
from typing import List
|
||||
|
||||
from llama_toolchain.distribution.datatypes import Adapter, ApiSurface, SourceAdapter
|
||||
from llama_toolchain.distribution.datatypes import Adapter, Api, SourceAdapter
|
||||
|
||||
|
||||
def available_inference_adapters() -> List[Adapter]:
|
||||
return [
|
||||
SourceAdapter(
|
||||
api_surface=ApiSurface.inference,
|
||||
api=Api.inference,
|
||||
adapter_id="meta-reference",
|
||||
pip_packages=[
|
||||
"torch",
|
||||
|
@ -22,7 +22,7 @@ def available_inference_adapters() -> List[Adapter]:
|
|||
config_class="llama_toolchain.inference.inference.MetaReferenceImplConfig",
|
||||
),
|
||||
SourceAdapter(
|
||||
api_surface=ApiSurface.inference,
|
||||
api=Api.inference,
|
||||
adapter_id="meta-ollama",
|
||||
pip_packages=[
|
||||
"ollama",
|
||||
|
|
|
@ -11,7 +11,7 @@ from typing import AsyncIterator, Dict, Union
|
|||
from llama_models.llama3_1.api.datatypes import StopReason
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_toolchain.distribution.datatypes import Adapter, ApiSurface
|
||||
from llama_toolchain.distribution.datatypes import Adapter, Api
|
||||
|
||||
from .api.config import MetaReferenceImplConfig
|
||||
from .api.datatypes import (
|
||||
|
@ -29,9 +29,7 @@ from .api.endpoints import (
|
|||
from .model_parallel import LlamaModelParallelGenerator
|
||||
|
||||
|
||||
async def get_adapter_impl(
|
||||
config: MetaReferenceImplConfig, _deps: Dict[ApiSurface, Adapter]
|
||||
):
|
||||
async def get_adapter_impl(config: MetaReferenceImplConfig, _deps: Dict[Api, Adapter]):
|
||||
assert isinstance(
|
||||
config, MetaReferenceImplConfig
|
||||
), f"Unexpected config type: {type(config)}"
|
||||
|
|
|
@ -6,13 +6,13 @@
|
|||
|
||||
from typing import List
|
||||
|
||||
from llama_toolchain.distribution.datatypes import Adapter, ApiSurface, SourceAdapter
|
||||
from llama_toolchain.distribution.datatypes import Adapter, Api, SourceAdapter
|
||||
|
||||
|
||||
def available_safety_adapters() -> List[Adapter]:
|
||||
return [
|
||||
SourceAdapter(
|
||||
api_surface=ApiSurface.safety,
|
||||
api=Api.safety,
|
||||
adapter_id="meta-reference",
|
||||
pip_packages=[
|
||||
"codeshield",
|
||||
|
|
|
@ -8,7 +8,7 @@ import asyncio
|
|||
|
||||
from typing import Dict
|
||||
|
||||
from llama_toolchain.distribution.datatypes import Adapter, ApiSurface
|
||||
from llama_toolchain.distribution.datatypes import Adapter, Api
|
||||
|
||||
from .config import SafetyConfig
|
||||
from .api.endpoints import * # noqa
|
||||
|
@ -23,7 +23,7 @@ from .shields import (
|
|||
)
|
||||
|
||||
|
||||
async def get_adapter_impl(config: SafetyConfig, _deps: Dict[ApiSurface, Adapter]):
|
||||
async def get_adapter_impl(config: SafetyConfig, _deps: Dict[Api, Adapter]):
|
||||
assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = MetaReferenceSafetyImpl(config)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue