Adapter -> Provider

This commit is contained in:
Ashwin Bharambe 2024-08-05 13:26:29 -07:00
parent db3e6dda07
commit 65a9e40174
15 changed files with 119 additions and 110 deletions

View file

@ -26,13 +26,13 @@ class ApiEndpoint(BaseModel):
@json_schema_type
class Adapter(BaseModel):
class ProviderSpec(BaseModel):
api: Api
adapter_id: str
provider_id: str
@json_schema_type
class SourceAdapter(Adapter):
class InlineProviderSpec(ProviderSpec):
pip_packages: List[str] = Field(
default_factory=list,
description="The pip dependencies needed for this implementation",
@ -42,21 +42,21 @@ class SourceAdapter(Adapter):
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_adapter_impl(config, deps)`: returns the local implementation
- `get_provider_impl(config, deps)`: returns the local implementation
""",
)
config_class: str = Field(
...,
description="Fully-qualified classname of the config for this adapter",
description="Fully-qualified classname of the config for this provider",
)
adapter_dependencies: List[Api] = Field(
api_dependencies: List[Api] = Field(
default_factory=list,
description="Higher-level API surfaces may depend on other adapters to provide their functionality",
description="Higher-level API surfaces may depend on other providers to provide their functionality",
)
@json_schema_type
class PassthroughApiAdapter(Adapter):
class RemoteProviderSpec(ProviderSpec):
base_url: str = Field(..., description="The base URL for the llama stack provider")
headers: Dict[str, str] = Field(
default_factory=dict,
@ -75,12 +75,12 @@ class Distribution(BaseModel):
name: str
description: str
adapters: Dict[Api, Adapter] = Field(
provider_specs: Dict[Api, ProviderSpec] = Field(
default_factory=dict,
description="The API surfaces provided by this distribution",
description="Provider specifications for each of the APIs provided by this distribution",
)
additional_pip_packages: List[str] = Field(
default_factory=list,
description="Additional pip packages beyond those required by the adapters",
description="Additional pip packages beyond those required by the providers",
)

View file

@ -11,16 +11,16 @@ from llama_toolchain.agentic_system.api.endpoints import AgenticSystem
from llama_toolchain.inference.api.endpoints import Inference
from llama_toolchain.safety.api.endpoints import Safety
from .datatypes import Api, ApiEndpoint, Distribution, SourceAdapter
from .datatypes import Api, ApiEndpoint, Distribution, InlineProviderSpec
def distribution_dependencies(distribution: Distribution) -> List[str]:
# only consider SourceAdapters when calculating dependencies
# only consider InlineProviderSpecs when calculating dependencies
return [
dep
for adapter in distribution.adapters.values()
if isinstance(adapter, SourceAdapter)
for dep in adapter.pip_packages
for provider_spec in distribution.provider_specs.values()
if isinstance(provider_spec, InlineProviderSpec)
for dep in provider_spec.pip_packages
] + distribution.additional_pip_packages

View file

@ -8,7 +8,7 @@ import asyncio
import importlib
from typing import Any, Dict
from .datatypes import Adapter, PassthroughApiAdapter, SourceAdapter
from .datatypes import InlineProviderSpec, ProviderSpec, RemoteProviderSpec
def instantiate_class_type(fully_qualified_name):
@ -18,17 +18,19 @@ def instantiate_class_type(fully_qualified_name):
# returns a class implementing the protocol corresponding to the Api
def instantiate_adapter(
adapter: SourceAdapter, adapter_config: Dict[str, Any], deps: Dict[str, Adapter]
def instantiate_provider(
provider_spec: InlineProviderSpec,
provider_config: Dict[str, Any],
deps: Dict[str, ProviderSpec],
):
module = importlib.import_module(adapter.module)
module = importlib.import_module(provider_spec.module)
config_type = instantiate_class_type(adapter.config_class)
config = config_type(**adapter_config)
return asyncio.run(module.get_adapter_impl(config, deps))
config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider_config)
return asyncio.run(module.get_provider_impl(config, deps))
def instantiate_client(adapter: PassthroughApiAdapter, base_url: str):
module = importlib.import_module(adapter.module)
def instantiate_client(provider_spec: RemoteProviderSpec, base_url: str):
module = importlib.import_module(provider_spec.module)
return asyncio.run(module.get_client_impl(base_url))

View file

@ -7,12 +7,12 @@
from functools import lru_cache
from typing import List, Optional
from llama_toolchain.agentic_system.adapters import available_agentic_system_adapters
from llama_toolchain.agentic_system.providers import available_agentic_system_providers
from llama_toolchain.inference.adapters import available_inference_adapters
from llama_toolchain.safety.adapters import available_safety_adapters
from llama_toolchain.inference.providers import available_inference_providers
from llama_toolchain.safety.providers import available_safety_providers
from .datatypes import Api, Distribution, PassthroughApiAdapter
from .datatypes import Api, Distribution, RemoteProviderSpec
# This is currently duplicated from `requirements.txt` with a few minor changes
# dev-dependencies like "ufmt" etc. are nuked. A few specialized dependencies
@ -49,10 +49,10 @@ def client_module(api: Api) -> str:
return f"llama_toolchain.{api.value}.client"
def passthrough(api: Api, port: int) -> PassthroughApiAdapter:
return PassthroughApiAdapter(
def remote(api: Api, port: int) -> RemoteProviderSpec:
return RemoteProviderSpec(
api=api,
adapter_id=f"{api.value}-passthrough",
provider_id=f"{api.value}-remote",
base_url=f"http://localhost:{port}",
module=client_module(api),
)
@ -60,25 +60,28 @@ def passthrough(api: Api, port: int) -> PassthroughApiAdapter:
@lru_cache()
def available_distributions() -> List[Distribution]:
inference_adapters_by_id = {a.adapter_id: a for a in available_inference_adapters()}
safety_adapters_by_id = {a.adapter_id: a for a in available_safety_adapters()}
agentic_system_adapters_by_id = {
a.adapter_id: a for a in available_agentic_system_adapters()
inference_providers_by_id = {
a.provider_id: a for a in available_inference_providers()
}
safety_providers_by_id = {a.provider_id: a for a in available_safety_providers()}
agentic_system_providers_by_id = {
a.provider_id: a for a in available_agentic_system_providers()
}
return [
Distribution(
name="local-source",
name="local-inline",
description="Use code from `llama_toolchain` itself to serve all llama stack APIs",
additional_pip_packages=COMMON_DEPENDENCIES,
adapters={
Api.inference: inference_adapters_by_id["meta-reference"],
Api.safety: safety_adapters_by_id["meta-reference"],
Api.agentic_system: agentic_system_adapters_by_id["meta-reference"],
provider_specs={
Api.inference: inference_providers_by_id["meta-reference"],
Api.safety: safety_providers_by_id["meta-reference"],
Api.agentic_system: agentic_system_providers_by_id["meta-reference"],
},
),
# NOTE: this hardcodes the ports to which things point to
Distribution(
name="full-passthrough",
name="full-remote",
description="Point to remote services for all llama stack APIs",
additional_pip_packages=[
"python-dotenv",
@ -94,20 +97,20 @@ def available_distributions() -> List[Distribution]:
"pydantic_core==2.18.2",
"uvicorn",
],
adapters={
Api.inference: passthrough(Api.inference, 5001),
Api.safety: passthrough(Api.safety, 5001),
Api.agentic_system: passthrough(Api.agentic_system, 5001),
provider_specs={
Api.inference: remote(Api.inference, 5001),
Api.safety: remote(Api.safety, 5001),
Api.agentic_system: remote(Api.agentic_system, 5001),
},
),
Distribution(
name="local-ollama",
description="Like local-source, but use ollama for running LLM inference",
additional_pip_packages=COMMON_DEPENDENCIES,
adapters={
Api.inference: inference_adapters_by_id["meta-ollama"],
Api.safety: safety_adapters_by_id["meta-reference"],
Api.agentic_system: agentic_system_adapters_by_id["meta-reference"],
provider_specs={
Api.inference: inference_providers_by_id["meta-ollama"],
Api.safety: safety_providers_by_id["meta-reference"],
Api.agentic_system: agentic_system_providers_by_id["meta-reference"],
},
),
]

View file

@ -36,9 +36,9 @@ from fastapi.routing import APIRoute
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from .datatypes import Adapter, Api, Distribution, PassthroughApiAdapter
from .datatypes import Api, Distribution, ProviderSpec, RemoteProviderSpec
from .distribution import api_endpoints
from .dynamic import instantiate_adapter, instantiate_client
from .dynamic import instantiate_client, instantiate_provider
from .registry import resolve_distribution
@ -226,15 +226,15 @@ def create_dynamic_typed_route(func: Any):
return endpoint
def topological_sort(adapters: List[Adapter]) -> List[Adapter]:
def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
by_id = {x.api: x for x in adapters}
by_id = {x.api: x for x in providers}
def dfs(a: Adapter, visited: Set[Api], stack: List[Api]):
def dfs(a: ProviderSpec, visited: Set[Api], stack: List[Api]):
visited.add(a.api)
if not isinstance(a, PassthroughApiAdapter):
for api in a.adapter_dependencies:
if not isinstance(a, RemoteProviderSpec):
for api in a.api_dependencies:
if api not in visited:
dfs(by_id[api], visited, stack)
@ -243,7 +243,7 @@ def topological_sort(adapters: List[Adapter]) -> List[Adapter]:
visited = set()
stack = []
for a in adapters:
for a in providers:
if a.api not in visited:
dfs(a, visited, stack)
@ -251,23 +251,25 @@ def topological_sort(adapters: List[Adapter]) -> List[Adapter]:
def resolve_impls(dist: Distribution, config: Dict[str, Any]) -> Dict[Api, Any]:
adapter_configs = config["adapters"]
adapters = topological_sort(dist.adapters.values())
provider_configs = config["providers"]
provider_specs = topological_sort(dist.provider_specs.values())
impls = {}
for adapter in adapters:
api = adapter.api
if api.value not in adapter_configs:
for provider_spec in provider_specs:
api = provider_spec.api
if api.value not in provider_configs:
raise ValueError(
f"Could not find adapter config for {api}. Please add it to the config"
f"Could not find provider_spec config for {api}. Please add it to the config"
)
adapter_config = adapter_configs[api.value]
if isinstance(adapter, PassthroughApiAdapter):
impls[api] = instantiate_client(adapter, adapter.base_url.rstrip("/"))
provider_config = provider_configs[api.value]
if isinstance(provider_spec, RemoteProviderSpec):
impls[api] = instantiate_client(
provider_spec, provider_spec.base_url.rstrip("/")
)
else:
deps = {api: impls[api] for api in adapter.adapter_dependencies}
impl = instantiate_adapter(adapter, adapter_config, deps)
deps = {api: impls[api] for api in provider_spec.api_dependencies}
impl = instantiate_provider(provider_spec, provider_config, deps)
impls[api] = impl
return impls
@ -288,12 +290,12 @@ def main(
all_endpoints = api_endpoints()
impls = resolve_impls(dist, config)
for adapter in dist.adapters.values():
api = adapter.api
for provider_spec in dist.provider_specs.values():
api = provider_spec.api
endpoints = all_endpoints[api]
if isinstance(adapter, PassthroughApiAdapter):
if isinstance(provider_spec, RemoteProviderSpec):
for endpoint in endpoints:
url = adapter.base_url.rstrip("/") + endpoint.route
url = provider_spec.base_url.rstrip("/") + endpoint.route
getattr(app, endpoint.method)(endpoint.route)(
create_dynamic_passthrough(url)
)