Rename Distribution -> DistributionSpec, simplify RemoteProviders

This commit is contained in:
Ashwin Bharambe 2024-08-06 10:45:06 -07:00
parent 0a67f3d3e6
commit 7cc0445517
9 changed files with 181 additions and 147 deletions

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from enum import Enum
from typing import Dict, List
from typing import Dict, List, Optional
from pydantic import BaseModel, Field
from strong_typing.schema import json_schema_type
@ -29,6 +29,10 @@ class ApiEndpoint(BaseModel):
class ProviderSpec(BaseModel):
api: Api
provider_id: str
config_class: str = Field(
...,
description="Fully-qualified classname of the config for this provider",
)
@json_schema_type
@ -45,23 +49,21 @@ Fully-qualified name of the module to import. The module is expected to have:
- `get_provider_impl(config, deps)`: returns the local implementation
""",
)
config_class: str = Field(
...,
description="Fully-qualified classname of the config for this provider",
)
api_dependencies: List[Api] = Field(
default_factory=list,
description="Higher-level API surfaces may depend on other providers to provide their functionality",
)
class RemoteProviderConfig(BaseModel):
base_url: str = Field(..., description="The base URL for the llama stack provider")
api_key: Optional[str] = Field(
..., description="API key, if needed, for the provider"
)
@json_schema_type
class RemoteProviderSpec(ProviderSpec):
base_url: str = Field(..., description="The base URL for the llama stack provider")
headers: Dict[str, str] = Field(
default_factory=dict,
description="Headers (e.g., authorization) to send with the request",
)
module: str = Field(
...,
description="""
@ -69,10 +71,12 @@ Fully-qualified name of the module to import. The module is expected to have:
- `get_client_impl(base_url)`: returns a client which can be used to call the remote implementation
""",
)
config_class: str = "llama_toolchain.distribution.datatypes.RemoteProviderConfig"
class Distribution(BaseModel):
name: str
@json_schema_type
class DistributionSpec(BaseModel):
spec_id: str
description: str
provider_specs: Dict[Api, ProviderSpec] = Field(
@ -84,3 +88,12 @@ class Distribution(BaseModel):
default_factory=list,
description="Additional pip packages beyond those required by the providers",
)
@json_schema_type
class InstalledDistribution(BaseModel):
"""References to a installed / configured DistributionSpec"""
name: str
spec_id: str
# This is the class which represents the configs written by `configure`

View file

@ -8,13 +8,22 @@ import inspect
from typing import Dict, List
from llama_toolchain.agentic_system.api.endpoints import AgenticSystem
from llama_toolchain.agentic_system.providers import available_agentic_system_providers
from llama_toolchain.inference.api.endpoints import Inference
from llama_toolchain.inference.providers import available_inference_providers
from llama_toolchain.safety.api.endpoints import Safety
from llama_toolchain.safety.providers import available_safety_providers
from .datatypes import Api, ApiEndpoint, Distribution, InlineProviderSpec
from .datatypes import (
Api,
ApiEndpoint,
DistributionSpec,
InlineProviderSpec,
ProviderSpec,
)
def distribution_dependencies(distribution: Distribution) -> List[str]:
def distribution_dependencies(distribution: DistributionSpec) -> List[str]:
# only consider InlineProviderSpecs when calculating dependencies
return [
dep
@ -51,3 +60,19 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
apis[api] = endpoints
return apis
def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
inference_providers_by_id = {
a.provider_id: a for a in available_inference_providers()
}
safety_providers_by_id = {a.provider_id: a for a in available_safety_providers()}
agentic_system_providers_by_id = {
a.provider_id: a for a in available_agentic_system_providers()
}
return {
Api.inference: inference_providers_by_id,
Api.safety: safety_providers_by_id,
Api.agentic_system: agentic_system_providers_by_id,
}

View file

@ -7,12 +7,8 @@
from functools import lru_cache
from typing import List, Optional
from llama_toolchain.agentic_system.providers import available_agentic_system_providers
from llama_toolchain.inference.providers import available_inference_providers
from llama_toolchain.safety.providers import available_safety_providers
from .datatypes import Api, Distribution, RemoteProviderSpec
from .datatypes import Api, DistributionSpec, RemoteProviderSpec
from .distribution import api_providers
# This is currently duplicated from `requirements.txt` with a few minor changes
# dev-dependencies like "ufmt" etc. are nuked. A few specialized dependencies
@ -49,39 +45,30 @@ def client_module(api: Api) -> str:
return f"llama_toolchain.{api.value}.client"
def remote(api: Api, port: int) -> RemoteProviderSpec:
def remote_spec(api: Api) -> RemoteProviderSpec:
return RemoteProviderSpec(
api=api,
provider_id=f"{api.value}-remote",
base_url=f"http://localhost:{port}",
module=client_module(api),
)
@lru_cache()
def available_distributions() -> List[Distribution]:
inference_providers_by_id = {
a.provider_id: a for a in available_inference_providers()
}
safety_providers_by_id = {a.provider_id: a for a in available_safety_providers()}
agentic_system_providers_by_id = {
a.provider_id: a for a in available_agentic_system_providers()
}
def available_distribution_specs() -> List[DistributionSpec]:
providers = api_providers()
return [
Distribution(
name="local-inline",
DistributionSpec(
spec_id="inline",
description="Use code from `llama_toolchain` itself to serve all llama stack APIs",
additional_pip_packages=COMMON_DEPENDENCIES,
provider_specs={
Api.inference: inference_providers_by_id["meta-reference"],
Api.safety: safety_providers_by_id["meta-reference"],
Api.agentic_system: agentic_system_providers_by_id["meta-reference"],
Api.inference: providers[Api.inference]["meta-reference"],
Api.safety: providers[Api.safety]["meta-reference"],
Api.agentic_system: providers[Api.agentic_system]["meta-reference"],
},
),
# NOTE: this hardcodes the ports to which things point to
Distribution(
name="full-remote",
DistributionSpec(
spec_id="remote",
description="Point to remote services for all llama stack APIs",
additional_pip_packages=[
"python-dotenv",
@ -97,28 +84,24 @@ def available_distributions() -> List[Distribution]:
"pydantic_core==2.18.2",
"uvicorn",
],
provider_specs={
Api.inference: remote(Api.inference, 5001),
Api.safety: remote(Api.safety, 5001),
Api.agentic_system: remote(Api.agentic_system, 5001),
},
provider_specs={x: remote_spec(x) for x in providers},
),
Distribution(
name="local-ollama",
DistributionSpec(
spec_id="ollama-inline",
description="Like local-source, but use ollama for running LLM inference",
additional_pip_packages=COMMON_DEPENDENCIES,
provider_specs={
Api.inference: inference_providers_by_id["meta-ollama"],
Api.safety: safety_providers_by_id["meta-reference"],
Api.agentic_system: agentic_system_providers_by_id["meta-reference"],
Api.inference: providers[Api.inference]["meta-ollama"],
Api.safety: providers[Api.safety]["meta-reference"],
Api.agentic_system: providers[Api.agentic_system]["meta-reference"],
},
),
]
@lru_cache()
def resolve_distribution(name: str) -> Optional[Distribution]:
for dist in available_distributions():
if dist.name == name:
return dist
def resolve_distribution_spec(spec_id: str) -> Optional[DistributionSpec]:
for spec in available_distribution_specs():
if spec.spec_id == spec_id:
return spec
return None

View file

@ -36,11 +36,11 @@ from fastapi.routing import APIRoute
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from .datatypes import Api, Distribution, ProviderSpec, RemoteProviderSpec
from .datatypes import Api, DistributionSpec, ProviderSpec, RemoteProviderSpec
from .distribution import api_endpoints
from .dynamic import instantiate_client, instantiate_provider
from .registry import resolve_distribution
from .registry import resolve_distribution_spec
load_dotenv()
@ -250,7 +250,7 @@ def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
return [by_id[x] for x in stack]
def resolve_impls(dist: Distribution, config: Dict[str, Any]) -> Dict[Api, Any]:
def resolve_impls(dist: DistributionSpec, config: Dict[str, Any]) -> Dict[Api, Any]:
provider_configs = config["providers"]
provider_specs = topological_sort(dist.provider_specs.values())
@ -265,7 +265,7 @@ def resolve_impls(dist: Distribution, config: Dict[str, Any]) -> Dict[Api, Any]:
provider_config = provider_configs[api.value]
if isinstance(provider_spec, RemoteProviderSpec):
impls[api] = instantiate_client(
provider_spec, provider_spec.base_url.rstrip("/")
provider_spec, provider_config.base_url.rstrip("/")
)
else:
deps = {api: impls[api] for api in provider_spec.api_dependencies}
@ -275,16 +275,15 @@ def resolve_impls(dist: Distribution, config: Dict[str, Any]) -> Dict[Api, Any]:
return impls
def main(
dist_name: str, yaml_config: str, port: int = 5000, disable_ipv6: bool = False
):
dist = resolve_distribution(dist_name)
if dist is None:
raise ValueError(f"Could not find distribution {dist_name}")
def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
with open(yaml_config, "r") as fp:
config = yaml.safe_load(fp)
spec = config["spec"]
dist = resolve_distribution_spec(spec)
if dist is None:
raise ValueError(f"Could not find distribution specification `{spec}`")
app = FastAPI()
all_endpoints = api_endpoints()
@ -293,14 +292,15 @@ def main(
for provider_spec in dist.provider_specs.values():
api = provider_spec.api
endpoints = all_endpoints[api]
impl = impls[api]
if isinstance(provider_spec, RemoteProviderSpec):
for endpoint in endpoints:
url = provider_spec.base_url.rstrip("/") + endpoint.route
url = impl.base_url + endpoint.route
getattr(app, endpoint.method)(endpoint.route)(
create_dynamic_passthrough(url)
)
else:
impl = impls[api]
for endpoint in endpoints:
if not hasattr(impl, endpoint.name):
# ideally this should be a typing violation already