mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Adapter -> Provider
This commit is contained in:
parent
db3e6dda07
commit
65a9e40174
15 changed files with 119 additions and 110 deletions
|
@ -7,7 +7,7 @@
|
||||||
|
|
||||||
from llama_toolchain.agentic_system.api import AgenticSystem
|
from llama_toolchain.agentic_system.api import AgenticSystem
|
||||||
|
|
||||||
from llama_toolchain.distribution.datatypes import Adapter, Api
|
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
|
||||||
from llama_toolchain.inference.api import Inference
|
from llama_toolchain.inference.api import Inference
|
||||||
from llama_toolchain.safety.api import Safety
|
from llama_toolchain.safety.api import Safety
|
||||||
|
|
||||||
|
@ -44,7 +44,7 @@ logger = logging.getLogger()
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: AgenticSystemConfig, deps: Dict[Api, Adapter]):
|
async def get_provider_impl(config: AgenticSystemConfig, deps: Dict[Api, ProviderSpec]):
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
config, AgenticSystemConfig
|
config, AgenticSystemConfig
|
||||||
), f"Unexpected config type: {type(config)}"
|
), f"Unexpected config type: {type(config)}"
|
||||||
|
|
|
@ -6,14 +6,14 @@
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from llama_toolchain.distribution.datatypes import Adapter, Api, SourceAdapter
|
from llama_toolchain.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec
|
||||||
|
|
||||||
|
|
||||||
def available_agentic_system_adapters() -> List[Adapter]:
|
def available_agentic_system_providers() -> List[ProviderSpec]:
|
||||||
return [
|
return [
|
||||||
SourceAdapter(
|
InlineProviderSpec(
|
||||||
api=Api.agentic_system,
|
api=Api.agentic_system,
|
||||||
adapter_id="meta-reference",
|
provider_id="meta-reference",
|
||||||
pip_packages=[
|
pip_packages=[
|
||||||
"codeshield",
|
"codeshield",
|
||||||
"torch",
|
"torch",
|
||||||
|
@ -21,7 +21,7 @@ def available_agentic_system_adapters() -> List[Adapter]:
|
||||||
],
|
],
|
||||||
module="llama_toolchain.agentic_system.agentic_system",
|
module="llama_toolchain.agentic_system.agentic_system",
|
||||||
config_class="llama_toolchain.agentic_system.config.AgenticSystemConfig",
|
config_class="llama_toolchain.agentic_system.config.AgenticSystemConfig",
|
||||||
adapter_dependencies=[
|
api_dependencies=[
|
||||||
Api.inference,
|
Api.inference,
|
||||||
Api.safety,
|
Api.safety,
|
||||||
],
|
],
|
|
@ -63,7 +63,7 @@ def configure_llama_distribution(dist: "Distribution", conda_env: str):
|
||||||
from llama_toolchain.common.exec import run_command
|
from llama_toolchain.common.exec import run_command
|
||||||
from llama_toolchain.common.prompt_for_config import prompt_for_config
|
from llama_toolchain.common.prompt_for_config import prompt_for_config
|
||||||
from llama_toolchain.common.serialize import EnumEncoder
|
from llama_toolchain.common.serialize import EnumEncoder
|
||||||
from llama_toolchain.distribution.datatypes import PassthroughApiAdapter
|
from llama_toolchain.distribution.datatypes import RemoteProviderSpec
|
||||||
from llama_toolchain.distribution.dynamic import instantiate_class_type
|
from llama_toolchain.distribution.dynamic import instantiate_class_type
|
||||||
|
|
||||||
python_exe = run_command(shlex.split("which python"))
|
python_exe = run_command(shlex.split("which python"))
|
||||||
|
@ -84,28 +84,28 @@ def configure_llama_distribution(dist: "Distribution", conda_env: str):
|
||||||
with open(config_path, "r") as fp:
|
with open(config_path, "r") as fp:
|
||||||
existing_config = yaml.safe_load(fp)
|
existing_config = yaml.safe_load(fp)
|
||||||
|
|
||||||
adapter_configs = {}
|
provider_configs = {}
|
||||||
for api, adapter in dist.adapters.items():
|
for api, provider_spec in dist.provider_specs.items():
|
||||||
if isinstance(adapter, PassthroughApiAdapter):
|
if isinstance(provider_spec, RemoteProviderSpec):
|
||||||
adapter_configs[api.value] = adapter.dict()
|
provider_configs[api.value] = provider_spec.dict()
|
||||||
else:
|
else:
|
||||||
cprint(f"Configuring API surface: {api.value}", "white", attrs=["bold"])
|
cprint(f"Configuring API surface: {api.value}", "white", attrs=["bold"])
|
||||||
config_type = instantiate_class_type(adapter.config_class)
|
config_type = instantiate_class_type(provider_spec.config_class)
|
||||||
config = prompt_for_config(
|
config = prompt_for_config(
|
||||||
config_type,
|
config_type,
|
||||||
(
|
(
|
||||||
config_type(**existing_config["adapters"][api.value])
|
config_type(**existing_config["providers"][api.value])
|
||||||
if existing_config and api.value in existing_config["adapters"]
|
if existing_config and api.value in existing_config["providers"]
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
adapter_configs[api.value] = {
|
provider_configs[api.value] = {
|
||||||
"adapter_id": adapter.adapter_id,
|
"provider_id": provider_spec.provider_id,
|
||||||
**config.dict(),
|
**config.dict(),
|
||||||
}
|
}
|
||||||
|
|
||||||
dist_config = {
|
dist_config = {
|
||||||
"adapters": adapter_configs,
|
"providers": provider_configs,
|
||||||
"conda_env": conda_env,
|
"conda_env": conda_env,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -30,7 +30,7 @@ class DistributionCreate(Subcommand):
|
||||||
required=True,
|
required=True,
|
||||||
)
|
)
|
||||||
# for each Api the user wants to support, we should
|
# for each Api the user wants to support, we should
|
||||||
# get the list of available adapters, ask which one the user
|
# get the list of available providers, ask which one the user
|
||||||
# wants to pick and then ask for their configuration.
|
# wants to pick and then ask for their configuration.
|
||||||
|
|
||||||
def _run_distribution_create_cmd(self, args: argparse.Namespace) -> None:
|
def _run_distribution_create_cmd(self, args: argparse.Namespace) -> None:
|
||||||
|
|
|
@ -33,17 +33,17 @@ class DistributionList(Subcommand):
|
||||||
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
|
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
|
||||||
headers = [
|
headers = [
|
||||||
"Name",
|
"Name",
|
||||||
"Adapters",
|
"ProviderSpecs",
|
||||||
"Description",
|
"Description",
|
||||||
]
|
]
|
||||||
|
|
||||||
rows = []
|
rows = []
|
||||||
for dist in available_distributions():
|
for dist in available_distributions():
|
||||||
adapters = {k.value: v.adapter_id for k, v in dist.adapters.items()}
|
providers = {k.value: v.provider_id for k, v in dist.provider_specs.items()}
|
||||||
rows.append(
|
rows.append(
|
||||||
[
|
[
|
||||||
dist.name,
|
dist.name,
|
||||||
json.dumps(adapters, indent=2),
|
json.dumps(providers, indent=2),
|
||||||
dist.description,
|
dist.description,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
|
@ -26,13 +26,13 @@ class ApiEndpoint(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class Adapter(BaseModel):
|
class ProviderSpec(BaseModel):
|
||||||
api: Api
|
api: Api
|
||||||
adapter_id: str
|
provider_id: str
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class SourceAdapter(Adapter):
|
class InlineProviderSpec(ProviderSpec):
|
||||||
pip_packages: List[str] = Field(
|
pip_packages: List[str] = Field(
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
description="The pip dependencies needed for this implementation",
|
description="The pip dependencies needed for this implementation",
|
||||||
|
@ -42,21 +42,21 @@ class SourceAdapter(Adapter):
|
||||||
description="""
|
description="""
|
||||||
Fully-qualified name of the module to import. The module is expected to have:
|
Fully-qualified name of the module to import. The module is expected to have:
|
||||||
|
|
||||||
- `get_adapter_impl(config, deps)`: returns the local implementation
|
- `get_provider_impl(config, deps)`: returns the local implementation
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
config_class: str = Field(
|
config_class: str = Field(
|
||||||
...,
|
...,
|
||||||
description="Fully-qualified classname of the config for this adapter",
|
description="Fully-qualified classname of the config for this provider",
|
||||||
)
|
)
|
||||||
adapter_dependencies: List[Api] = Field(
|
api_dependencies: List[Api] = Field(
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
description="Higher-level API surfaces may depend on other adapters to provide their functionality",
|
description="Higher-level API surfaces may depend on other providers to provide their functionality",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class PassthroughApiAdapter(Adapter):
|
class RemoteProviderSpec(ProviderSpec):
|
||||||
base_url: str = Field(..., description="The base URL for the llama stack provider")
|
base_url: str = Field(..., description="The base URL for the llama stack provider")
|
||||||
headers: Dict[str, str] = Field(
|
headers: Dict[str, str] = Field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
|
@ -75,12 +75,12 @@ class Distribution(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
|
|
||||||
adapters: Dict[Api, Adapter] = Field(
|
provider_specs: Dict[Api, ProviderSpec] = Field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
description="The API surfaces provided by this distribution",
|
description="Provider specifications for each of the APIs provided by this distribution",
|
||||||
)
|
)
|
||||||
|
|
||||||
additional_pip_packages: List[str] = Field(
|
additional_pip_packages: List[str] = Field(
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
description="Additional pip packages beyond those required by the adapters",
|
description="Additional pip packages beyond those required by the providers",
|
||||||
)
|
)
|
||||||
|
|
|
@ -11,16 +11,16 @@ from llama_toolchain.agentic_system.api.endpoints import AgenticSystem
|
||||||
from llama_toolchain.inference.api.endpoints import Inference
|
from llama_toolchain.inference.api.endpoints import Inference
|
||||||
from llama_toolchain.safety.api.endpoints import Safety
|
from llama_toolchain.safety.api.endpoints import Safety
|
||||||
|
|
||||||
from .datatypes import Api, ApiEndpoint, Distribution, SourceAdapter
|
from .datatypes import Api, ApiEndpoint, Distribution, InlineProviderSpec
|
||||||
|
|
||||||
|
|
||||||
def distribution_dependencies(distribution: Distribution) -> List[str]:
|
def distribution_dependencies(distribution: Distribution) -> List[str]:
|
||||||
# only consider SourceAdapters when calculating dependencies
|
# only consider InlineProviderSpecs when calculating dependencies
|
||||||
return [
|
return [
|
||||||
dep
|
dep
|
||||||
for adapter in distribution.adapters.values()
|
for provider_spec in distribution.provider_specs.values()
|
||||||
if isinstance(adapter, SourceAdapter)
|
if isinstance(provider_spec, InlineProviderSpec)
|
||||||
for dep in adapter.pip_packages
|
for dep in provider_spec.pip_packages
|
||||||
] + distribution.additional_pip_packages
|
] + distribution.additional_pip_packages
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,7 @@ import asyncio
|
||||||
import importlib
|
import importlib
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
from .datatypes import Adapter, PassthroughApiAdapter, SourceAdapter
|
from .datatypes import InlineProviderSpec, ProviderSpec, RemoteProviderSpec
|
||||||
|
|
||||||
|
|
||||||
def instantiate_class_type(fully_qualified_name):
|
def instantiate_class_type(fully_qualified_name):
|
||||||
|
@ -18,17 +18,19 @@ def instantiate_class_type(fully_qualified_name):
|
||||||
|
|
||||||
|
|
||||||
# returns a class implementing the protocol corresponding to the Api
|
# returns a class implementing the protocol corresponding to the Api
|
||||||
def instantiate_adapter(
|
def instantiate_provider(
|
||||||
adapter: SourceAdapter, adapter_config: Dict[str, Any], deps: Dict[str, Adapter]
|
provider_spec: InlineProviderSpec,
|
||||||
|
provider_config: Dict[str, Any],
|
||||||
|
deps: Dict[str, ProviderSpec],
|
||||||
):
|
):
|
||||||
module = importlib.import_module(adapter.module)
|
module = importlib.import_module(provider_spec.module)
|
||||||
|
|
||||||
config_type = instantiate_class_type(adapter.config_class)
|
config_type = instantiate_class_type(provider_spec.config_class)
|
||||||
config = config_type(**adapter_config)
|
config = config_type(**provider_config)
|
||||||
return asyncio.run(module.get_adapter_impl(config, deps))
|
return asyncio.run(module.get_provider_impl(config, deps))
|
||||||
|
|
||||||
|
|
||||||
def instantiate_client(adapter: PassthroughApiAdapter, base_url: str):
|
def instantiate_client(provider_spec: RemoteProviderSpec, base_url: str):
|
||||||
module = importlib.import_module(adapter.module)
|
module = importlib.import_module(provider_spec.module)
|
||||||
|
|
||||||
return asyncio.run(module.get_client_impl(base_url))
|
return asyncio.run(module.get_client_impl(base_url))
|
||||||
|
|
|
@ -7,12 +7,12 @@
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from llama_toolchain.agentic_system.adapters import available_agentic_system_adapters
|
from llama_toolchain.agentic_system.providers import available_agentic_system_providers
|
||||||
|
|
||||||
from llama_toolchain.inference.adapters import available_inference_adapters
|
from llama_toolchain.inference.providers import available_inference_providers
|
||||||
from llama_toolchain.safety.adapters import available_safety_adapters
|
from llama_toolchain.safety.providers import available_safety_providers
|
||||||
|
|
||||||
from .datatypes import Api, Distribution, PassthroughApiAdapter
|
from .datatypes import Api, Distribution, RemoteProviderSpec
|
||||||
|
|
||||||
# This is currently duplicated from `requirements.txt` with a few minor changes
|
# This is currently duplicated from `requirements.txt` with a few minor changes
|
||||||
# dev-dependencies like "ufmt" etc. are nuked. A few specialized dependencies
|
# dev-dependencies like "ufmt" etc. are nuked. A few specialized dependencies
|
||||||
|
@ -49,10 +49,10 @@ def client_module(api: Api) -> str:
|
||||||
return f"llama_toolchain.{api.value}.client"
|
return f"llama_toolchain.{api.value}.client"
|
||||||
|
|
||||||
|
|
||||||
def passthrough(api: Api, port: int) -> PassthroughApiAdapter:
|
def remote(api: Api, port: int) -> RemoteProviderSpec:
|
||||||
return PassthroughApiAdapter(
|
return RemoteProviderSpec(
|
||||||
api=api,
|
api=api,
|
||||||
adapter_id=f"{api.value}-passthrough",
|
provider_id=f"{api.value}-remote",
|
||||||
base_url=f"http://localhost:{port}",
|
base_url=f"http://localhost:{port}",
|
||||||
module=client_module(api),
|
module=client_module(api),
|
||||||
)
|
)
|
||||||
|
@ -60,25 +60,28 @@ def passthrough(api: Api, port: int) -> PassthroughApiAdapter:
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def available_distributions() -> List[Distribution]:
|
def available_distributions() -> List[Distribution]:
|
||||||
inference_adapters_by_id = {a.adapter_id: a for a in available_inference_adapters()}
|
inference_providers_by_id = {
|
||||||
safety_adapters_by_id = {a.adapter_id: a for a in available_safety_adapters()}
|
a.provider_id: a for a in available_inference_providers()
|
||||||
agentic_system_adapters_by_id = {
|
}
|
||||||
a.adapter_id: a for a in available_agentic_system_adapters()
|
safety_providers_by_id = {a.provider_id: a for a in available_safety_providers()}
|
||||||
|
agentic_system_providers_by_id = {
|
||||||
|
a.provider_id: a for a in available_agentic_system_providers()
|
||||||
}
|
}
|
||||||
|
|
||||||
return [
|
return [
|
||||||
Distribution(
|
Distribution(
|
||||||
name="local-source",
|
name="local-inline",
|
||||||
description="Use code from `llama_toolchain` itself to serve all llama stack APIs",
|
description="Use code from `llama_toolchain` itself to serve all llama stack APIs",
|
||||||
additional_pip_packages=COMMON_DEPENDENCIES,
|
additional_pip_packages=COMMON_DEPENDENCIES,
|
||||||
adapters={
|
provider_specs={
|
||||||
Api.inference: inference_adapters_by_id["meta-reference"],
|
Api.inference: inference_providers_by_id["meta-reference"],
|
||||||
Api.safety: safety_adapters_by_id["meta-reference"],
|
Api.safety: safety_providers_by_id["meta-reference"],
|
||||||
Api.agentic_system: agentic_system_adapters_by_id["meta-reference"],
|
Api.agentic_system: agentic_system_providers_by_id["meta-reference"],
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
|
# NOTE: this hardcodes the ports to which things point to
|
||||||
Distribution(
|
Distribution(
|
||||||
name="full-passthrough",
|
name="full-remote",
|
||||||
description="Point to remote services for all llama stack APIs",
|
description="Point to remote services for all llama stack APIs",
|
||||||
additional_pip_packages=[
|
additional_pip_packages=[
|
||||||
"python-dotenv",
|
"python-dotenv",
|
||||||
|
@ -94,20 +97,20 @@ def available_distributions() -> List[Distribution]:
|
||||||
"pydantic_core==2.18.2",
|
"pydantic_core==2.18.2",
|
||||||
"uvicorn",
|
"uvicorn",
|
||||||
],
|
],
|
||||||
adapters={
|
provider_specs={
|
||||||
Api.inference: passthrough(Api.inference, 5001),
|
Api.inference: remote(Api.inference, 5001),
|
||||||
Api.safety: passthrough(Api.safety, 5001),
|
Api.safety: remote(Api.safety, 5001),
|
||||||
Api.agentic_system: passthrough(Api.agentic_system, 5001),
|
Api.agentic_system: remote(Api.agentic_system, 5001),
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
Distribution(
|
Distribution(
|
||||||
name="local-ollama",
|
name="local-ollama",
|
||||||
description="Like local-source, but use ollama for running LLM inference",
|
description="Like local-source, but use ollama for running LLM inference",
|
||||||
additional_pip_packages=COMMON_DEPENDENCIES,
|
additional_pip_packages=COMMON_DEPENDENCIES,
|
||||||
adapters={
|
provider_specs={
|
||||||
Api.inference: inference_adapters_by_id["meta-ollama"],
|
Api.inference: inference_providers_by_id["meta-ollama"],
|
||||||
Api.safety: safety_adapters_by_id["meta-reference"],
|
Api.safety: safety_providers_by_id["meta-reference"],
|
||||||
Api.agentic_system: agentic_system_adapters_by_id["meta-reference"],
|
Api.agentic_system: agentic_system_providers_by_id["meta-reference"],
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -36,9 +36,9 @@ from fastapi.routing import APIRoute
|
||||||
from pydantic import BaseModel, ValidationError
|
from pydantic import BaseModel, ValidationError
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from .datatypes import Adapter, Api, Distribution, PassthroughApiAdapter
|
from .datatypes import Api, Distribution, ProviderSpec, RemoteProviderSpec
|
||||||
from .distribution import api_endpoints
|
from .distribution import api_endpoints
|
||||||
from .dynamic import instantiate_adapter, instantiate_client
|
from .dynamic import instantiate_client, instantiate_provider
|
||||||
|
|
||||||
from .registry import resolve_distribution
|
from .registry import resolve_distribution
|
||||||
|
|
||||||
|
@ -226,15 +226,15 @@ def create_dynamic_typed_route(func: Any):
|
||||||
return endpoint
|
return endpoint
|
||||||
|
|
||||||
|
|
||||||
def topological_sort(adapters: List[Adapter]) -> List[Adapter]:
|
def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
|
||||||
|
|
||||||
by_id = {x.api: x for x in adapters}
|
by_id = {x.api: x for x in providers}
|
||||||
|
|
||||||
def dfs(a: Adapter, visited: Set[Api], stack: List[Api]):
|
def dfs(a: ProviderSpec, visited: Set[Api], stack: List[Api]):
|
||||||
visited.add(a.api)
|
visited.add(a.api)
|
||||||
|
|
||||||
if not isinstance(a, PassthroughApiAdapter):
|
if not isinstance(a, RemoteProviderSpec):
|
||||||
for api in a.adapter_dependencies:
|
for api in a.api_dependencies:
|
||||||
if api not in visited:
|
if api not in visited:
|
||||||
dfs(by_id[api], visited, stack)
|
dfs(by_id[api], visited, stack)
|
||||||
|
|
||||||
|
@ -243,7 +243,7 @@ def topological_sort(adapters: List[Adapter]) -> List[Adapter]:
|
||||||
visited = set()
|
visited = set()
|
||||||
stack = []
|
stack = []
|
||||||
|
|
||||||
for a in adapters:
|
for a in providers:
|
||||||
if a.api not in visited:
|
if a.api not in visited:
|
||||||
dfs(a, visited, stack)
|
dfs(a, visited, stack)
|
||||||
|
|
||||||
|
@ -251,23 +251,25 @@ def topological_sort(adapters: List[Adapter]) -> List[Adapter]:
|
||||||
|
|
||||||
|
|
||||||
def resolve_impls(dist: Distribution, config: Dict[str, Any]) -> Dict[Api, Any]:
|
def resolve_impls(dist: Distribution, config: Dict[str, Any]) -> Dict[Api, Any]:
|
||||||
adapter_configs = config["adapters"]
|
provider_configs = config["providers"]
|
||||||
adapters = topological_sort(dist.adapters.values())
|
provider_specs = topological_sort(dist.provider_specs.values())
|
||||||
|
|
||||||
impls = {}
|
impls = {}
|
||||||
for adapter in adapters:
|
for provider_spec in provider_specs:
|
||||||
api = adapter.api
|
api = provider_spec.api
|
||||||
if api.value not in adapter_configs:
|
if api.value not in provider_configs:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Could not find adapter config for {api}. Please add it to the config"
|
f"Could not find provider_spec config for {api}. Please add it to the config"
|
||||||
)
|
)
|
||||||
|
|
||||||
adapter_config = adapter_configs[api.value]
|
provider_config = provider_configs[api.value]
|
||||||
if isinstance(adapter, PassthroughApiAdapter):
|
if isinstance(provider_spec, RemoteProviderSpec):
|
||||||
impls[api] = instantiate_client(adapter, adapter.base_url.rstrip("/"))
|
impls[api] = instantiate_client(
|
||||||
|
provider_spec, provider_spec.base_url.rstrip("/")
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
deps = {api: impls[api] for api in adapter.adapter_dependencies}
|
deps = {api: impls[api] for api in provider_spec.api_dependencies}
|
||||||
impl = instantiate_adapter(adapter, adapter_config, deps)
|
impl = instantiate_provider(provider_spec, provider_config, deps)
|
||||||
impls[api] = impl
|
impls[api] = impl
|
||||||
|
|
||||||
return impls
|
return impls
|
||||||
|
@ -288,12 +290,12 @@ def main(
|
||||||
all_endpoints = api_endpoints()
|
all_endpoints = api_endpoints()
|
||||||
impls = resolve_impls(dist, config)
|
impls = resolve_impls(dist, config)
|
||||||
|
|
||||||
for adapter in dist.adapters.values():
|
for provider_spec in dist.provider_specs.values():
|
||||||
api = adapter.api
|
api = provider_spec.api
|
||||||
endpoints = all_endpoints[api]
|
endpoints = all_endpoints[api]
|
||||||
if isinstance(adapter, PassthroughApiAdapter):
|
if isinstance(provider_spec, RemoteProviderSpec):
|
||||||
for endpoint in endpoints:
|
for endpoint in endpoints:
|
||||||
url = adapter.base_url.rstrip("/") + endpoint.route
|
url = provider_spec.base_url.rstrip("/") + endpoint.route
|
||||||
getattr(app, endpoint.method)(endpoint.route)(
|
getattr(app, endpoint.method)(endpoint.route)(
|
||||||
create_dynamic_passthrough(url)
|
create_dynamic_passthrough(url)
|
||||||
)
|
)
|
||||||
|
|
|
@ -11,7 +11,7 @@ from typing import AsyncIterator, Dict, Union
|
||||||
from llama_models.llama3_1.api.datatypes import StopReason
|
from llama_models.llama3_1.api.datatypes import StopReason
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from llama_toolchain.distribution.datatypes import Adapter, Api
|
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
|
||||||
|
|
||||||
from .api.config import MetaReferenceImplConfig
|
from .api.config import MetaReferenceImplConfig
|
||||||
from .api.datatypes import (
|
from .api.datatypes import (
|
||||||
|
@ -29,7 +29,9 @@ from .api.endpoints import (
|
||||||
from .model_parallel import LlamaModelParallelGenerator
|
from .model_parallel import LlamaModelParallelGenerator
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: MetaReferenceImplConfig, _deps: Dict[Api, Adapter]):
|
async def get_provider_impl(
|
||||||
|
config: MetaReferenceImplConfig, _deps: Dict[Api, ProviderSpec]
|
||||||
|
):
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
config, MetaReferenceImplConfig
|
config, MetaReferenceImplConfig
|
||||||
), f"Unexpected config type: {type(config)}"
|
), f"Unexpected config type: {type(config)}"
|
||||||
|
|
|
@ -37,7 +37,7 @@ from .api.endpoints import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_adapter_impl(config: OllamaImplConfig) -> Inference:
|
def get_provider_impl(config: OllamaImplConfig) -> Inference:
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
config, OllamaImplConfig
|
config, OllamaImplConfig
|
||||||
), f"Unexpected config type: {type(config)}"
|
), f"Unexpected config type: {type(config)}"
|
||||||
|
|
|
@ -6,14 +6,14 @@
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from llama_toolchain.distribution.datatypes import Adapter, Api, SourceAdapter
|
from llama_toolchain.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec
|
||||||
|
|
||||||
|
|
||||||
def available_inference_adapters() -> List[Adapter]:
|
def available_inference_providers() -> List[ProviderSpec]:
|
||||||
return [
|
return [
|
||||||
SourceAdapter(
|
InlineProviderSpec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter_id="meta-reference",
|
provider_id="meta-reference",
|
||||||
pip_packages=[
|
pip_packages=[
|
||||||
"torch",
|
"torch",
|
||||||
"zmq",
|
"zmq",
|
||||||
|
@ -21,9 +21,9 @@ def available_inference_adapters() -> List[Adapter]:
|
||||||
module="llama_toolchain.inference.inference",
|
module="llama_toolchain.inference.inference",
|
||||||
config_class="llama_toolchain.inference.inference.MetaReferenceImplConfig",
|
config_class="llama_toolchain.inference.inference.MetaReferenceImplConfig",
|
||||||
),
|
),
|
||||||
SourceAdapter(
|
InlineProviderSpec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter_id="meta-ollama",
|
provider_id="meta-ollama",
|
||||||
pip_packages=[
|
pip_packages=[
|
||||||
"ollama",
|
"ollama",
|
||||||
],
|
],
|
|
@ -6,14 +6,14 @@
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from llama_toolchain.distribution.datatypes import Adapter, Api, SourceAdapter
|
from llama_toolchain.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec
|
||||||
|
|
||||||
|
|
||||||
def available_safety_adapters() -> List[Adapter]:
|
def available_safety_providers() -> List[ProviderSpec]:
|
||||||
return [
|
return [
|
||||||
SourceAdapter(
|
InlineProviderSpec(
|
||||||
api=Api.safety,
|
api=Api.safety,
|
||||||
adapter_id="meta-reference",
|
provider_id="meta-reference",
|
||||||
pip_packages=[
|
pip_packages=[
|
||||||
"codeshield",
|
"codeshield",
|
||||||
"torch",
|
"torch",
|
|
@ -8,7 +8,7 @@ import asyncio
|
||||||
|
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
from llama_toolchain.distribution.datatypes import Adapter, Api
|
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
|
||||||
|
|
||||||
from .config import SafetyConfig
|
from .config import SafetyConfig
|
||||||
from .api.endpoints import * # noqa
|
from .api.endpoints import * # noqa
|
||||||
|
@ -23,7 +23,7 @@ from .shields import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: SafetyConfig, _deps: Dict[Api, Adapter]):
|
async def get_provider_impl(config: SafetyConfig, _deps: Dict[Api, ProviderSpec]):
|
||||||
assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}"
|
assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}"
|
||||||
|
|
||||||
impl = MetaReferenceSafetyImpl(config)
|
impl = MetaReferenceSafetyImpl(config)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue