ApiSurface -> Api

This commit is contained in:
Ashwin Bharambe 2024-08-05 12:44:56 -07:00
parent 7890921e5c
commit 125fdb1b2a
13 changed files with 76 additions and 87 deletions

View file

@ -6,13 +6,13 @@
from typing import List from typing import List
from llama_toolchain.distribution.datatypes import Adapter, ApiSurface, SourceAdapter from llama_toolchain.distribution.datatypes import Adapter, Api, SourceAdapter
def available_agentic_system_adapters() -> List[Adapter]: def available_agentic_system_adapters() -> List[Adapter]:
return [ return [
SourceAdapter( SourceAdapter(
api_surface=ApiSurface.agentic_system, api=Api.agentic_system,
adapter_id="meta-reference", adapter_id="meta-reference",
pip_packages=[ pip_packages=[
"codeshield", "codeshield",
@ -22,8 +22,8 @@ def available_agentic_system_adapters() -> List[Adapter]:
module="llama_toolchain.agentic_system.agentic_system", module="llama_toolchain.agentic_system.agentic_system",
config_class="llama_toolchain.agentic_system.config.AgenticSystemConfig", config_class="llama_toolchain.agentic_system.config.AgenticSystemConfig",
adapter_dependencies=[ adapter_dependencies=[
ApiSurface.inference, Api.inference,
ApiSurface.safety, Api.safety,
], ],
), ),
] ]

View file

@ -7,7 +7,7 @@
from llama_toolchain.agentic_system.api import AgenticSystem from llama_toolchain.agentic_system.api import AgenticSystem
from llama_toolchain.distribution.datatypes import Adapter, ApiSurface from llama_toolchain.distribution.datatypes import Adapter, Api
from llama_toolchain.inference.api import Inference from llama_toolchain.inference.api import Inference
from llama_toolchain.safety.api import Safety from llama_toolchain.safety.api import Safety
@ -44,16 +44,14 @@ logger = logging.getLogger()
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
async def get_adapter_impl( async def get_adapter_impl(config: AgenticSystemConfig, deps: Dict[Api, Adapter]):
config: AgenticSystemConfig, deps: Dict[ApiSurface, Adapter]
):
assert isinstance( assert isinstance(
config, AgenticSystemConfig config, AgenticSystemConfig
), f"Unexpected config type: {type(config)}" ), f"Unexpected config type: {type(config)}"
impl = MetaReferenceAgenticSystemImpl( impl = MetaReferenceAgenticSystemImpl(
deps[ApiSurface.inference], deps[Api.inference],
deps[ApiSurface.safety], deps[Api.safety],
) )
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -85,24 +85,21 @@ def configure_llama_distribution(dist: "Distribution", conda_env: str):
existing_config = yaml.safe_load(fp) existing_config = yaml.safe_load(fp)
adapter_configs = {} adapter_configs = {}
for api_surface, adapter in dist.adapters.items(): for api, adapter in dist.adapters.items():
if isinstance(adapter, PassthroughApiAdapter): if isinstance(adapter, PassthroughApiAdapter):
adapter_configs[api_surface.value] = adapter.dict() adapter_configs[api.value] = adapter.dict()
else: else:
cprint( cprint(f"Configuring API surface: {api.value}", "white", attrs=["bold"])
f"Configuring API surface: {api_surface.value}", "white", attrs=["bold"]
)
config_type = instantiate_class_type(adapter.config_class) config_type = instantiate_class_type(adapter.config_class)
config = prompt_for_config( config = prompt_for_config(
config_type, config_type,
( (
config_type(**existing_config["adapters"][api_surface.value]) config_type(**existing_config["adapters"][api.value])
if existing_config if existing_config and api.value in existing_config["adapters"]
and api_surface.value in existing_config["adapters"]
else None else None
), ),
) )
adapter_configs[api_surface.value] = { adapter_configs[api.value] = {
"adapter_id": adapter.adapter_id, "adapter_id": adapter.adapter_id,
**config.dict(), **config.dict(),
} }

View file

@ -29,7 +29,7 @@ class DistributionCreate(Subcommand):
help="Name of the distribution to create", help="Name of the distribution to create",
required=True, required=True,
) )
# for each ApiSurface the user wants to support, we should # for each Api the user wants to support, we should
# get the list of available adapters, ask which one the user # get the list of available adapters, ask which one the user
# wants to pick and then ask for their configuration. # wants to pick and then ask for their configuration.

View file

@ -12,14 +12,14 @@ from strong_typing.schema import json_schema_type
@json_schema_type @json_schema_type
class ApiSurface(Enum): class Api(Enum):
inference = "inference" inference = "inference"
safety = "safety" safety = "safety"
agentic_system = "agentic_system" agentic_system = "agentic_system"
@json_schema_type @json_schema_type
class ApiSurfaceEndpoint(BaseModel): class ApiEndpoint(BaseModel):
route: str route: str
method: str method: str
name: str name: str
@ -27,7 +27,7 @@ class ApiSurfaceEndpoint(BaseModel):
@json_schema_type @json_schema_type
class Adapter(BaseModel): class Adapter(BaseModel):
api_surface: ApiSurface api: Api
adapter_id: str adapter_id: str
@ -49,7 +49,7 @@ Fully-qualified name of the module to import. The module is expected to have:
..., ...,
description="Fully-qualified classname of the config for this adapter", description="Fully-qualified classname of the config for this adapter",
) )
adapter_dependencies: List[ApiSurface] = Field( adapter_dependencies: List[Api] = Field(
default_factory=list, default_factory=list,
description="Higher-level API surfaces may depend on other adapters to provide their functionality", description="Higher-level API surfaces may depend on other adapters to provide their functionality",
) )
@ -75,7 +75,7 @@ class Distribution(BaseModel):
name: str name: str
description: str description: str
adapters: Dict[ApiSurface, Adapter] = Field( adapters: Dict[Api, Adapter] = Field(
default_factory=dict, default_factory=dict,
description="The API surfaces provided by this distribution", description="The API surfaces provided by this distribution",
) )

View file

@ -11,7 +11,7 @@ from llama_toolchain.agentic_system.api.endpoints import AgenticSystem
from llama_toolchain.inference.api.endpoints import Inference from llama_toolchain.inference.api.endpoints import Inference
from llama_toolchain.safety.api.endpoints import Safety from llama_toolchain.safety.api.endpoints import Safety
from .datatypes import ApiSurface, ApiSurfaceEndpoint, Distribution, SourceAdapter from .datatypes import Api, ApiEndpoint, Distribution, SourceAdapter
def distribution_dependencies(distribution: Distribution) -> List[str]: def distribution_dependencies(distribution: Distribution) -> List[str]:
@ -24,16 +24,16 @@ def distribution_dependencies(distribution: Distribution) -> List[str]:
] + distribution.additional_pip_packages ] + distribution.additional_pip_packages
def api_surface_endpoints() -> Dict[ApiSurface, List[ApiSurfaceEndpoint]]: def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
surfaces = {} apis = {}
protocols = { protocols = {
ApiSurface.inference: Inference, Api.inference: Inference,
ApiSurface.safety: Safety, Api.safety: Safety,
ApiSurface.agentic_system: AgenticSystem, Api.agentic_system: AgenticSystem,
} }
for surface, protocol in protocols.items(): for api, protocol in protocols.items():
endpoints = [] endpoints = []
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction) protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
@ -46,8 +46,8 @@ def api_surface_endpoints() -> Dict[ApiSurface, List[ApiSurfaceEndpoint]]:
# use `post` for all methods right now until we fix up the `webmethod` openapi # use `post` for all methods right now until we fix up the `webmethod` openapi
# annotation and write our own openapi generator # annotation and write our own openapi generator
endpoints.append(ApiSurfaceEndpoint(route=route, method="post", name=name)) endpoints.append(ApiEndpoint(route=route, method="post", name=name))
surfaces[surface] = endpoints apis[api] = endpoints
return surfaces return apis

View file

@ -17,7 +17,7 @@ def instantiate_class_type(fully_qualified_name):
return getattr(module, class_name) return getattr(module, class_name)
# returns a class implementing the protocol corresponding to the ApiSurface # returns a class implementing the protocol corresponding to the Api
def instantiate_adapter( def instantiate_adapter(
adapter: SourceAdapter, adapter_config: Dict[str, Any], deps: Dict[str, Adapter] adapter: SourceAdapter, adapter_config: Dict[str, Any], deps: Dict[str, Adapter]
): ):

View file

@ -12,7 +12,7 @@ from llama_toolchain.agentic_system.adapters import available_agentic_system_ada
from llama_toolchain.inference.adapters import available_inference_adapters from llama_toolchain.inference.adapters import available_inference_adapters
from llama_toolchain.safety.adapters import available_safety_adapters from llama_toolchain.safety.adapters import available_safety_adapters
from .datatypes import ApiSurface, Distribution, PassthroughApiAdapter from .datatypes import Api, Distribution, PassthroughApiAdapter
# This is currently duplicated from `requirements.txt` with a few minor changes # This is currently duplicated from `requirements.txt` with a few minor changes
# dev-dependencies like "ufmt" etc. are nuked. A few specialized dependencies # dev-dependencies like "ufmt" etc. are nuked. A few specialized dependencies
@ -45,16 +45,16 @@ COMMON_DEPENDENCIES = [
] ]
def client_module(api_surface: ApiSurface) -> str: def client_module(api: Api) -> str:
return f"llama_toolchain.{api_surface.value}.client" return f"llama_toolchain.{api.value}.client"
def passthrough(api_surface: ApiSurface, port: int) -> PassthroughApiAdapter: def passthrough(api: Api, port: int) -> PassthroughApiAdapter:
return PassthroughApiAdapter( return PassthroughApiAdapter(
api_surface=api_surface, api=api,
adapter_id=f"{api_surface.value}-passthrough", adapter_id=f"{api.value}-passthrough",
base_url=f"http://localhost:{port}", base_url=f"http://localhost:{port}",
module=client_module(api_surface), module=client_module(api),
) )
@ -72,11 +72,9 @@ def available_distributions() -> List[Distribution]:
description="Use code from `llama_toolchain` itself to serve all llama stack APIs", description="Use code from `llama_toolchain` itself to serve all llama stack APIs",
additional_pip_packages=COMMON_DEPENDENCIES, additional_pip_packages=COMMON_DEPENDENCIES,
adapters={ adapters={
ApiSurface.inference: inference_adapters_by_id["meta-reference"], Api.inference: inference_adapters_by_id["meta-reference"],
ApiSurface.safety: safety_adapters_by_id["meta-reference"], Api.safety: safety_adapters_by_id["meta-reference"],
ApiSurface.agentic_system: agentic_system_adapters_by_id[ Api.agentic_system: agentic_system_adapters_by_id["meta-reference"],
"meta-reference"
],
}, },
), ),
Distribution( Distribution(
@ -97,9 +95,9 @@ def available_distributions() -> List[Distribution]:
"uvicorn", "uvicorn",
], ],
adapters={ adapters={
ApiSurface.inference: passthrough(ApiSurface.inference, 5001), Api.inference: passthrough(Api.inference, 5001),
ApiSurface.safety: passthrough(ApiSurface.safety, 5001), Api.safety: passthrough(Api.safety, 5001),
ApiSurface.agentic_system: passthrough(ApiSurface.agentic_system, 5001), Api.agentic_system: passthrough(Api.agentic_system, 5001),
}, },
), ),
Distribution( Distribution(
@ -107,11 +105,9 @@ def available_distributions() -> List[Distribution]:
description="Like local-source, but use ollama for running LLM inference", description="Like local-source, but use ollama for running LLM inference",
additional_pip_packages=COMMON_DEPENDENCIES, additional_pip_packages=COMMON_DEPENDENCIES,
adapters={ adapters={
ApiSurface.inference: inference_adapters_by_id["meta-ollama"], Api.inference: inference_adapters_by_id["meta-ollama"],
ApiSurface.safety: safety_adapters_by_id["meta-reference"], Api.safety: safety_adapters_by_id["meta-reference"],
ApiSurface.agentic_system: agentic_system_adapters_by_id[ Api.agentic_system: agentic_system_adapters_by_id["meta-reference"],
"meta-reference"
],
}, },
), ),
] ]

