mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-09 13:14:39 +00:00
Rename Distribution -> DistributionSpec, simplify RemoteProviders
This commit is contained in:
parent
0a67f3d3e6
commit
7cc0445517
9 changed files with 181 additions and 147 deletions
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from strong_typing.schema import json_schema_type
|
||||
|
@ -29,6 +29,10 @@ class ApiEndpoint(BaseModel):
|
|||
class ProviderSpec(BaseModel):
|
||||
api: Api
|
||||
provider_id: str
|
||||
config_class: str = Field(
|
||||
...,
|
||||
description="Fully-qualified classname of the config for this provider",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -45,23 +49,21 @@ Fully-qualified name of the module to import. The module is expected to have:
|
|||
- `get_provider_impl(config, deps)`: returns the local implementation
|
||||
""",
|
||||
)
|
||||
config_class: str = Field(
|
||||
...,
|
||||
description="Fully-qualified classname of the config for this provider",
|
||||
)
|
||||
api_dependencies: List[Api] = Field(
|
||||
default_factory=list,
|
||||
description="Higher-level API surfaces may depend on other providers to provide their functionality",
|
||||
)
|
||||
|
||||
|
||||
class RemoteProviderConfig(BaseModel):
|
||||
base_url: str = Field(..., description="The base URL for the llama stack provider")
|
||||
api_key: Optional[str] = Field(
|
||||
..., description="API key, if needed, for the provider"
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RemoteProviderSpec(ProviderSpec):
|
||||
base_url: str = Field(..., description="The base URL for the llama stack provider")
|
||||
headers: Dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
description="Headers (e.g., authorization) to send with the request",
|
||||
)
|
||||
module: str = Field(
|
||||
...,
|
||||
description="""
|
||||
|
@ -69,10 +71,12 @@ Fully-qualified name of the module to import. The module is expected to have:
|
|||
- `get_client_impl(base_url)`: returns a client which can be used to call the remote implementation
|
||||
""",
|
||||
)
|
||||
config_class: str = "llama_toolchain.distribution.datatypes.RemoteProviderConfig"
|
||||
|
||||
|
||||
class Distribution(BaseModel):
|
||||
name: str
|
||||
@json_schema_type
|
||||
class DistributionSpec(BaseModel):
|
||||
spec_id: str
|
||||
description: str
|
||||
|
||||
provider_specs: Dict[Api, ProviderSpec] = Field(
|
||||
|
@ -84,3 +88,12 @@ class Distribution(BaseModel):
|
|||
default_factory=list,
|
||||
description="Additional pip packages beyond those required by the providers",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class InstalledDistribution(BaseModel):
|
||||
"""References to a installed / configured DistributionSpec"""
|
||||
|
||||
name: str
|
||||
spec_id: str
|
||||
# This is the class which represents the configs written by `configure`
|
||||
|
|
|
@ -8,13 +8,22 @@ import inspect
|
|||
from typing import Dict, List
|
||||
|
||||
from llama_toolchain.agentic_system.api.endpoints import AgenticSystem
|
||||
from llama_toolchain.agentic_system.providers import available_agentic_system_providers
|
||||
from llama_toolchain.inference.api.endpoints import Inference
|
||||
from llama_toolchain.inference.providers import available_inference_providers
|
||||
from llama_toolchain.safety.api.endpoints import Safety
|
||||
from llama_toolchain.safety.providers import available_safety_providers
|
||||
|
||||
from .datatypes import Api, ApiEndpoint, Distribution, InlineProviderSpec
|
||||
from .datatypes import (
|
||||
Api,
|
||||
ApiEndpoint,
|
||||
DistributionSpec,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
)
|
||||
|
||||
|
||||
def distribution_dependencies(distribution: Distribution) -> List[str]:
|
||||
def distribution_dependencies(distribution: DistributionSpec) -> List[str]:
|
||||
# only consider InlineProviderSpecs when calculating dependencies
|
||||
return [
|
||||
dep
|
||||
|
@ -51,3 +60,19 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
|||
apis[api] = endpoints
|
||||
|
||||
return apis
|
||||
|
||||
|
||||
def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
|
||||
inference_providers_by_id = {
|
||||
a.provider_id: a for a in available_inference_providers()
|
||||
}
|
||||
safety_providers_by_id = {a.provider_id: a for a in available_safety_providers()}
|
||||
agentic_system_providers_by_id = {
|
||||
a.provider_id: a for a in available_agentic_system_providers()
|
||||
}
|
||||
|
||||
return {
|
||||
Api.inference: inference_providers_by_id,
|
||||
Api.safety: safety_providers_by_id,
|
||||
Api.agentic_system: agentic_system_providers_by_id,
|
||||
}
|
||||
|
|
|
@ -7,12 +7,8 @@
|
|||
from functools import lru_cache
|
||||
from typing import List, Optional
|
||||
|
||||
from llama_toolchain.agentic_system.providers import available_agentic_system_providers
|
||||
|
||||
from llama_toolchain.inference.providers import available_inference_providers
|
||||
from llama_toolchain.safety.providers import available_safety_providers
|
||||
|
||||
from .datatypes import Api, Distribution, RemoteProviderSpec
|
||||
from .datatypes import Api, DistributionSpec, RemoteProviderSpec
|
||||
from .distribution import api_providers
|
||||
|
||||
# This is currently duplicated from `requirements.txt` with a few minor changes
|
||||
# dev-dependencies like "ufmt" etc. are nuked. A few specialized dependencies
|
||||
|
@ -49,39 +45,30 @@ def client_module(api: Api) -> str:
|
|||
return f"llama_toolchain.{api.value}.client"
|
||||
|
||||
|
||||
def remote(api: Api, port: int) -> RemoteProviderSpec:
|
||||
def remote_spec(api: Api) -> RemoteProviderSpec:
|
||||
return RemoteProviderSpec(
|
||||
api=api,
|
||||
provider_id=f"{api.value}-remote",
|
||||
base_url=f"http://localhost:{port}",
|
||||
module=client_module(api),
|
||||
)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def available_distributions() -> List[Distribution]:
|
||||
inference_providers_by_id = {
|
||||
a.provider_id: a for a in available_inference_providers()
|
||||
}
|
||||
safety_providers_by_id = {a.provider_id: a for a in available_safety_providers()}
|
||||
agentic_system_providers_by_id = {
|
||||
a.provider_id: a for a in available_agentic_system_providers()
|
||||
}
|
||||
|
||||
def available_distribution_specs() -> List[DistributionSpec]:
|
||||
providers = api_providers()
|
||||
return [
|
||||
Distribution(
|
||||
name="local-inline",
|
||||
DistributionSpec(
|
||||
spec_id="inline",
|
||||
description="Use code from `llama_toolchain` itself to serve all llama stack APIs",
|
||||
additional_pip_packages=COMMON_DEPENDENCIES,
|
||||
provider_specs={
|
||||
Api.inference: inference_providers_by_id["meta-reference"],
|
||||
Api.safety: safety_providers_by_id["meta-reference"],
|
||||
Api.agentic_system: agentic_system_providers_by_id["meta-reference"],
|
||||
Api.inference: providers[Api.inference]["meta-reference"],
|
||||
Api.safety: providers[Api.safety]["meta-reference"],
|
||||
Api.agentic_system: providers[Api.agentic_system]["meta-reference"],
|
||||
},
|
||||
),
|
||||
# NOTE: this hardcodes the ports to which things point to
|
||||
Distribution(
|
||||
name="full-remote",
|
||||
DistributionSpec(
|
||||
spec_id="remote",
|
||||
description="Point to remote services for all llama stack APIs",
|
||||
additional_pip_packages=[
|
||||
"python-dotenv",
|
||||
|
@ -97,28 +84,24 @@ def available_distributions() -> List[Distribution]:
|
|||
"pydantic_core==2.18.2",
|
||||
"uvicorn",
|
||||
],
|
||||
provider_specs={
|
||||
Api.inference: remote(Api.inference, 5001),
|
||||
Api.safety: remote(Api.safety, 5001),
|
||||
Api.agentic_system: remote(Api.agentic_system, 5001),
|
||||
},
|
||||
provider_specs={x: remote_spec(x) for x in providers},
|
||||
),
|
||||
Distribution(
|
||||
name="local-ollama",
|
||||
DistributionSpec(
|
||||
spec_id="ollama-inline",
|
||||
description="Like local-source, but use ollama for running LLM inference",
|
||||
additional_pip_packages=COMMON_DEPENDENCIES,
|
||||
provider_specs={
|
||||
Api.inference: inference_providers_by_id["meta-ollama"],
|
||||
Api.safety: safety_providers_by_id["meta-reference"],
|
||||
Api.agentic_system: agentic_system_providers_by_id["meta-reference"],
|
||||
Api.inference: providers[Api.inference]["meta-ollama"],
|
||||
Api.safety: providers[Api.safety]["meta-reference"],
|
||||
Api.agentic_system: providers[Api.agentic_system]["meta-reference"],
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def resolve_distribution(name: str) -> Optional[Distribution]:
|
||||
for dist in available_distributions():
|
||||
if dist.name == name:
|
||||
return dist
|
||||
def resolve_distribution_spec(spec_id: str) -> Optional[DistributionSpec]:
|
||||
for spec in available_distribution_specs():
|
||||
if spec.spec_id == spec_id:
|
||||
return spec
|
||||
return None
|
||||
|
|
|
@ -36,11 +36,11 @@ from fastapi.routing import APIRoute
|
|||
from pydantic import BaseModel, ValidationError
|
||||
from termcolor import cprint
|
||||
|
||||
from .datatypes import Api, Distribution, ProviderSpec, RemoteProviderSpec
|
||||
from .datatypes import Api, DistributionSpec, ProviderSpec, RemoteProviderSpec
|
||||
from .distribution import api_endpoints
|
||||
from .dynamic import instantiate_client, instantiate_provider
|
||||
|
||||
from .registry import resolve_distribution
|
||||
from .registry import resolve_distribution_spec
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
@ -250,7 +250,7 @@ def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
|
|||
return [by_id[x] for x in stack]
|
||||
|
||||
|
||||
def resolve_impls(dist: Distribution, config: Dict[str, Any]) -> Dict[Api, Any]:
|
||||
def resolve_impls(dist: DistributionSpec, config: Dict[str, Any]) -> Dict[Api, Any]:
|
||||
provider_configs = config["providers"]
|
||||
provider_specs = topological_sort(dist.provider_specs.values())
|
||||
|
||||
|
@ -265,7 +265,7 @@ def resolve_impls(dist: Distribution, config: Dict[str, Any]) -> Dict[Api, Any]:
|
|||
provider_config = provider_configs[api.value]
|
||||
if isinstance(provider_spec, RemoteProviderSpec):
|
||||
impls[api] = instantiate_client(
|
||||
provider_spec, provider_spec.base_url.rstrip("/")
|
||||
provider_spec, provider_config.base_url.rstrip("/")
|
||||
)
|
||||
else:
|
||||
deps = {api: impls[api] for api in provider_spec.api_dependencies}
|
||||
|
@ -275,16 +275,15 @@ def resolve_impls(dist: Distribution, config: Dict[str, Any]) -> Dict[Api, Any]:
|
|||
return impls
|
||||
|
||||
|
||||
def main(
|
||||
dist_name: str, yaml_config: str, port: int = 5000, disable_ipv6: bool = False
|
||||
):
|
||||
dist = resolve_distribution(dist_name)
|
||||
if dist is None:
|
||||
raise ValueError(f"Could not find distribution {dist_name}")
|
||||
|
||||
def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
||||
with open(yaml_config, "r") as fp:
|
||||
config = yaml.safe_load(fp)
|
||||
|
||||
spec = config["spec"]
|
||||
dist = resolve_distribution_spec(spec)
|
||||
if dist is None:
|
||||
raise ValueError(f"Could not find distribution specification `{spec}`")
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
all_endpoints = api_endpoints()
|
||||
|
@ -293,14 +292,15 @@ def main(
|
|||
for provider_spec in dist.provider_specs.values():
|
||||
api = provider_spec.api
|
||||
endpoints = all_endpoints[api]
|
||||
impl = impls[api]
|
||||
|
||||
if isinstance(provider_spec, RemoteProviderSpec):
|
||||
for endpoint in endpoints:
|
||||
url = provider_spec.base_url.rstrip("/") + endpoint.route
|
||||
url = impl.base_url + endpoint.route
|
||||
getattr(app, endpoint.method)(endpoint.route)(
|
||||
create_dynamic_passthrough(url)
|
||||
)
|
||||
else:
|
||||
impl = impls[api]
|
||||
for endpoint in endpoints:
|
||||
if not hasattr(impl, endpoint.name):
|
||||
# ideally this should be a typing violation already
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue