ollama remote adapter works

This commit is contained in:
Ashwin Bharambe 2024-08-28 06:51:07 -07:00
parent 2076d2b6db
commit 2a1552a5eb
14 changed files with 196 additions and 128 deletions

View file

@ -38,6 +38,36 @@ class ProviderSpec(BaseModel):
)
@json_schema_type
class AdapterSpec(BaseModel):
"""
If some code is needed to convert the remote responses into Llama Stack compatible
API responses, specify the adapter here. If not specified, it indicates the remote
as being "Llama Stack compatible"
"""
adapter_id: str = Field(
...,
description="Unique identifier for this adapter",
)
module: str = Field(
...,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_adapter_impl(config, deps)`: returns the adapter implementation
""",
)
pip_packages: List[str] = Field(
default_factory=list,
description="The pip dependencies needed for this implementation",
)
config_class: Optional[str] = Field(
default=None,
description="Fully-qualified classname of the config for this provider",
)
@json_schema_type
class InlineProviderSpec(ProviderSpec):
pip_packages: List[str] = Field(
@ -63,30 +93,7 @@ Fully-qualified name of the module to import. The module is expected to have:
default_factory=list,
description="Higher-level API surfaces may depend on other providers to provide their functionality",
)
@json_schema_type
class AdapterSpec(BaseModel):
adapter_id: str = Field(
...,
description="Unique identifier for this adapter",
)
module: str = Field(
...,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_adapter_impl(config, deps)`: returns the adapter implementation
""",
)
pip_packages: List[str] = Field(
default_factory=list,
description="The pip dependencies needed for this implementation",
)
config_class: Optional[str] = Field(
default=None,
description="Fully-qualified classname of the config for this provider",
)
is_adapter: bool = False
class RemoteProviderConfig(BaseModel):
@ -106,40 +113,34 @@ def remote_provider_id(adapter_id: str) -> str:
@json_schema_type
class RemoteProviderSpec(ProviderSpec):
adapter: Optional[AdapterSpec] = Field(
default=None,
description="""
If some code is needed to convert the remote responses into Llama Stack compatible
API responses, specify the adapter here. If not specified, it indicates the remote
as being "Llama Stack compatible"
""",
)
provider_id: str = "remote"
config_class: str = "llama_toolchain.distribution.datatypes.RemoteProviderConfig"
@property
def module(self) -> str:
return f"llama_toolchain.{self.api.value}.client"
# need this wrapper since we don't have Pydantic v2 and that means we don't have
def remote_provider_spec(api: Api) -> RemoteProviderSpec:
return RemoteProviderSpec(api=api)
# TODO: use computed_field to avoid this wrapper
# the @computed_field decorator
def remote_provider_spec(
api: Api, adapter: Optional[AdapterSpec] = None
) -> RemoteProviderSpec:
provider_id = (
remote_provider_id(adapter.adapter_id) if adapter is not None else "remote"
)
module = (
adapter.module if adapter is not None else f"llama_toolchain.{api.value}.client"
)
def adapter_provider_spec(api: Api, adapter: AdapterSpec) -> InlineProviderSpec:
config_class = (
adapter.config_class
if adapter and adapter.config_class
if adapter.config_class
else "llama_toolchain.distribution.datatypes.RemoteProviderConfig"
)
return RemoteProviderSpec(
return InlineProviderSpec(
api=api,
provider_id=provider_id,
pip_packages=adapter.pip_packages if adapter is not None else [],
module=module,
provider_id=remote_provider_id(adapter.adapter_id),
pip_packages=adapter.pip_packages,
module=adapter.module,
config_class=config_class,
is_adapter=True,
)

View file

@ -8,7 +8,7 @@ import asyncio
import importlib
from typing import Any, Dict
from .datatypes import InlineProviderSpec, ProviderSpec, RemoteProviderSpec
from .datatypes import InlineProviderSpec, ProviderSpec, RemoteProviderConfig
def instantiate_class_type(fully_qualified_name):
@ -19,38 +19,23 @@ def instantiate_class_type(fully_qualified_name):
# returns a class implementing the protocol corresponding to the Api
def instantiate_provider(
provider_spec: InlineProviderSpec,
provider_spec: ProviderSpec,
provider_config: Dict[str, Any],
deps: Dict[str, ProviderSpec],
):
module = importlib.import_module(provider_spec.module)
config_type = instantiate_class_type(provider_spec.config_class)
if isinstance(provider_spec, InlineProviderSpec):
if provider_spec.is_adapter:
if not issubclass(config_type, RemoteProviderConfig):
raise ValueError(
f"Config class {provider_spec.config_class} does not inherit from RemoteProviderConfig"
)
config = config_type(**provider_config)
return asyncio.run(module.get_provider_impl(config, deps))
def instantiate_client(
provider_spec: RemoteProviderSpec, provider_config: Dict[str, Any]
):
module = importlib.import_module(provider_spec.module)
adapter = provider_spec.adapter
if adapter is not None:
if "adapter" not in provider_config:
raise ValueError(
f"Adapter is specified but not present in provider config: {provider_config}"
)
adapter_config = provider_config["adapter"]
config_type = instantiate_class_type(adapter.config_class)
if not issubclass(config_type, RemoteProviderConfig):
raise ValueError(
f"Config class {adapter.config_class} does not inherit from RemoteProviderConfig"
)
config = config_type(**adapter_config)
if isinstance(provider_spec, InlineProviderSpec):
args = [config, deps]
else:
config = RemoteProviderConfig(**provider_config)
return asyncio.run(module.get_adapter_impl(config))
args = [config]
return asyncio.run(module.get_provider_impl(*args))

View file

@ -38,11 +38,9 @@ from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from .datatypes import Api, DistributionSpec, ProviderSpec, RemoteProviderSpec
from .distribution import api_endpoints
from .dynamic import instantiate_client, instantiate_provider
from .registry import resolve_distribution_spec
from .datatypes import Api, ProviderSpec, RemoteProviderSpec
from .distribution import api_endpoints, api_providers
from .dynamic import instantiate_provider
def is_async_iterator_type(typ):
@ -249,9 +247,11 @@ def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
return [by_id[x] for x in stack]
def resolve_impls(dist: DistributionSpec, config: Dict[str, Any]) -> Dict[Api, Any]:
def resolve_impls(
provider_specs: Dict[str, ProviderSpec], config: Dict[str, Any]
) -> Dict[Api, Any]:
provider_configs = config["providers"]
provider_specs = topological_sort(dist.provider_specs.values())
provider_specs = topological_sort(provider_specs.values())
impls = {}
for provider_spec in provider_specs:
@ -261,16 +261,10 @@ def resolve_impls(dist: DistributionSpec, config: Dict[str, Any]) -> Dict[Api, A
f"Could not find provider_spec config for {api}. Please add it to the config"
)
deps = {api: impls[api] for api in provider_spec.api_dependencies}
provider_config = provider_configs[api.value]
if isinstance(provider_spec, RemoteProviderSpec):
impls[api] = instantiate_client(
provider_spec,
provider_config,
)
else:
deps = {api: impls[api] for api in provider_spec.api_dependencies}
impl = instantiate_provider(provider_spec, provider_config, deps)
impls[api] = impl
impl = instantiate_provider(provider_spec, provider_config, deps)
impls[api] = impl
return impls
@ -279,22 +273,34 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
with open(yaml_config, "r") as fp:
config = yaml.safe_load(fp)
spec = config["spec"]
dist = resolve_distribution_spec(spec)
if dist is None:
raise ValueError(f"Could not find distribution specification `{spec}`")
app = FastAPI()
all_endpoints = api_endpoints()
impls = resolve_impls(dist, config)
all_providers = api_providers()
for provider_spec in dist.provider_specs.values():
provider_specs = {}
for api_str, provider_config in config["providers"].items():
api = Api(api_str)
providers = all_providers[api]
provider_id = provider_config["provider_id"]
if provider_id not in providers:
raise ValueError(
f"Unknown provider `{provider_id}` is not available for API `{api}`"
)
provider_specs[api] = providers[provider_id]
impls = resolve_impls(provider_specs, config)
for provider_spec in provider_specs.values():
api = provider_spec.api
endpoints = all_endpoints[api]
impl = impls[api]
if isinstance(provider_spec, RemoteProviderSpec):
if (
isinstance(provider_spec, RemoteProviderSpec)
and provider_spec.adapter is None
):
for endpoint in endpoints:
url = impl.base_url + endpoint.route
getattr(app, endpoint.method)(endpoint.route)(

View file

@ -8,7 +8,6 @@
set -euo pipefail
# Define color codes
RED='\033[0;31m'
NC='\033[0m' # No Color
@ -17,20 +16,17 @@ error_handler() {
exit 1
}
# Set up the error trap
trap 'error_handler ${LINENO}' ERR
if [ $# -lt 2 ]; then
echo "Usage: $0 <environment_name> <script_args...>"
exit 1
echo "Usage: $0 <environment_name> <script_args...>"
exit 1
fi
env_name="$1"
shift
eval "$(conda shell.bash hook)"
conda deactivate && conda activate "$env_name"
python_interp=$(conda run -n "$env_name" which python)
$python_interp -m llama_toolchain.distribution.server "$@"
$CONDA_PREFIX/bin/python -m llama_toolchain.distribution.server "$@"