View file

@ -36,8 +36,8 @@ from fastapi.routing import APIRoute
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
from termcolor import cprint from termcolor import cprint
from .datatypes import Adapter, ApiSurface, PassthroughApiAdapter from .datatypes import Adapter, Api, PassthroughApiAdapter
from .distribution import api_surface_endpoints from .distribution import api_endpoints
from .dynamic import instantiate_adapter, instantiate_client from .dynamic import instantiate_adapter, instantiate_client
from .registry import resolve_distribution from .registry import resolve_distribution
@ -173,7 +173,7 @@ def create_dynamic_typed_route(func: Any):
request_model = next(iter(hints.values())) request_model = next(iter(hints.values()))
response_model = hints["return"] response_model = hints["return"]
# NOTE: I think it is better to just add a method within each ApiSurface # NOTE: I think it is better to just add a method within each Api
# "Protocol" / adapter-impl to tell what sort of a response this request # "Protocol" / adapter-impl to tell what sort of a response this request
# is going to produce. /chat_completion can produce a streaming or # is going to produce. /chat_completion can produce a streaming or
# non-streaming response depending on if request.stream is True / False. # non-streaming response depending on if request.stream is True / False.
@ -228,23 +228,23 @@ def create_dynamic_typed_route(func: Any):
def topological_sort(adapters: List[Adapter]) -> List[Adapter]: def topological_sort(adapters: List[Adapter]) -> List[Adapter]:
by_id = {x.api_surface: x for x in adapters} by_id = {x.api: x for x in adapters}
def dfs(a: Adapter, visited: Set[ApiSurface], stack: List[ApiSurface]): def dfs(a: Adapter, visited: Set[Api], stack: List[Api]):
visited.add(a.api_surface) visited.add(a.api)
if not isinstance(a, PassthroughApiAdapter): if not isinstance(a, PassthroughApiAdapter):
for surface in a.adapter_dependencies: for api in a.adapter_dependencies:
if surface not in visited: if api not in visited:
dfs(by_id[surface], visited, stack) dfs(by_id[api], visited, stack)
stack.append(a.api_surface) stack.append(a.api)
visited = set() visited = set()
stack = [] stack = []
for a in adapters: for a in adapters:
if a.api_surface not in visited: if a.api not in visited:
dfs(a, visited, stack) dfs(a, visited, stack)
return [by_id[x] for x in stack] return [by_id[x] for x in stack]
@ -262,7 +262,7 @@ def main(
app = FastAPI() app = FastAPI()
all_endpoints = api_surface_endpoints() all_endpoints = api_endpoints()
adapter_configs = config["adapters"] adapter_configs = config["adapters"]
adapters = topological_sort(dist.adapters.values()) adapters = topological_sort(dist.adapters.values())
@ -271,25 +271,25 @@ def main(
# and then you create the routes. # and then you create the routes.
impls = {} impls = {}
for adapter in adapters: for adapter in adapters:
surface = adapter.api_surface api = adapter.api
if surface.value not in adapter_configs: if api.value not in adapter_configs:
raise ValueError( raise ValueError(
f"Could not find adapter config for {surface}. Please add it to the config" f"Could not find adapter config for {api}. Please add it to the config"
) )
adapter_config = adapter_configs[surface.value] adapter_config = adapter_configs[api.value]
endpoints = all_endpoints[surface] endpoints = all_endpoints[api]
if isinstance(adapter, PassthroughApiAdapter): if isinstance(adapter, PassthroughApiAdapter):
for endpoint in endpoints: for endpoint in endpoints:
url = adapter.base_url.rstrip("/") + endpoint.route url = adapter.base_url.rstrip("/") + endpoint.route
getattr(app, endpoint.method)(endpoint.route)( getattr(app, endpoint.method)(endpoint.route)(
create_dynamic_passthrough(url) create_dynamic_passthrough(url)
) )
impls[surface] = instantiate_client(adapter, adapter.base_url.rstrip("/")) impls[api] = instantiate_client(adapter, adapter.base_url.rstrip("/"))
else: else:
deps = {surface: impls[surface] for surface in adapter.adapter_dependencies} deps = {api: impls[api] for api in adapter.adapter_dependencies}
impl = instantiate_adapter(adapter, adapter_config, deps) impl = instantiate_adapter(adapter, adapter_config, deps)
impls[surface] = impl impls[api] = impl
for endpoint in endpoints: for endpoint in endpoints:
if not hasattr(impl, endpoint.name): if not hasattr(impl, endpoint.name):
# ideally this should be a typing violation already # ideally this should be a typing violation already

View file

@ -6,13 +6,13 @@
from typing import List from typing import List
from llama_toolchain.distribution.datatypes import Adapter, ApiSurface, SourceAdapter from llama_toolchain.distribution.datatypes import Adapter, Api, SourceAdapter
def available_inference_adapters() -> List[Adapter]: def available_inference_adapters() -> List[Adapter]:
return [ return [
SourceAdapter( SourceAdapter(
api_surface=ApiSurface.inference, api=Api.inference,
adapter_id="meta-reference", adapter_id="meta-reference",
pip_packages=[ pip_packages=[
"torch", "torch",
@ -22,7 +22,7 @@ def available_inference_adapters() -> List[Adapter]:
config_class="llama_toolchain.inference.inference.MetaReferenceImplConfig", config_class="llama_toolchain.inference.inference.MetaReferenceImplConfig",
), ),
SourceAdapter( SourceAdapter(
api_surface=ApiSurface.inference, api=Api.inference,
adapter_id="meta-ollama", adapter_id="meta-ollama",
pip_packages=[ pip_packages=[
"ollama", "ollama",

View file

@ -11,7 +11,7 @@ from typing import AsyncIterator, Dict, Union
from llama_models.llama3_1.api.datatypes import StopReason from llama_models.llama3_1.api.datatypes import StopReason
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from llama_toolchain.distribution.datatypes import Adapter, ApiSurface from llama_toolchain.distribution.datatypes import Adapter, Api
from .api.config import MetaReferenceImplConfig from .api.config import MetaReferenceImplConfig
from .api.datatypes import ( from .api.datatypes import (
@ -29,9 +29,7 @@ from .api.endpoints import (
from .model_parallel import LlamaModelParallelGenerator from .model_parallel import LlamaModelParallelGenerator
async def get_adapter_impl( async def get_adapter_impl(config: MetaReferenceImplConfig, _deps: Dict[Api, Adapter]):
config: MetaReferenceImplConfig, _deps: Dict[ApiSurface, Adapter]
):
assert isinstance( assert isinstance(
config, MetaReferenceImplConfig config, MetaReferenceImplConfig
), f"Unexpected config type: {type(config)}" ), f"Unexpected config type: {type(config)}"

View file

@ -6,13 +6,13 @@
from typing import List from typing import List
from llama_toolchain.distribution.datatypes import Adapter, ApiSurface, SourceAdapter from llama_toolchain.distribution.datatypes import Adapter, Api, SourceAdapter
def available_safety_adapters() -> List[Adapter]: def available_safety_adapters() -> List[Adapter]:
return [ return [
SourceAdapter( SourceAdapter(
api_surface=ApiSurface.safety, api=Api.safety,
adapter_id="meta-reference", adapter_id="meta-reference",
pip_packages=[ pip_packages=[
"codeshield", "codeshield",

View file

@ -8,7 +8,7 @@ import asyncio
from typing import Dict from typing import Dict
from llama_toolchain.distribution.datatypes import Adapter, ApiSurface from llama_toolchain.distribution.datatypes import Adapter, Api
from .config import SafetyConfig from .config import SafetyConfig
from .api.endpoints import * # noqa from .api.endpoints import * # noqa
@ -23,7 +23,7 @@ from .shields import (
) )
async def get_adapter_impl(config: SafetyConfig, _deps: Dict[ApiSurface, Adapter]): async def get_adapter_impl(config: SafetyConfig, _deps: Dict[Api, Adapter]):
assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}" assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}"
impl = MetaReferenceSafetyImpl(config) impl = MetaReferenceSafetyImpl(config)