Several smaller fixes to make adapters work

Also, reorganized the pattern of __init__ inside providers so
configuration can stay lightweight
This commit is contained in:
Ashwin Bharambe 2024-08-28 09:42:08 -07:00
parent 2a1552a5eb
commit 45987996c4
23 changed files with 164 additions and 160 deletions

View file

@ -69,7 +69,7 @@ ensure_conda_env_python310() {
conda create -n "${env_name}" python="${python_version}" -y
ENVNAME="${env_name}"
setup_cleanup_handlers
# setup_cleanup_handlers
fi
eval "$(conda shell.bash hook)"

View file

@ -36,16 +36,14 @@ class ProviderSpec(BaseModel):
...,
description="Fully-qualified classname of the config for this provider",
)
api_dependencies: List[Api] = Field(
default_factory=list,
description="Higher-level API surfaces may depend on other providers to provide their functionality",
)
@json_schema_type
class AdapterSpec(BaseModel):
"""
If some code is needed to convert the remote responses into Llama Stack compatible
API responses, specify the adapter here. If not specified, it indicates the remote
as being "Llama Stack compatible"
"""
adapter_id: str = Field(
...,
description="Unique identifier for this adapter",
@ -89,11 +87,6 @@ Fully-qualified name of the module to import. The module is expected to have:
- `get_provider_impl(config, deps)`: returns the local implementation
""",
)
api_dependencies: List[Api] = Field(
default_factory=list,
description="Higher-level API surfaces may depend on other providers to provide their functionality",
)
is_adapter: bool = False
class RemoteProviderConfig(BaseModel):
@ -113,34 +106,41 @@ def remote_provider_id(adapter_id: str) -> str:
@json_schema_type
class RemoteProviderSpec(ProviderSpec):
provider_id: str = "remote"
config_class: str = "llama_toolchain.distribution.datatypes.RemoteProviderConfig"
adapter: Optional[AdapterSpec] = Field(
default=None,
description="""
If some code is needed to convert the remote responses into Llama Stack compatible
API responses, specify the adapter here. If not specified, it indicates the remote
as being "Llama Stack compatible"
""",
)
@property
def module(self) -> str:
if self.adapter:
return self.adapter.module
return f"llama_toolchain.{self.api.value}.client"
def remote_provider_spec(api: Api) -> RemoteProviderSpec:
return RemoteProviderSpec(api=api)
@property
def pip_packages(self) -> List[str]:
if self.adapter:
return self.adapter.pip_packages
return []
# TODO: use computed_field to avoid this wrapper
# the @computed_field decorator
def adapter_provider_spec(api: Api, adapter: AdapterSpec) -> InlineProviderSpec:
# Can avoid this by using Pydantic computed_field
def remote_provider_spec(
api: Api, adapter: Optional[AdapterSpec] = None
) -> RemoteProviderSpec:
config_class = (
adapter.config_class
if adapter.config_class
if adapter and adapter.config_class
else "llama_toolchain.distribution.datatypes.RemoteProviderConfig"
)
provider_id = remote_provider_id(adapter.adapter_id) if adapter else "remote"
return InlineProviderSpec(
api=api,
provider_id=remote_provider_id(adapter.adapter_id),
pip_packages=adapter.pip_packages,
module=adapter.module,
config_class=config_class,
is_adapter=True,
return RemoteProviderSpec(
api=api, provider_id=provider_id, config_class=config_class, adapter=adapter
)

View file

@ -22,6 +22,7 @@ from .datatypes import (
DistributionSpec,
InlineProviderSpec,
ProviderSpec,
remote_provider_spec,
)
# These are the dependencies needed by the distribution server.
@ -89,9 +90,12 @@ def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
a.provider_id: a for a in available_agentic_system_providers()
}
return {
ret = {
Api.inference: inference_providers_by_id,
Api.safety: safety_providers_by_id,
Api.agentic_system: agentic_system_providers_by_id,
Api.memory: {a.provider_id: a for a in available_memory_providers()},
}
for k, v in ret.items():
v["remote"] = remote_provider_spec(k)
return ret

View file

@ -8,7 +8,7 @@ import asyncio
import importlib
from typing import Any, Dict
from .datatypes import InlineProviderSpec, ProviderSpec, RemoteProviderConfig
from .datatypes import ProviderSpec, RemoteProviderConfig, RemoteProviderSpec
def instantiate_class_type(fully_qualified_name):
@ -26,16 +26,21 @@ def instantiate_provider(
module = importlib.import_module(provider_spec.module)
config_type = instantiate_class_type(provider_spec.config_class)
if isinstance(provider_spec, InlineProviderSpec):
if provider_spec.is_adapter:
if isinstance(provider_spec, RemoteProviderSpec):
if provider_spec.adapter:
if not issubclass(config_type, RemoteProviderConfig):
raise ValueError(
f"Config class {provider_spec.config_class} does not inherit from RemoteProviderConfig"
)
config = config_type(**provider_config)
if isinstance(provider_spec, InlineProviderSpec):
args = [config, deps]
method = "get_adapter_impl"
else:
method = "get_client_impl"
else:
args = [config]
return asyncio.run(module.get_provider_impl(*args))
method = "get_provider_impl"
config = config_type(**provider_config)
fn = getattr(module, method)
impl = asyncio.run(fn(config, deps))
impl.__provider_spec__ = provider_spec
impl.__provider_config__ = config
return impl

View file

@ -38,7 +38,7 @@ from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from .datatypes import Api, ProviderSpec, RemoteProviderSpec
from .datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
from .distribution import api_endpoints, api_providers
from .dynamic import instantiate_provider
@ -230,10 +230,9 @@ def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
def dfs(a: ProviderSpec, visited: Set[Api], stack: List[Api]):
visited.add(a.api)
if not isinstance(a, RemoteProviderSpec):
for api in a.api_dependencies:
if api not in visited:
dfs(by_id[api], visited, stack)
for api in a.api_dependencies:
if api not in visited:
dfs(by_id[api], visited, stack)
stack.append(a.api)
@ -261,7 +260,10 @@ def resolve_impls(
f"Could not find provider_spec config for {api}. Please add it to the config"
)
deps = {api: impls[api] for api in provider_spec.api_dependencies}
if isinstance(provider_spec, InlineProviderSpec):
deps = {api: impls[api] for api in provider_spec.api_dependencies}
else:
deps = {}
provider_config = provider_configs[api.value]
impl = instantiate_provider(provider_spec, provider_config, deps)
impls[api] = impl
@ -302,7 +304,7 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
and provider_spec.adapter is None
):
for endpoint in endpoints:
url = impl.base_url + endpoint.route
url = impl.__provider_config__.url
getattr(app, endpoint.method)(endpoint.route)(
create_dynamic_passthrough(url)
)