mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
ollama remote adapter works
This commit is contained in:
parent
2076d2b6db
commit
2a1552a5eb
14 changed files with 196 additions and 128 deletions
|
@ -10,6 +10,7 @@ from llama_toolchain.cli.subcommand import Subcommand
|
||||||
|
|
||||||
from .build import ApiBuild
|
from .build import ApiBuild
|
||||||
from .configure import ApiConfigure
|
from .configure import ApiConfigure
|
||||||
|
from .start import ApiStart
|
||||||
|
|
||||||
|
|
||||||
class ApiParser(Subcommand):
|
class ApiParser(Subcommand):
|
||||||
|
@ -26,3 +27,4 @@ class ApiParser(Subcommand):
|
||||||
# Add sub-commands
|
# Add sub-commands
|
||||||
ApiBuild.create(subparsers)
|
ApiBuild.create(subparsers)
|
||||||
ApiConfigure.create(subparsers)
|
ApiConfigure.create(subparsers)
|
||||||
|
ApiStart.create(subparsers)
|
||||||
|
|
|
@ -7,9 +7,6 @@
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import random
|
|
||||||
import string
|
|
||||||
import uuid
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
@ -25,10 +22,6 @@ from llama_toolchain.common.config_dirs import BUILDS_BASE_DIR
|
||||||
from llama_toolchain.distribution.datatypes import * # noqa: F403
|
from llama_toolchain.distribution.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
def random_string():
|
|
||||||
return "".join(random.choices(string.ascii_letters + string.digits, k=8))
|
|
||||||
|
|
||||||
|
|
||||||
class BuildType(Enum):
|
class BuildType(Enum):
|
||||||
container = "container"
|
container = "container"
|
||||||
conda_env = "conda_env"
|
conda_env = "conda_env"
|
||||||
|
@ -42,6 +35,8 @@ class Dependencies(BaseModel):
|
||||||
def get_dependencies(
|
def get_dependencies(
|
||||||
provider: ProviderSpec, dependencies: Dict[str, ProviderSpec]
|
provider: ProviderSpec, dependencies: Dict[str, ProviderSpec]
|
||||||
) -> Dependencies:
|
) -> Dependencies:
|
||||||
|
from llama_toolchain.distribution.distribution import SERVER_DEPENDENCIES
|
||||||
|
|
||||||
def _deps(provider: ProviderSpec) -> Tuple[List[str], Optional[str]]:
|
def _deps(provider: ProviderSpec) -> Tuple[List[str], Optional[str]]:
|
||||||
if isinstance(provider, InlineProviderSpec):
|
if isinstance(provider, InlineProviderSpec):
|
||||||
return provider.pip_packages, provider.docker_image
|
return provider.pip_packages, provider.docker_image
|
||||||
|
@ -60,7 +55,9 @@ def get_dependencies(
|
||||||
|
|
||||||
pip_packages.extend(dep_pip_packages)
|
pip_packages.extend(dep_pip_packages)
|
||||||
|
|
||||||
return Dependencies(docker_image=docker_image, pip_packages=pip_packages)
|
return Dependencies(
|
||||||
|
docker_image=docker_image, pip_packages=pip_packages + SERVER_DEPENDENCIES
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def parse_dependencies(
|
def parse_dependencies(
|
||||||
|
@ -88,7 +85,6 @@ def parse_dependencies(
|
||||||
|
|
||||||
|
|
||||||
class ApiBuild(Subcommand):
|
class ApiBuild(Subcommand):
|
||||||
|
|
||||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.parser = subparsers.add_parser(
|
self.parser = subparsers.add_parser(
|
||||||
|
@ -125,8 +121,8 @@ class ApiBuild(Subcommand):
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
"--name",
|
"--name",
|
||||||
type=str,
|
type=str,
|
||||||
help="Name of the build target (image, conda env). Defaults to a random UUID",
|
help="Name of the build target (image, conda env)",
|
||||||
required=False,
|
required=True,
|
||||||
)
|
)
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
"--type",
|
"--type",
|
||||||
|
@ -153,11 +149,10 @@ class ApiBuild(Subcommand):
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
name = args.name or random_string()
|
|
||||||
if args.type == BuildType.container.value:
|
if args.type == BuildType.container.value:
|
||||||
package_name = f"image-{args.provider}-{name}"
|
package_name = f"image-{args.provider}-{args.name}"
|
||||||
else:
|
else:
|
||||||
package_name = f"env-{args.provider}-{name}"
|
package_name = f"env-{args.provider}-{args.name}"
|
||||||
package_name = package_name.replace("::", "-")
|
package_name = package_name.replace("::", "-")
|
||||||
|
|
||||||
build_dir = BUILDS_BASE_DIR / args.api
|
build_dir = BUILDS_BASE_DIR / args.api
|
||||||
|
@ -176,7 +171,7 @@ class ApiBuild(Subcommand):
|
||||||
}
|
}
|
||||||
with open(package_file, "w") as f:
|
with open(package_file, "w") as f:
|
||||||
c = PackageConfig(
|
c = PackageConfig(
|
||||||
built_at=str(datetime.now()),
|
built_at=datetime.now(),
|
||||||
package_name=package_name,
|
package_name=package_name,
|
||||||
docker_image=(
|
docker_image=(
|
||||||
package_name if args.type == BuildType.container.value else None
|
package_name if args.type == BuildType.container.value else None
|
||||||
|
|
83
llama_toolchain/cli/api/start.py
Normal file
83
llama_toolchain/cli/api/start.py
Normal file
|
@ -0,0 +1,83 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pkg_resources
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from llama_toolchain.cli.subcommand import Subcommand
|
||||||
|
from llama_toolchain.distribution.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
class ApiStart(Subcommand):
|
||||||
|
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||||
|
super().__init__()
|
||||||
|
self.parser = subparsers.add_parser(
|
||||||
|
"start",
|
||||||
|
prog="llama api start",
|
||||||
|
description="""start the server for a Llama API provider. You should have already built and configured the provider.""",
|
||||||
|
formatter_class=argparse.RawTextHelpFormatter,
|
||||||
|
)
|
||||||
|
self._add_arguments()
|
||||||
|
self.parser.set_defaults(func=self._run_api_start_cmd)
|
||||||
|
|
||||||
|
def _add_arguments(self):
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--yaml-config",
|
||||||
|
type=str,
|
||||||
|
help="Yaml config containing the API build configuration",
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--port",
|
||||||
|
type=int,
|
||||||
|
help="Port to run the server on. Defaults to 5000",
|
||||||
|
default=5000,
|
||||||
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--disable-ipv6",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable IPv6 support",
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _run_api_start_cmd(self, args: argparse.Namespace) -> None:
|
||||||
|
from llama_toolchain.common.exec import run_with_pty
|
||||||
|
|
||||||
|
config_file = Path(args.yaml_config)
|
||||||
|
if not config_file.exists():
|
||||||
|
self.parser.error(
|
||||||
|
f"Could not find {config_file}. Please run `llama api build` first"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
with open(config_file, "r") as f:
|
||||||
|
config = PackageConfig(**yaml.safe_load(f))
|
||||||
|
|
||||||
|
if config.docker_image:
|
||||||
|
script = pkg_resources.resource_filename(
|
||||||
|
"llama_toolchain",
|
||||||
|
"distribution/start_container.sh",
|
||||||
|
)
|
||||||
|
run_args = [script, config.docker_image]
|
||||||
|
else:
|
||||||
|
script = pkg_resources.resource_filename(
|
||||||
|
"llama_toolchain",
|
||||||
|
"distribution/start_conda_env.sh",
|
||||||
|
)
|
||||||
|
run_args = [
|
||||||
|
script,
|
||||||
|
config.conda_env,
|
||||||
|
]
|
||||||
|
|
||||||
|
run_args.extend(["--yaml_config", str(config_file), "--port", str(args.port)])
|
||||||
|
if args.disable_ipv6:
|
||||||
|
run_args.append("--disable-ipv6")
|
||||||
|
|
||||||
|
run_with_pty(run_args)
|
|
@ -38,6 +38,36 @@ class ProviderSpec(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class AdapterSpec(BaseModel):
|
||||||
|
"""
|
||||||
|
If some code is needed to convert the remote responses into Llama Stack compatible
|
||||||
|
API responses, specify the adapter here. If not specified, it indicates the remote
|
||||||
|
as being "Llama Stack compatible"
|
||||||
|
"""
|
||||||
|
|
||||||
|
adapter_id: str = Field(
|
||||||
|
...,
|
||||||
|
description="Unique identifier for this adapter",
|
||||||
|
)
|
||||||
|
module: str = Field(
|
||||||
|
...,
|
||||||
|
description="""
|
||||||
|
Fully-qualified name of the module to import. The module is expected to have:
|
||||||
|
|
||||||
|
- `get_adapter_impl(config, deps)`: returns the adapter implementation
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
pip_packages: List[str] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="The pip dependencies needed for this implementation",
|
||||||
|
)
|
||||||
|
config_class: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Fully-qualified classname of the config for this provider",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class InlineProviderSpec(ProviderSpec):
|
class InlineProviderSpec(ProviderSpec):
|
||||||
pip_packages: List[str] = Field(
|
pip_packages: List[str] = Field(
|
||||||
|
@ -63,30 +93,7 @@ Fully-qualified name of the module to import. The module is expected to have:
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
description="Higher-level API surfaces may depend on other providers to provide their functionality",
|
description="Higher-level API surfaces may depend on other providers to provide their functionality",
|
||||||
)
|
)
|
||||||
|
is_adapter: bool = False
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AdapterSpec(BaseModel):
|
|
||||||
adapter_id: str = Field(
|
|
||||||
...,
|
|
||||||
description="Unique identifier for this adapter",
|
|
||||||
)
|
|
||||||
module: str = Field(
|
|
||||||
...,
|
|
||||||
description="""
|
|
||||||
Fully-qualified name of the module to import. The module is expected to have:
|
|
||||||
|
|
||||||
- `get_adapter_impl(config, deps)`: returns the adapter implementation
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
pip_packages: List[str] = Field(
|
|
||||||
default_factory=list,
|
|
||||||
description="The pip dependencies needed for this implementation",
|
|
||||||
)
|
|
||||||
config_class: Optional[str] = Field(
|
|
||||||
default=None,
|
|
||||||
description="Fully-qualified classname of the config for this provider",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class RemoteProviderConfig(BaseModel):
|
class RemoteProviderConfig(BaseModel):
|
||||||
|
@ -106,40 +113,34 @@ def remote_provider_id(adapter_id: str) -> str:
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RemoteProviderSpec(ProviderSpec):
|
class RemoteProviderSpec(ProviderSpec):
|
||||||
adapter: Optional[AdapterSpec] = Field(
|
provider_id: str = "remote"
|
||||||
default=None,
|
|
||||||
description="""
|
|
||||||
If some code is needed to convert the remote responses into Llama Stack compatible
|
|
||||||
API responses, specify the adapter here. If not specified, it indicates the remote
|
|
||||||
as being "Llama Stack compatible"
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
config_class: str = "llama_toolchain.distribution.datatypes.RemoteProviderConfig"
|
config_class: str = "llama_toolchain.distribution.datatypes.RemoteProviderConfig"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def module(self) -> str:
|
||||||
|
return f"llama_toolchain.{self.api.value}.client"
|
||||||
|
|
||||||
# need this wrapper since we don't have Pydantic v2 and that means we don't have
|
|
||||||
|
def remote_provider_spec(api: Api) -> RemoteProviderSpec:
|
||||||
|
return RemoteProviderSpec(api=api)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: use computed_field to avoid this wrapper
|
||||||
# the @computed_field decorator
|
# the @computed_field decorator
|
||||||
def remote_provider_spec(
|
def adapter_provider_spec(api: Api, adapter: AdapterSpec) -> InlineProviderSpec:
|
||||||
api: Api, adapter: Optional[AdapterSpec] = None
|
|
||||||
) -> RemoteProviderSpec:
|
|
||||||
provider_id = (
|
|
||||||
remote_provider_id(adapter.adapter_id) if adapter is not None else "remote"
|
|
||||||
)
|
|
||||||
module = (
|
|
||||||
adapter.module if adapter is not None else f"llama_toolchain.{api.value}.client"
|
|
||||||
)
|
|
||||||
config_class = (
|
config_class = (
|
||||||
adapter.config_class
|
adapter.config_class
|
||||||
if adapter and adapter.config_class
|
if adapter.config_class
|
||||||
else "llama_toolchain.distribution.datatypes.RemoteProviderConfig"
|
else "llama_toolchain.distribution.datatypes.RemoteProviderConfig"
|
||||||
)
|
)
|
||||||
|
|
||||||
return RemoteProviderSpec(
|
return InlineProviderSpec(
|
||||||
api=api,
|
api=api,
|
||||||
provider_id=provider_id,
|
provider_id=remote_provider_id(adapter.adapter_id),
|
||||||
pip_packages=adapter.pip_packages if adapter is not None else [],
|
pip_packages=adapter.pip_packages,
|
||||||
module=module,
|
module=adapter.module,
|
||||||
config_class=config_class,
|
config_class=config_class,
|
||||||
|
is_adapter=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,7 @@ import asyncio
|
||||||
import importlib
|
import importlib
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
from .datatypes import InlineProviderSpec, ProviderSpec, RemoteProviderSpec
|
from .datatypes import InlineProviderSpec, ProviderSpec, RemoteProviderConfig
|
||||||
|
|
||||||
|
|
||||||
def instantiate_class_type(fully_qualified_name):
|
def instantiate_class_type(fully_qualified_name):
|
||||||
|
@ -19,38 +19,23 @@ def instantiate_class_type(fully_qualified_name):
|
||||||
|
|
||||||
# returns a class implementing the protocol corresponding to the Api
|
# returns a class implementing the protocol corresponding to the Api
|
||||||
def instantiate_provider(
|
def instantiate_provider(
|
||||||
provider_spec: InlineProviderSpec,
|
provider_spec: ProviderSpec,
|
||||||
provider_config: Dict[str, Any],
|
provider_config: Dict[str, Any],
|
||||||
deps: Dict[str, ProviderSpec],
|
deps: Dict[str, ProviderSpec],
|
||||||
):
|
):
|
||||||
module = importlib.import_module(provider_spec.module)
|
module = importlib.import_module(provider_spec.module)
|
||||||
|
|
||||||
config_type = instantiate_class_type(provider_spec.config_class)
|
config_type = instantiate_class_type(provider_spec.config_class)
|
||||||
|
if isinstance(provider_spec, InlineProviderSpec):
|
||||||
|
if provider_spec.is_adapter:
|
||||||
|
if not issubclass(config_type, RemoteProviderConfig):
|
||||||
|
raise ValueError(
|
||||||
|
f"Config class {provider_spec.config_class} does not inherit from RemoteProviderConfig"
|
||||||
|
)
|
||||||
config = config_type(**provider_config)
|
config = config_type(**provider_config)
|
||||||
return asyncio.run(module.get_provider_impl(config, deps))
|
|
||||||
|
|
||||||
|
if isinstance(provider_spec, InlineProviderSpec):
|
||||||
def instantiate_client(
|
args = [config, deps]
|
||||||
provider_spec: RemoteProviderSpec, provider_config: Dict[str, Any]
|
|
||||||
):
|
|
||||||
module = importlib.import_module(provider_spec.module)
|
|
||||||
|
|
||||||
adapter = provider_spec.adapter
|
|
||||||
if adapter is not None:
|
|
||||||
if "adapter" not in provider_config:
|
|
||||||
raise ValueError(
|
|
||||||
f"Adapter is specified but not present in provider config: {provider_config}"
|
|
||||||
)
|
|
||||||
adapter_config = provider_config["adapter"]
|
|
||||||
|
|
||||||
config_type = instantiate_class_type(adapter.config_class)
|
|
||||||
if not issubclass(config_type, RemoteProviderConfig):
|
|
||||||
raise ValueError(
|
|
||||||
f"Config class {adapter.config_class} does not inherit from RemoteProviderConfig"
|
|
||||||
)
|
|
||||||
|
|
||||||
config = config_type(**adapter_config)
|
|
||||||
else:
|
else:
|
||||||
config = RemoteProviderConfig(**provider_config)
|
args = [config]
|
||||||
|
return asyncio.run(module.get_provider_impl(*args))
|
||||||
return asyncio.run(module.get_adapter_impl(config))
|
|
||||||
|
|
|
@ -38,11 +38,9 @@ from pydantic import BaseModel, ValidationError
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from .datatypes import Api, DistributionSpec, ProviderSpec, RemoteProviderSpec
|
from .datatypes import Api, ProviderSpec, RemoteProviderSpec
|
||||||
from .distribution import api_endpoints
|
from .distribution import api_endpoints, api_providers
|
||||||
from .dynamic import instantiate_client, instantiate_provider
|
from .dynamic import instantiate_provider
|
||||||
|
|
||||||
from .registry import resolve_distribution_spec
|
|
||||||
|
|
||||||
|
|
||||||
def is_async_iterator_type(typ):
|
def is_async_iterator_type(typ):
|
||||||
|
@ -249,9 +247,11 @@ def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
|
||||||
return [by_id[x] for x in stack]
|
return [by_id[x] for x in stack]
|
||||||
|
|
||||||
|
|
||||||
def resolve_impls(dist: DistributionSpec, config: Dict[str, Any]) -> Dict[Api, Any]:
|
def resolve_impls(
|
||||||
|
provider_specs: Dict[str, ProviderSpec], config: Dict[str, Any]
|
||||||
|
) -> Dict[Api, Any]:
|
||||||
provider_configs = config["providers"]
|
provider_configs = config["providers"]
|
||||||
provider_specs = topological_sort(dist.provider_specs.values())
|
provider_specs = topological_sort(provider_specs.values())
|
||||||
|
|
||||||
impls = {}
|
impls = {}
|
||||||
for provider_spec in provider_specs:
|
for provider_spec in provider_specs:
|
||||||
|
@ -261,16 +261,10 @@ def resolve_impls(dist: DistributionSpec, config: Dict[str, Any]) -> Dict[Api, A
|
||||||
f"Could not find provider_spec config for {api}. Please add it to the config"
|
f"Could not find provider_spec config for {api}. Please add it to the config"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
deps = {api: impls[api] for api in provider_spec.api_dependencies}
|
||||||
provider_config = provider_configs[api.value]
|
provider_config = provider_configs[api.value]
|
||||||
if isinstance(provider_spec, RemoteProviderSpec):
|
impl = instantiate_provider(provider_spec, provider_config, deps)
|
||||||
impls[api] = instantiate_client(
|
impls[api] = impl
|
||||||
provider_spec,
|
|
||||||
provider_config,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
deps = {api: impls[api] for api in provider_spec.api_dependencies}
|
|
||||||
impl = instantiate_provider(provider_spec, provider_config, deps)
|
|
||||||
impls[api] = impl
|
|
||||||
|
|
||||||
return impls
|
return impls
|
||||||
|
|
||||||
|
@ -279,22 +273,34 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
||||||
with open(yaml_config, "r") as fp:
|
with open(yaml_config, "r") as fp:
|
||||||
config = yaml.safe_load(fp)
|
config = yaml.safe_load(fp)
|
||||||
|
|
||||||
spec = config["spec"]
|
|
||||||
dist = resolve_distribution_spec(spec)
|
|
||||||
if dist is None:
|
|
||||||
raise ValueError(f"Could not find distribution specification `{spec}`")
|
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
all_endpoints = api_endpoints()
|
all_endpoints = api_endpoints()
|
||||||
impls = resolve_impls(dist, config)
|
all_providers = api_providers()
|
||||||
|
|
||||||
for provider_spec in dist.provider_specs.values():
|
provider_specs = {}
|
||||||
|
for api_str, provider_config in config["providers"].items():
|
||||||
|
api = Api(api_str)
|
||||||
|
providers = all_providers[api]
|
||||||
|
provider_id = provider_config["provider_id"]
|
||||||
|
if provider_id not in providers:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown provider `{provider_id}` is not available for API `{api}`"
|
||||||
|
)
|
||||||
|
|
||||||
|
provider_specs[api] = providers[provider_id]
|
||||||
|
|
||||||
|
impls = resolve_impls(provider_specs, config)
|
||||||
|
|
||||||
|
for provider_spec in provider_specs.values():
|
||||||
api = provider_spec.api
|
api = provider_spec.api
|
||||||
endpoints = all_endpoints[api]
|
endpoints = all_endpoints[api]
|
||||||
impl = impls[api]
|
impl = impls[api]
|
||||||
|
|
||||||
if isinstance(provider_spec, RemoteProviderSpec):
|
if (
|
||||||
|
isinstance(provider_spec, RemoteProviderSpec)
|
||||||
|
and provider_spec.adapter is None
|
||||||
|
):
|
||||||
for endpoint in endpoints:
|
for endpoint in endpoints:
|
||||||
url = impl.base_url + endpoint.route
|
url = impl.base_url + endpoint.route
|
||||||
getattr(app, endpoint.method)(endpoint.route)(
|
getattr(app, endpoint.method)(endpoint.route)(
|
||||||
|
|
|
@ -8,7 +8,6 @@
|
||||||
|
|
||||||
set -euo pipefail
|
set -euo pipefail
|
||||||
|
|
||||||
# Define color codes
|
|
||||||
RED='\033[0;31m'
|
RED='\033[0;31m'
|
||||||
NC='\033[0m' # No Color
|
NC='\033[0m' # No Color
|
||||||
|
|
||||||
|
@ -17,20 +16,17 @@ error_handler() {
|
||||||
exit 1
|
exit 1
|
||||||
}
|
}
|
||||||
|
|
||||||
# Set up the error trap
|
|
||||||
trap 'error_handler ${LINENO}' ERR
|
trap 'error_handler ${LINENO}' ERR
|
||||||
|
|
||||||
if [ $# -lt 2 ]; then
|
if [ $# -lt 2 ]; then
|
||||||
echo "Usage: $0 <environment_name> <script_args...>"
|
echo "Usage: $0 <environment_name> <script_args...>"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
||||||
env_name="$1"
|
env_name="$1"
|
||||||
shift
|
shift
|
||||||
|
|
||||||
eval "$(conda shell.bash hook)"
|
eval "$(conda shell.bash hook)"
|
||||||
conda deactivate && conda activate "$env_name"
|
conda deactivate && conda activate "$env_name"
|
||||||
|
|
||||||
python_interp=$(conda run -n "$env_name" which python)
|
$CONDA_PREFIX/bin/python -m llama_toolchain.distribution.server "$@"
|
||||||
$python_interp -m llama_toolchain.distribution.server "$@"
|
|
0
llama_toolchain/distribution/run_image.sh → llama_toolchain/distribution/start_container.sh
Normal file → Executable file
0
llama_toolchain/distribution/run_image.sh → llama_toolchain/distribution/start_container.sh
Normal file → Executable file
|
@ -4,4 +4,4 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .ollama import get_adapter_impl # noqa
|
from .ollama import get_provider_impl # noqa
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import AsyncGenerator
|
from typing import Any, AsyncGenerator
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
@ -36,7 +36,7 @@ OLLAMA_SUPPORTED_SKUS = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: RemoteProviderConfig) -> Inference:
|
async def get_provider_impl(config: RemoteProviderConfig, _deps: Any) -> Inference:
|
||||||
impl = OllamaInferenceAdapter(config.url)
|
impl = OllamaInferenceAdapter(config.url)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -26,7 +26,7 @@ from .api import (
|
||||||
from .event_logger import EventLogger
|
from .event_logger import EventLogger
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: RemoteProviderConfig) -> Inference:
|
async def get_provider_impl(config: RemoteProviderConfig) -> Inference:
|
||||||
return InferenceClient(config.url)
|
return InferenceClient(config.url)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -27,11 +27,11 @@ def available_inference_providers() -> List[ProviderSpec]:
|
||||||
module="llama_toolchain.inference.meta_reference",
|
module="llama_toolchain.inference.meta_reference",
|
||||||
config_class="llama_toolchain.inference.meta_reference.MetaReferenceImplConfig",
|
config_class="llama_toolchain.inference.meta_reference.MetaReferenceImplConfig",
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
adapter_provider_spec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
adapter_id="ollama",
|
adapter_id="ollama",
|
||||||
pip_packages=[],
|
pip_packages=["ollama"],
|
||||||
module="llama_toolchain.inference.adapters.ollama",
|
module="llama_toolchain.inference.adapters.ollama",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
|
|
@ -16,7 +16,7 @@ from llama_toolchain.distribution.datatypes import RemoteProviderConfig
|
||||||
from .api import * # noqa: F403
|
from .api import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: RemoteProviderConfig) -> Memory:
|
async def get_provider_impl(config: RemoteProviderConfig) -> Memory:
|
||||||
return MemoryClient(config.url)
|
return MemoryClient(config.url)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,7 @@ from llama_toolchain.distribution.datatypes import RemoteProviderConfig
|
||||||
from .api import * # noqa: F403
|
from .api import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: RemoteProviderConfig) -> Safety:
|
async def get_provider_impl(config: RemoteProviderConfig) -> Safety:
|
||||||
return SafetyClient(config.url)
|
return SafetyClient(config.url)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue