mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-08 04:54:38 +00:00
Several smaller fixes to make adapters work
Also, reorganized the pattern of __init__ inside providers so configuration can stay lightweight
This commit is contained in:
parent
2a1552a5eb
commit
45987996c4
23 changed files with 164 additions and 160 deletions
|
@ -69,7 +69,7 @@ ensure_conda_env_python310() {
|
|||
conda create -n "${env_name}" python="${python_version}" -y
|
||||
|
||||
ENVNAME="${env_name}"
|
||||
setup_cleanup_handlers
|
||||
# setup_cleanup_handlers
|
||||
fi
|
||||
|
||||
eval "$(conda shell.bash hook)"
|
||||
|
|
|
@ -36,16 +36,14 @@ class ProviderSpec(BaseModel):
|
|||
...,
|
||||
description="Fully-qualified classname of the config for this provider",
|
||||
)
|
||||
api_dependencies: List[Api] = Field(
|
||||
default_factory=list,
|
||||
description="Higher-level API surfaces may depend on other providers to provide their functionality",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AdapterSpec(BaseModel):
|
||||
"""
|
||||
If some code is needed to convert the remote responses into Llama Stack compatible
|
||||
API responses, specify the adapter here. If not specified, it indicates the remote
|
||||
as being "Llama Stack compatible"
|
||||
"""
|
||||
|
||||
adapter_id: str = Field(
|
||||
...,
|
||||
description="Unique identifier for this adapter",
|
||||
|
@ -89,11 +87,6 @@ Fully-qualified name of the module to import. The module is expected to have:
|
|||
- `get_provider_impl(config, deps)`: returns the local implementation
|
||||
""",
|
||||
)
|
||||
api_dependencies: List[Api] = Field(
|
||||
default_factory=list,
|
||||
description="Higher-level API surfaces may depend on other providers to provide their functionality",
|
||||
)
|
||||
is_adapter: bool = False
|
||||
|
||||
|
||||
class RemoteProviderConfig(BaseModel):
|
||||
|
@ -113,34 +106,41 @@ def remote_provider_id(adapter_id: str) -> str:
|
|||
|
||||
@json_schema_type
|
||||
class RemoteProviderSpec(ProviderSpec):
|
||||
provider_id: str = "remote"
|
||||
config_class: str = "llama_toolchain.distribution.datatypes.RemoteProviderConfig"
|
||||
adapter: Optional[AdapterSpec] = Field(
|
||||
default=None,
|
||||
description="""
|
||||
If some code is needed to convert the remote responses into Llama Stack compatible
|
||||
API responses, specify the adapter here. If not specified, it indicates the remote
|
||||
as being "Llama Stack compatible"
|
||||
""",
|
||||
)
|
||||
|
||||
@property
|
||||
def module(self) -> str:
|
||||
if self.adapter:
|
||||
return self.adapter.module
|
||||
return f"llama_toolchain.{self.api.value}.client"
|
||||
|
||||
|
||||
def remote_provider_spec(api: Api) -> RemoteProviderSpec:
|
||||
return RemoteProviderSpec(api=api)
|
||||
@property
|
||||
def pip_packages(self) -> List[str]:
|
||||
if self.adapter:
|
||||
return self.adapter.pip_packages
|
||||
return []
|
||||
|
||||
|
||||
# TODO: use computed_field to avoid this wrapper
|
||||
# the @computed_field decorator
|
||||
def adapter_provider_spec(api: Api, adapter: AdapterSpec) -> InlineProviderSpec:
|
||||
# Can avoid this by using Pydantic computed_field
|
||||
def remote_provider_spec(
|
||||
api: Api, adapter: Optional[AdapterSpec] = None
|
||||
) -> RemoteProviderSpec:
|
||||
config_class = (
|
||||
adapter.config_class
|
||||
if adapter.config_class
|
||||
if adapter and adapter.config_class
|
||||
else "llama_toolchain.distribution.datatypes.RemoteProviderConfig"
|
||||
)
|
||||
provider_id = remote_provider_id(adapter.adapter_id) if adapter else "remote"
|
||||
|
||||
return InlineProviderSpec(
|
||||
api=api,
|
||||
provider_id=remote_provider_id(adapter.adapter_id),
|
||||
pip_packages=adapter.pip_packages,
|
||||
module=adapter.module,
|
||||
config_class=config_class,
|
||||
is_adapter=True,
|
||||
return RemoteProviderSpec(
|
||||
api=api, provider_id=provider_id, config_class=config_class, adapter=adapter
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ from .datatypes import (
|
|||
DistributionSpec,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
remote_provider_spec,
|
||||
)
|
||||
|
||||
# These are the dependencies needed by the distribution server.
|
||||
|
@ -89,9 +90,12 @@ def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
|
|||
a.provider_id: a for a in available_agentic_system_providers()
|
||||
}
|
||||
|
||||
return {
|
||||
ret = {
|
||||
Api.inference: inference_providers_by_id,
|
||||
Api.safety: safety_providers_by_id,
|
||||
Api.agentic_system: agentic_system_providers_by_id,
|
||||
Api.memory: {a.provider_id: a for a in available_memory_providers()},
|
||||
}
|
||||
for k, v in ret.items():
|
||||
v["remote"] = remote_provider_spec(k)
|
||||
return ret
|
||||
|
|
|
@ -8,7 +8,7 @@ import asyncio
|
|||
import importlib
|
||||
from typing import Any, Dict
|
||||
|
||||
from .datatypes import InlineProviderSpec, ProviderSpec, RemoteProviderConfig
|
||||
from .datatypes import ProviderSpec, RemoteProviderConfig, RemoteProviderSpec
|
||||
|
||||
|
||||
def instantiate_class_type(fully_qualified_name):
|
||||
|
@ -26,16 +26,21 @@ def instantiate_provider(
|
|||
module = importlib.import_module(provider_spec.module)
|
||||
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
if isinstance(provider_spec, InlineProviderSpec):
|
||||
if provider_spec.is_adapter:
|
||||
if isinstance(provider_spec, RemoteProviderSpec):
|
||||
if provider_spec.adapter:
|
||||
if not issubclass(config_type, RemoteProviderConfig):
|
||||
raise ValueError(
|
||||
f"Config class {provider_spec.config_class} does not inherit from RemoteProviderConfig"
|
||||
)
|
||||
config = config_type(**provider_config)
|
||||
|
||||
if isinstance(provider_spec, InlineProviderSpec):
|
||||
args = [config, deps]
|
||||
method = "get_adapter_impl"
|
||||
else:
|
||||
method = "get_client_impl"
|
||||
else:
|
||||
args = [config]
|
||||
return asyncio.run(module.get_provider_impl(*args))
|
||||
method = "get_provider_impl"
|
||||
|
||||
config = config_type(**provider_config)
|
||||
fn = getattr(module, method)
|
||||
impl = asyncio.run(fn(config, deps))
|
||||
impl.__provider_spec__ = provider_spec
|
||||
impl.__provider_config__ = config
|
||||
return impl
|
||||
|
|
|
@ -38,7 +38,7 @@ from pydantic import BaseModel, ValidationError
|
|||
from termcolor import cprint
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from .datatypes import Api, ProviderSpec, RemoteProviderSpec
|
||||
from .datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
|
||||
from .distribution import api_endpoints, api_providers
|
||||
from .dynamic import instantiate_provider
|
||||
|
||||
|
@ -230,10 +230,9 @@ def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
|
|||
def dfs(a: ProviderSpec, visited: Set[Api], stack: List[Api]):
|
||||
visited.add(a.api)
|
||||
|
||||
if not isinstance(a, RemoteProviderSpec):
|
||||
for api in a.api_dependencies:
|
||||
if api not in visited:
|
||||
dfs(by_id[api], visited, stack)
|
||||
for api in a.api_dependencies:
|
||||
if api not in visited:
|
||||
dfs(by_id[api], visited, stack)
|
||||
|
||||
stack.append(a.api)
|
||||
|
||||
|
@ -261,7 +260,10 @@ def resolve_impls(
|
|||
f"Could not find provider_spec config for {api}. Please add it to the config"
|
||||
)
|
||||
|
||||
deps = {api: impls[api] for api in provider_spec.api_dependencies}
|
||||
if isinstance(provider_spec, InlineProviderSpec):
|
||||
deps = {api: impls[api] for api in provider_spec.api_dependencies}
|
||||
else:
|
||||
deps = {}
|
||||
provider_config = provider_configs[api.value]
|
||||
impl = instantiate_provider(provider_spec, provider_config, deps)
|
||||
impls[api] = impl
|
||||
|
@ -302,7 +304,7 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
|||
and provider_spec.adapter is None
|
||||
):
|
||||
for endpoint in endpoints:
|
||||
url = impl.base_url + endpoint.route
|
||||
url = impl.__provider_config__.url
|
||||
getattr(app, endpoint.method)(endpoint.route)(
|
||||
create_dynamic_passthrough(url)
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue