diff --git a/llama_toolchain/agentic_system/adapters.py b/llama_toolchain/agentic_system/adapters.py index df8e8c9d6..82d1a0ccb 100644 --- a/llama_toolchain/agentic_system/adapters.py +++ b/llama_toolchain/agentic_system/adapters.py @@ -6,13 +6,13 @@ from typing import List -from llama_toolchain.distribution.datatypes import Adapter, ApiSurface, SourceAdapter +from llama_toolchain.distribution.datatypes import Adapter, Api, SourceAdapter def available_agentic_system_adapters() -> List[Adapter]: return [ SourceAdapter( - api_surface=ApiSurface.agentic_system, + api=Api.agentic_system, adapter_id="meta-reference", pip_packages=[ "codeshield", @@ -22,8 +22,8 @@ def available_agentic_system_adapters() -> List[Adapter]: module="llama_toolchain.agentic_system.agentic_system", config_class="llama_toolchain.agentic_system.config.AgenticSystemConfig", adapter_dependencies=[ - ApiSurface.inference, - ApiSurface.safety, + Api.inference, + Api.safety, ], ), ] diff --git a/llama_toolchain/agentic_system/agentic_system.py b/llama_toolchain/agentic_system/agentic_system.py index 95fcd7e02..8bf74e44f 100644 --- a/llama_toolchain/agentic_system/agentic_system.py +++ b/llama_toolchain/agentic_system/agentic_system.py @@ -7,7 +7,7 @@ from llama_toolchain.agentic_system.api import AgenticSystem -from llama_toolchain.distribution.datatypes import Adapter, ApiSurface +from llama_toolchain.distribution.datatypes import Adapter, Api from llama_toolchain.inference.api import Inference from llama_toolchain.safety.api import Safety @@ -44,16 +44,14 @@ logger = logging.getLogger() logger.setLevel(logging.INFO) -async def get_adapter_impl( - config: AgenticSystemConfig, deps: Dict[ApiSurface, Adapter] -): +async def get_adapter_impl(config: AgenticSystemConfig, deps: Dict[Api, Adapter]): assert isinstance( config, AgenticSystemConfig ), f"Unexpected config type: {type(config)}" impl = MetaReferenceAgenticSystemImpl( - deps[ApiSurface.inference], - deps[ApiSurface.safety], + deps[Api.inference], + deps[Api.safety], ) await impl.initialize() return impl diff --git a/llama_toolchain/cli/distribution/configure.py b/llama_toolchain/cli/distribution/configure.py index bb87d0ddf..2d0341754 100644 --- a/llama_toolchain/cli/distribution/configure.py +++ b/llama_toolchain/cli/distribution/configure.py @@ -85,24 +85,21 @@ def configure_llama_distribution(dist: "Distribution", conda_env: str): existing_config = yaml.safe_load(fp) adapter_configs = {} - for api_surface, adapter in dist.adapters.items(): + for api, adapter in dist.adapters.items(): if isinstance(adapter, PassthroughApiAdapter): - adapter_configs[api_surface.value] = adapter.dict() + adapter_configs[api.value] = adapter.dict() else: - cprint( - f"Configuring API surface: {api_surface.value}", "white", attrs=["bold"] - ) + cprint(f"Configuring API surface: {api.value}", "white", attrs=["bold"]) config_type = instantiate_class_type(adapter.config_class) config = prompt_for_config( config_type, ( - config_type(**existing_config["adapters"][api_surface.value]) - if existing_config - and api_surface.value in existing_config["adapters"] + config_type(**existing_config["adapters"][api.value]) + if existing_config and api.value in existing_config["adapters"] else None ), ) - adapter_configs[api_surface.value] = { + adapter_configs[api.value] = { "adapter_id": adapter.adapter_id, **config.dict(), } diff --git a/llama_toolchain/cli/distribution/create.py b/llama_toolchain/cli/distribution/create.py index e1cff1244..e5a835c91 100644 --- a/llama_toolchain/cli/distribution/create.py +++ b/llama_toolchain/cli/distribution/create.py @@ -29,7 +29,7 @@ class DistributionCreate(Subcommand): help="Name of the distribution to create", required=True, ) - # for each ApiSurface the user wants to support, we should + # for each Api the user wants to support, we should # get the list of available adapters, ask which one the user # wants to pick and then ask for their configuration. diff --git a/llama_toolchain/distribution/datatypes.py b/llama_toolchain/distribution/datatypes.py index 762b3d487..7dd197a80 100644 --- a/llama_toolchain/distribution/datatypes.py +++ b/llama_toolchain/distribution/datatypes.py @@ -12,14 +12,14 @@ from strong_typing.schema import json_schema_type @json_schema_type -class ApiSurface(Enum): +class Api(Enum): inference = "inference" safety = "safety" agentic_system = "agentic_system" @json_schema_type -class ApiSurfaceEndpoint(BaseModel): +class ApiEndpoint(BaseModel): route: str method: str name: str @@ -27,7 +27,7 @@ class ApiSurfaceEndpoint(BaseModel): @json_schema_type class Adapter(BaseModel): - api_surface: ApiSurface + api: Api adapter_id: str @@ -49,7 +49,7 @@ Fully-qualified name of the module to import. The module is expected to have: ..., description="Fully-qualified classname of the config for this adapter", ) - adapter_dependencies: List[ApiSurface] = Field( + adapter_dependencies: List[Api] = Field( default_factory=list, description="Higher-level API surfaces may depend on other adapters to provide their functionality", ) @@ -75,7 +75,7 @@ class Distribution(BaseModel): name: str description: str - adapters: Dict[ApiSurface, Adapter] = Field( + adapters: Dict[Api, Adapter] = Field( default_factory=dict, description="The API surfaces provided by this distribution", ) diff --git a/llama_toolchain/distribution/distribution.py b/llama_toolchain/distribution/distribution.py index 03bd5d3a5..27a7e4a5d 100644 --- a/llama_toolchain/distribution/distribution.py +++ b/llama_toolchain/distribution/distribution.py @@ -11,7 +11,7 @@ from llama_toolchain.agentic_system.api.endpoints import AgenticSystem from llama_toolchain.inference.api.endpoints import Inference from llama_toolchain.safety.api.endpoints import Safety -from .datatypes import ApiSurface, ApiSurfaceEndpoint, Distribution, SourceAdapter +from .datatypes import Api, ApiEndpoint, Distribution, SourceAdapter def distribution_dependencies(distribution: Distribution) -> List[str]: @@ -24,16 +24,16 @@ def distribution_dependencies(distribution: Distribution) -> List[str]: ] + distribution.additional_pip_packages -def api_surface_endpoints() -> Dict[ApiSurface, List[ApiSurfaceEndpoint]]: - surfaces = {} +def api_endpoints() -> Dict[Api, List[ApiEndpoint]]: + apis = {} protocols = { - ApiSurface.inference: Inference, - ApiSurface.safety: Safety, - ApiSurface.agentic_system: AgenticSystem, + Api.inference: Inference, + Api.safety: Safety, + Api.agentic_system: AgenticSystem, } - for surface, protocol in protocols.items(): + for api, protocol in protocols.items(): endpoints = [] protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction) @@ -46,8 +46,8 @@ def api_surface_endpoints() -> Dict[ApiSurface, List[ApiSurfaceEndpoint]]: # use `post` for all methods right now until we fix up the `webmethod` openapi # annotation and write our own openapi generator - endpoints.append(ApiSurfaceEndpoint(route=route, method="post", name=name)) + endpoints.append(ApiEndpoint(route=route, method="post", name=name)) - surfaces[surface] = endpoints + apis[api] = endpoints - return surfaces + return apis diff --git a/llama_toolchain/distribution/dynamic.py b/llama_toolchain/distribution/dynamic.py index 62e954d46..ae7075940 100644 --- a/llama_toolchain/distribution/dynamic.py +++ b/llama_toolchain/distribution/dynamic.py @@ -17,7 +17,7 @@ def instantiate_class_type(fully_qualified_name): return getattr(module, class_name) -# returns a class implementing the protocol corresponding to the ApiSurface +# returns a class implementing the protocol corresponding to the Api def instantiate_adapter( adapter: SourceAdapter, adapter_config: Dict[str, Any], deps: Dict[str, Adapter] ): diff --git a/llama_toolchain/distribution/registry.py b/llama_toolchain/distribution/registry.py index 897b8f9d0..17fa4bc93 100644 --- a/llama_toolchain/distribution/registry.py +++ b/llama_toolchain/distribution/registry.py @@ -12,7 +12,7 @@ from llama_toolchain.agentic_system.adapters import available_agentic_system_ada from llama_toolchain.inference.adapters import available_inference_adapters from llama_toolchain.safety.adapters import available_safety_adapters -from .datatypes import ApiSurface, Distribution, PassthroughApiAdapter +from .datatypes import Api, Distribution, PassthroughApiAdapter # This is currently duplicated from `requirements.txt` with a few minor changes # dev-dependencies like "ufmt" etc. are nuked. A few specialized dependencies @@ -45,16 +45,16 @@ COMMON_DEPENDENCIES = [ ] -def client_module(api_surface: ApiSurface) -> str: - return f"llama_toolchain.{api_surface.value}.client" +def client_module(api: Api) -> str: + return f"llama_toolchain.{api.value}.client" -def passthrough(api_surface: ApiSurface, port: int) -> PassthroughApiAdapter: +def passthrough(api: Api, port: int) -> PassthroughApiAdapter: return PassthroughApiAdapter( - api_surface=api_surface, - adapter_id=f"{api_surface.value}-passthrough", + api=api, + adapter_id=f"{api.value}-passthrough", base_url=f"http://localhost:{port}", - module=client_module(api_surface), + module=client_module(api), ) @@ -72,11 +72,9 @@ def available_distributions() -> List[Distribution]: description="Use code from `llama_toolchain` itself to serve all llama stack APIs", additional_pip_packages=COMMON_DEPENDENCIES, adapters={ - ApiSurface.inference: inference_adapters_by_id["meta-reference"], - ApiSurface.safety: safety_adapters_by_id["meta-reference"], - ApiSurface.agentic_system: agentic_system_adapters_by_id[ - "meta-reference" - ], + Api.inference: inference_adapters_by_id["meta-reference"], + Api.safety: safety_adapters_by_id["meta-reference"], + Api.agentic_system: agentic_system_adapters_by_id["meta-reference"], }, ), Distribution( @@ -97,9 +95,9 @@ def available_distributions() -> List[Distribution]: "uvicorn", ], adapters={ - ApiSurface.inference: passthrough(ApiSurface.inference, 5001), - ApiSurface.safety: passthrough(ApiSurface.safety, 5001), - ApiSurface.agentic_system: passthrough(ApiSurface.agentic_system, 5001), + Api.inference: passthrough(Api.inference, 5001), + Api.safety: passthrough(Api.safety, 5001), + Api.agentic_system: passthrough(Api.agentic_system, 5001), }, ), Distribution( @@ -107,11 +105,9 @@ def available_distributions() -> List[Distribution]: description="Like local-source, but use ollama for running LLM inference", additional_pip_packages=COMMON_DEPENDENCIES, adapters={ - ApiSurface.inference: inference_adapters_by_id["meta-ollama"], - ApiSurface.safety: safety_adapters_by_id["meta-reference"], - ApiSurface.agentic_system: agentic_system_adapters_by_id[ - "meta-reference" - ], + Api.inference: inference_adapters_by_id["meta-ollama"], + Api.safety: safety_adapters_by_id["meta-reference"], + Api.agentic_system: agentic_system_adapters_by_id["meta-reference"], }, ), ] diff --git a/llama_toolchain/distribution/server.py b/llama_toolchain/distribution/server.py index 5bcabf343..fd49b7d70 100644 --- a/llama_toolchain/distribution/server.py +++ b/llama_toolchain/distribution/server.py @@ -36,8 +36,8 @@ from fastapi.routing import APIRoute from pydantic import BaseModel, ValidationError from termcolor import cprint -from .datatypes import Adapter, ApiSurface, PassthroughApiAdapter -from .distribution import api_surface_endpoints +from .datatypes import Adapter, Api, PassthroughApiAdapter +from .distribution import api_endpoints from .dynamic import instantiate_adapter, instantiate_client from .registry import resolve_distribution @@ -173,7 +173,7 @@ def create_dynamic_typed_route(func: Any): request_model = next(iter(hints.values())) response_model = hints["return"] - # NOTE: I think it is better to just add a method within each ApiSurface + # NOTE: I think it is better to just add a method within each Api # "Protocol" / adapter-impl to tell what sort of a response this request # is going to produce. /chat_completion can produce a streaming or # non-streaming response depending on if request.stream is True / False. @@ -228,23 +228,23 @@ def create_dynamic_typed_route(func: Any): def topological_sort(adapters: List[Adapter]) -> List[Adapter]: - by_id = {x.api_surface: x for x in adapters} + by_id = {x.api: x for x in adapters} - def dfs(a: Adapter, visited: Set[ApiSurface], stack: List[ApiSurface]): - visited.add(a.api_surface) + def dfs(a: Adapter, visited: Set[Api], stack: List[Api]): + visited.add(a.api) if not isinstance(a, PassthroughApiAdapter): - for surface in a.adapter_dependencies: - if surface not in visited: - dfs(by_id[surface], visited, stack) + for api in a.adapter_dependencies: + if api not in visited: + dfs(by_id[api], visited, stack) - stack.append(a.api_surface) + stack.append(a.api) visited = set() stack = [] for a in adapters: - if a.api_surface not in visited: + if a.api not in visited: dfs(a, visited, stack) return [by_id[x] for x in stack] @@ -262,7 +262,7 @@ def main( app = FastAPI() - all_endpoints = api_surface_endpoints() + all_endpoints = api_endpoints() adapter_configs = config["adapters"] adapters = topological_sort(dist.adapters.values()) @@ -271,25 +271,25 @@ def main( # and then you create the routes. impls = {} for adapter in adapters: - surface = adapter.api_surface - if surface.value not in adapter_configs: + api = adapter.api + if api.value not in adapter_configs: raise ValueError( - f"Could not find adapter config for {surface}. Please add it to the config" + f"Could not find adapter config for {api}. Please add it to the config" ) - adapter_config = adapter_configs[surface.value] - endpoints = all_endpoints[surface] + adapter_config = adapter_configs[api.value] + endpoints = all_endpoints[api] if isinstance(adapter, PassthroughApiAdapter): for endpoint in endpoints: url = adapter.base_url.rstrip("/") + endpoint.route getattr(app, endpoint.method)(endpoint.route)( create_dynamic_passthrough(url) ) - impls[surface] = instantiate_client(adapter, adapter.base_url.rstrip("/")) + impls[api] = instantiate_client(adapter, adapter.base_url.rstrip("/")) else: - deps = {surface: impls[surface] for surface in adapter.adapter_dependencies} + deps = {api: impls[api] for api in adapter.adapter_dependencies} impl = instantiate_adapter(adapter, adapter_config, deps) - impls[surface] = impl + impls[api] = impl for endpoint in endpoints: if not hasattr(impl, endpoint.name): # ideally this should be a typing violation already diff --git a/llama_toolchain/inference/adapters.py b/llama_toolchain/inference/adapters.py index 4ab087221..320bad9a7 100644 --- a/llama_toolchain/inference/adapters.py +++ b/llama_toolchain/inference/adapters.py @@ -6,13 +6,13 @@ from typing import List -from llama_toolchain.distribution.datatypes import Adapter, ApiSurface, SourceAdapter +from llama_toolchain.distribution.datatypes import Adapter, Api, SourceAdapter def available_inference_adapters() -> List[Adapter]: return [ SourceAdapter( - api_surface=ApiSurface.inference, + api=Api.inference, adapter_id="meta-reference", pip_packages=[ "torch", @@ -22,7 +22,7 @@ def available_inference_adapters() -> List[Adapter]: config_class="llama_toolchain.inference.inference.MetaReferenceImplConfig", ), SourceAdapter( - api_surface=ApiSurface.inference, + api=Api.inference, adapter_id="meta-ollama", pip_packages=[ "ollama", diff --git a/llama_toolchain/inference/inference.py b/llama_toolchain/inference/inference.py index 7b54313c4..2dd15317e 100644 --- a/llama_toolchain/inference/inference.py +++ b/llama_toolchain/inference/inference.py @@ -11,7 +11,7 @@ from typing import AsyncIterator, Dict, Union from llama_models.llama3_1.api.datatypes import StopReason from llama_models.sku_list import resolve_model -from llama_toolchain.distribution.datatypes import Adapter, ApiSurface +from llama_toolchain.distribution.datatypes import Adapter, Api from .api.config import MetaReferenceImplConfig from .api.datatypes import ( @@ -29,9 +29,7 @@ from .api.endpoints import ( from .model_parallel import LlamaModelParallelGenerator -async def get_adapter_impl( - config: MetaReferenceImplConfig, _deps: Dict[ApiSurface, Adapter] -): +async def get_adapter_impl(config: MetaReferenceImplConfig, _deps: Dict[Api, Adapter]): assert isinstance( config, MetaReferenceImplConfig ), f"Unexpected config type: {type(config)}" diff --git a/llama_toolchain/safety/adapters.py b/llama_toolchain/safety/adapters.py index ab73ffe19..6411da757 100644 --- a/llama_toolchain/safety/adapters.py +++ b/llama_toolchain/safety/adapters.py @@ -6,13 +6,13 @@ from typing import List -from llama_toolchain.distribution.datatypes import Adapter, ApiSurface, SourceAdapter +from llama_toolchain.distribution.datatypes import Adapter, Api, SourceAdapter def available_safety_adapters() -> List[Adapter]: return [ SourceAdapter( - api_surface=ApiSurface.safety, + api=Api.safety, adapter_id="meta-reference", pip_packages=[ "codeshield", diff --git a/llama_toolchain/safety/safety.py b/llama_toolchain/safety/safety.py index 21b7e6f1f..5a01cc2c0 100644 --- a/llama_toolchain/safety/safety.py +++ b/llama_toolchain/safety/safety.py @@ -8,7 +8,7 @@ import asyncio from typing import Dict -from llama_toolchain.distribution.datatypes import Adapter, ApiSurface +from llama_toolchain.distribution.datatypes import Adapter, Api from .config import SafetyConfig from .api.endpoints import * # noqa @@ -23,7 +23,7 @@ from .shields import ( ) -async def get_adapter_impl(config: SafetyConfig, _deps: Dict[ApiSurface, Adapter]): +async def get_adapter_impl(config: SafetyConfig, _deps: Dict[Api, Adapter]): assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}" impl = MetaReferenceSafetyImpl(config)