mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
ollama remote adapter works
This commit is contained in:
parent
2076d2b6db
commit
2a1552a5eb
14 changed files with 196 additions and 128 deletions
|
@ -38,6 +38,36 @@ class ProviderSpec(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AdapterSpec(BaseModel):
|
||||
"""
|
||||
If some code is needed to convert the remote responses into Llama Stack compatible
|
||||
API responses, specify the adapter here. If not specified, it indicates the remote
|
||||
as being "Llama Stack compatible"
|
||||
"""
|
||||
|
||||
adapter_id: str = Field(
|
||||
...,
|
||||
description="Unique identifier for this adapter",
|
||||
)
|
||||
module: str = Field(
|
||||
...,
|
||||
description="""
|
||||
Fully-qualified name of the module to import. The module is expected to have:
|
||||
|
||||
- `get_adapter_impl(config, deps)`: returns the adapter implementation
|
||||
""",
|
||||
)
|
||||
pip_packages: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="The pip dependencies needed for this implementation",
|
||||
)
|
||||
config_class: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Fully-qualified classname of the config for this provider",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class InlineProviderSpec(ProviderSpec):
|
||||
pip_packages: List[str] = Field(
|
||||
|
@ -63,30 +93,7 @@ Fully-qualified name of the module to import. The module is expected to have:
|
|||
default_factory=list,
|
||||
description="Higher-level API surfaces may depend on other providers to provide their functionality",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AdapterSpec(BaseModel):
|
||||
adapter_id: str = Field(
|
||||
...,
|
||||
description="Unique identifier for this adapter",
|
||||
)
|
||||
module: str = Field(
|
||||
...,
|
||||
description="""
|
||||
Fully-qualified name of the module to import. The module is expected to have:
|
||||
|
||||
- `get_adapter_impl(config, deps)`: returns the adapter implementation
|
||||
""",
|
||||
)
|
||||
pip_packages: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="The pip dependencies needed for this implementation",
|
||||
)
|
||||
config_class: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Fully-qualified classname of the config for this provider",
|
||||
)
|
||||
is_adapter: bool = False
|
||||
|
||||
|
||||
class RemoteProviderConfig(BaseModel):
|
||||
|
@ -106,40 +113,34 @@ def remote_provider_id(adapter_id: str) -> str:
|
|||
|
||||
@json_schema_type
|
||||
class RemoteProviderSpec(ProviderSpec):
|
||||
adapter: Optional[AdapterSpec] = Field(
|
||||
default=None,
|
||||
description="""
|
||||
If some code is needed to convert the remote responses into Llama Stack compatible
|
||||
API responses, specify the adapter here. If not specified, it indicates the remote
|
||||
as being "Llama Stack compatible"
|
||||
""",
|
||||
)
|
||||
provider_id: str = "remote"
|
||||
config_class: str = "llama_toolchain.distribution.datatypes.RemoteProviderConfig"
|
||||
|
||||
@property
|
||||
def module(self) -> str:
|
||||
return f"llama_toolchain.{self.api.value}.client"
|
||||
|
||||
# need this wrapper since we don't have Pydantic v2 and that means we don't have
|
||||
|
||||
def remote_provider_spec(api: Api) -> RemoteProviderSpec:
|
||||
return RemoteProviderSpec(api=api)
|
||||
|
||||
|
||||
# TODO: use computed_field to avoid this wrapper
|
||||
# the @computed_field decorator
|
||||
def remote_provider_spec(
|
||||
api: Api, adapter: Optional[AdapterSpec] = None
|
||||
) -> RemoteProviderSpec:
|
||||
provider_id = (
|
||||
remote_provider_id(adapter.adapter_id) if adapter is not None else "remote"
|
||||
)
|
||||
module = (
|
||||
adapter.module if adapter is not None else f"llama_toolchain.{api.value}.client"
|
||||
)
|
||||
def adapter_provider_spec(api: Api, adapter: AdapterSpec) -> InlineProviderSpec:
|
||||
config_class = (
|
||||
adapter.config_class
|
||||
if adapter and adapter.config_class
|
||||
if adapter.config_class
|
||||
else "llama_toolchain.distribution.datatypes.RemoteProviderConfig"
|
||||
)
|
||||
|
||||
return RemoteProviderSpec(
|
||||
return InlineProviderSpec(
|
||||
api=api,
|
||||
provider_id=provider_id,
|
||||
pip_packages=adapter.pip_packages if adapter is not None else [],
|
||||
module=module,
|
||||
provider_id=remote_provider_id(adapter.adapter_id),
|
||||
pip_packages=adapter.pip_packages,
|
||||
module=adapter.module,
|
||||
config_class=config_class,
|
||||
is_adapter=True,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ import asyncio
|
|||
import importlib
|
||||
from typing import Any, Dict
|
||||
|
||||
from .datatypes import InlineProviderSpec, ProviderSpec, RemoteProviderSpec
|
||||
from .datatypes import InlineProviderSpec, ProviderSpec, RemoteProviderConfig
|
||||
|
||||
|
||||
def instantiate_class_type(fully_qualified_name):
|
||||
|
@ -19,38 +19,23 @@ def instantiate_class_type(fully_qualified_name):
|
|||
|
||||
# returns a class implementing the protocol corresponding to the Api
|
||||
def instantiate_provider(
|
||||
provider_spec: InlineProviderSpec,
|
||||
provider_spec: ProviderSpec,
|
||||
provider_config: Dict[str, Any],
|
||||
deps: Dict[str, ProviderSpec],
|
||||
):
|
||||
module = importlib.import_module(provider_spec.module)
|
||||
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
if isinstance(provider_spec, InlineProviderSpec):
|
||||
if provider_spec.is_adapter:
|
||||
if not issubclass(config_type, RemoteProviderConfig):
|
||||
raise ValueError(
|
||||
f"Config class {provider_spec.config_class} does not inherit from RemoteProviderConfig"
|
||||
)
|
||||
config = config_type(**provider_config)
|
||||
return asyncio.run(module.get_provider_impl(config, deps))
|
||||
|
||||
|
||||
def instantiate_client(
|
||||
provider_spec: RemoteProviderSpec, provider_config: Dict[str, Any]
|
||||
):
|
||||
module = importlib.import_module(provider_spec.module)
|
||||
|
||||
adapter = provider_spec.adapter
|
||||
if adapter is not None:
|
||||
if "adapter" not in provider_config:
|
||||
raise ValueError(
|
||||
f"Adapter is specified but not present in provider config: {provider_config}"
|
||||
)
|
||||
adapter_config = provider_config["adapter"]
|
||||
|
||||
config_type = instantiate_class_type(adapter.config_class)
|
||||
if not issubclass(config_type, RemoteProviderConfig):
|
||||
raise ValueError(
|
||||
f"Config class {adapter.config_class} does not inherit from RemoteProviderConfig"
|
||||
)
|
||||
|
||||
config = config_type(**adapter_config)
|
||||
if isinstance(provider_spec, InlineProviderSpec):
|
||||
args = [config, deps]
|
||||
else:
|
||||
config = RemoteProviderConfig(**provider_config)
|
||||
|
||||
return asyncio.run(module.get_adapter_impl(config))
|
||||
args = [config]
|
||||
return asyncio.run(module.get_provider_impl(*args))
|
||||
|
|
|
@ -38,11 +38,9 @@ from pydantic import BaseModel, ValidationError
|
|||
from termcolor import cprint
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from .datatypes import Api, DistributionSpec, ProviderSpec, RemoteProviderSpec
|
||||
from .distribution import api_endpoints
|
||||
from .dynamic import instantiate_client, instantiate_provider
|
||||
|
||||
from .registry import resolve_distribution_spec
|
||||
from .datatypes import Api, ProviderSpec, RemoteProviderSpec
|
||||
from .distribution import api_endpoints, api_providers
|
||||
from .dynamic import instantiate_provider
|
||||
|
||||
|
||||
def is_async_iterator_type(typ):
|
||||
|
@ -249,9 +247,11 @@ def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
|
|||
return [by_id[x] for x in stack]
|
||||
|
||||
|
||||
def resolve_impls(dist: DistributionSpec, config: Dict[str, Any]) -> Dict[Api, Any]:
|
||||
def resolve_impls(
|
||||
provider_specs: Dict[str, ProviderSpec], config: Dict[str, Any]
|
||||
) -> Dict[Api, Any]:
|
||||
provider_configs = config["providers"]
|
||||
provider_specs = topological_sort(dist.provider_specs.values())
|
||||
provider_specs = topological_sort(provider_specs.values())
|
||||
|
||||
impls = {}
|
||||
for provider_spec in provider_specs:
|
||||
|
@ -261,16 +261,10 @@ def resolve_impls(dist: DistributionSpec, config: Dict[str, Any]) -> Dict[Api, A
|
|||
f"Could not find provider_spec config for {api}. Please add it to the config"
|
||||
)
|
||||
|
||||
deps = {api: impls[api] for api in provider_spec.api_dependencies}
|
||||
provider_config = provider_configs[api.value]
|
||||
if isinstance(provider_spec, RemoteProviderSpec):
|
||||
impls[api] = instantiate_client(
|
||||
provider_spec,
|
||||
provider_config,
|
||||
)
|
||||
else:
|
||||
deps = {api: impls[api] for api in provider_spec.api_dependencies}
|
||||
impl = instantiate_provider(provider_spec, provider_config, deps)
|
||||
impls[api] = impl
|
||||
impl = instantiate_provider(provider_spec, provider_config, deps)
|
||||
impls[api] = impl
|
||||
|
||||
return impls
|
||||
|
||||
|
@ -279,22 +273,34 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
|||
with open(yaml_config, "r") as fp:
|
||||
config = yaml.safe_load(fp)
|
||||
|
||||
spec = config["spec"]
|
||||
dist = resolve_distribution_spec(spec)
|
||||
if dist is None:
|
||||
raise ValueError(f"Could not find distribution specification `{spec}`")
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
all_endpoints = api_endpoints()
|
||||
impls = resolve_impls(dist, config)
|
||||
all_providers = api_providers()
|
||||
|
||||
for provider_spec in dist.provider_specs.values():
|
||||
provider_specs = {}
|
||||
for api_str, provider_config in config["providers"].items():
|
||||
api = Api(api_str)
|
||||
providers = all_providers[api]
|
||||
provider_id = provider_config["provider_id"]
|
||||
if provider_id not in providers:
|
||||
raise ValueError(
|
||||
f"Unknown provider `{provider_id}` is not available for API `{api}`"
|
||||
)
|
||||
|
||||
provider_specs[api] = providers[provider_id]
|
||||
|
||||
impls = resolve_impls(provider_specs, config)
|
||||
|
||||
for provider_spec in provider_specs.values():
|
||||
api = provider_spec.api
|
||||
endpoints = all_endpoints[api]
|
||||
impl = impls[api]
|
||||
|
||||
if isinstance(provider_spec, RemoteProviderSpec):
|
||||
if (
|
||||
isinstance(provider_spec, RemoteProviderSpec)
|
||||
and provider_spec.adapter is None
|
||||
):
|
||||
for endpoint in endpoints:
|
||||
url = impl.base_url + endpoint.route
|
||||
getattr(app, endpoint.method)(endpoint.route)(
|
||||
|
|
|
@ -8,7 +8,6 @@
|
|||
|
||||
set -euo pipefail
|
||||
|
||||
# Define color codes
|
||||
RED='\033[0;31m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
|
@ -17,20 +16,17 @@ error_handler() {
|
|||
exit 1
|
||||
}
|
||||
|
||||
# Set up the error trap
|
||||
trap 'error_handler ${LINENO}' ERR
|
||||
|
||||
if [ $# -lt 2 ]; then
|
||||
echo "Usage: $0 <environment_name> <script_args...>"
|
||||
exit 1
|
||||
echo "Usage: $0 <environment_name> <script_args...>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
env_name="$1"
|
||||
shift
|
||||
|
||||
eval "$(conda shell.bash hook)"
|
||||
conda deactivate && conda activate "$env_name"
|
||||
|
||||
python_interp=$(conda run -n "$env_name" which python)
|
||||
$python_interp -m llama_toolchain.distribution.server "$@"
|
||||
$CONDA_PREFIX/bin/python -m llama_toolchain.distribution.server "$@"
|
0
llama_toolchain/distribution/run_image.sh → llama_toolchain/distribution/start_container.sh
Normal file → Executable file
0
llama_toolchain/distribution/run_image.sh → llama_toolchain/distribution/start_container.sh
Normal file → Executable file
Loading…
Add table
Add a link
Reference in a new issue