mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-06 04:34:57 +00:00
bunch more work to make adapters work
This commit is contained in:
parent
68f3db62e9
commit
c4fe72c3a3
20 changed files with 461 additions and 173 deletions
|
@ -10,20 +10,29 @@ LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
|
|||
LLAMA_TOOLCHAIN_DIR=${LLAMA_TOOLCHAIN_DIR:-}
|
||||
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
||||
|
||||
echo "llama-toolchain-dir=$LLAMA_TOOLCHAIN_DIR"
|
||||
set -euo pipefail
|
||||
|
||||
if [ "$#" -ne 3 ]; then
|
||||
echo "Usage: $0 <api_or_stack> <environment_name> <pip_dependencies>" >&2
|
||||
echo "Example: $0 [api|stack] conda-env 'numpy pandas scipy'" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
api_or_stack="$1"
|
||||
env_name="$2"
|
||||
pip_dependencies="$3"
|
||||
|
||||
# Define color codes
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
error_handler() {
|
||||
echo "Error occurred in script at line: ${1}" >&2
|
||||
exit 1
|
||||
}
|
||||
# this is set if we actually create a new conda in which case we need to clean up
|
||||
ENVNAME=""
|
||||
|
||||
# Set up the error trap
|
||||
trap 'error_handler ${LINENO}' ERR
|
||||
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
||||
source "$SCRIPT_DIR/common.sh"
|
||||
|
||||
ensure_conda_env_python310() {
|
||||
local env_name="$1"
|
||||
|
@ -52,6 +61,9 @@ ensure_conda_env_python310() {
|
|||
else
|
||||
echo "Conda environment '${env_name}' does not exist. Creating with Python ${python_version}..."
|
||||
conda create -n "${env_name}" python="${python_version}" -y
|
||||
|
||||
ENVNAME="${env_name}"
|
||||
setup_cleanup_handlers
|
||||
fi
|
||||
|
||||
eval "$(conda shell.bash hook)"
|
||||
|
@ -94,19 +106,8 @@ ensure_conda_env_python310() {
|
|||
fi
|
||||
}
|
||||
|
||||
if [ "$#" -ne 3 ]; then
|
||||
echo "Usage: $0 <environment_name> <distribution_name> <pip_dependencies>" >&2
|
||||
echo "Example: $0 my_env local-llama-8b 'numpy pandas scipy'" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
env_name="$1"
|
||||
distribution_name="$2"
|
||||
pip_dependencies="$3"
|
||||
|
||||
ensure_conda_env_python310 "$env_name" "$pip_dependencies"
|
||||
|
||||
echo -e "${GREEN}Successfully setup distribution environment. Configuring...${NC}"
|
||||
echo -e "${GREEN}Successfully setup conda environment. Configuring build...${NC}"
|
||||
|
||||
which python3
|
||||
python3 -m llama_toolchain.cli.llama distribution configure --name "$distribution_name"
|
||||
$CONDA_PREFIX/bin/python3 -m llama_toolchain.cli.llama api configure "$api_or_stack" --name "$env_name"
|
0
llama_toolchain/distribution/build_image.sh → llama_toolchain/distribution/build_container.sh
Normal file → Executable file
0
llama_toolchain/distribution/build_image.sh → llama_toolchain/distribution/build_container.sh
Normal file → Executable file
40
llama_toolchain/distribution/common.sh
Normal file
40
llama_toolchain/distribution/common.sh
Normal file
|
@ -0,0 +1,40 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
cleanup() {
|
||||
envname="$1"
|
||||
|
||||
set +x
|
||||
echo "Cleaning up..."
|
||||
conda deactivate
|
||||
conda env remove --name $envname -y
|
||||
}
|
||||
|
||||
handle_int() {
|
||||
if [ -n $ENVNAME ]; then
|
||||
cleanup $ENVNAME
|
||||
fi
|
||||
exit 1
|
||||
}
|
||||
|
||||
handle_exit() {
|
||||
if [ $? -ne 0 ]; then
|
||||
echo -e "\033[1;31mABORTING.\033[0m"
|
||||
if [ -n $ENVNAME ]; then
|
||||
cleanup $ENVNAME
|
||||
fi
|
||||
fi
|
||||
}
|
||||
|
||||
setup_cleanup_handlers() {
|
||||
trap handle_int INT
|
||||
trap handle_exit EXIT
|
||||
|
||||
__conda_setup="$('conda' 'shell.bash' 'hook' 2>/dev/null)"
|
||||
eval "$__conda_setup"
|
||||
|
||||
conda deactivate
|
||||
}
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
@ -66,36 +67,45 @@ Fully-qualified name of the module to import. The module is expected to have:
|
|||
|
||||
@json_schema_type
|
||||
class AdapterSpec(BaseModel):
|
||||
adapter_id: str = Field(
|
||||
...,
|
||||
description="Unique identifier for this adapter",
|
||||
)
|
||||
module: str = Field(
|
||||
...,
|
||||
description="""
|
||||
Fully-qualified name of the module to import. The module is expected to have:
|
||||
|
||||
- `get_adapter_impl(config, deps)`: returns the adapter implementation
|
||||
""",
|
||||
)
|
||||
pip_packages: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="The pip dependencies needed for this implementation",
|
||||
)
|
||||
config_class: str = Field(
|
||||
...,
|
||||
config_class: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Fully-qualified classname of the config for this provider",
|
||||
)
|
||||
|
||||
|
||||
class RemoteProviderConfig(BaseModel):
|
||||
base_url: str = Field(..., description="The base URL for the llama stack provider")
|
||||
url: str = Field(..., description="The URL for the provider")
|
||||
|
||||
@validator("base_url")
|
||||
@validator("url")
|
||||
@classmethod
|
||||
def validate_base_url(cls, base_url: str) -> str:
|
||||
if not base_url.startswith("http"):
|
||||
raise ValueError(f"URL must start with http: {base_url}")
|
||||
return base_url
|
||||
def validate_url(cls, url: str) -> str:
|
||||
if not url.startswith("http"):
|
||||
raise ValueError(f"URL must start with http: {url}")
|
||||
return url
|
||||
|
||||
|
||||
def remote_provider_id(adapter_id: str) -> str:
|
||||
return f"remote::{adapter_id}"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RemoteProviderSpec(ProviderSpec):
|
||||
module: str = Field(
|
||||
...,
|
||||
description="""
|
||||
Fully-qualified name of the module to import. The module is expected to have:
|
||||
- `get_client_impl(base_url)`: returns a client which can be used to call the remote implementation
|
||||
""",
|
||||
)
|
||||
adapter: Optional[AdapterSpec] = Field(
|
||||
default=None,
|
||||
description="""
|
||||
|
@ -107,6 +117,32 @@ as being "Llama Stack compatible"
|
|||
config_class: str = "llama_toolchain.distribution.datatypes.RemoteProviderConfig"
|
||||
|
||||
|
||||
# need this wrapper since we don't have Pydantic v2 and that means we don't have
|
||||
# the @computed_field decorator
|
||||
def remote_provider_spec(
|
||||
api: Api, adapter: Optional[AdapterSpec] = None
|
||||
) -> RemoteProviderSpec:
|
||||
provider_id = (
|
||||
remote_provider_id(adapter.adapter_id) if adapter is not None else "remote"
|
||||
)
|
||||
module = (
|
||||
adapter.module if adapter is not None else f"llama_toolchain.{api.value}.client"
|
||||
)
|
||||
config_class = (
|
||||
adapter.config_class
|
||||
if adapter and adapter.config_class
|
||||
else "llama_toolchain.distribution.datatypes.RemoteProviderConfig"
|
||||
)
|
||||
|
||||
return RemoteProviderSpec(
|
||||
api=api,
|
||||
provider_id=provider_id,
|
||||
pip_packages=adapter.pip_packages if adapter is not None else [],
|
||||
module=module,
|
||||
config_class=config_class,
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DistributionSpec(BaseModel):
|
||||
spec_id: str
|
||||
|
@ -119,13 +155,28 @@ class DistributionSpec(BaseModel):
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class DistributionConfig(BaseModel):
|
||||
"""References to a installed / configured DistributionSpec"""
|
||||
class PackageConfig(BaseModel):
|
||||
built_at: datetime
|
||||
|
||||
name: str
|
||||
spec: str
|
||||
conda_env: str
|
||||
package_name: str = Field(
|
||||
...,
|
||||
description="""
|
||||
Reference to the distribution this package refers to. For unregistered (adhoc) packages,
|
||||
this could be just a hash
|
||||
""",
|
||||
)
|
||||
docker_image: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Reference to the docker image if this package refers to a container",
|
||||
)
|
||||
conda_env: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Reference to the conda environment if this package refers to a conda environment",
|
||||
)
|
||||
providers: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Provider configurations for each of the APIs provided by this distribution",
|
||||
description="""
|
||||
Provider configurations for each of the APIs provided by this package. This includes configurations for
|
||||
the dependencies of these providers as well.
|
||||
""",
|
||||
)
|
||||
|
|
|
@ -30,7 +30,27 @@ def instantiate_provider(
|
|||
return asyncio.run(module.get_provider_impl(config, deps))
|
||||
|
||||
|
||||
def instantiate_client(provider_spec: RemoteProviderSpec, base_url: str):
|
||||
def instantiate_client(
|
||||
provider_spec: RemoteProviderSpec, provider_config: Dict[str, Any]
|
||||
):
|
||||
module = importlib.import_module(provider_spec.module)
|
||||
|
||||
return asyncio.run(module.get_client_impl(base_url))
|
||||
adapter = provider_spec.adapter
|
||||
if adapter is not None:
|
||||
if "adapter" not in provider_config:
|
||||
raise ValueError(
|
||||
f"Adapter is specified but not present in provider config: {provider_config}"
|
||||
)
|
||||
adapter_config = provider_config["adapter"]
|
||||
|
||||
config_type = instantiate_class_type(adapter.config_class)
|
||||
if not issubclass(config_type, RemoteProviderConfig):
|
||||
raise ValueError(
|
||||
f"Config class {adapter.config_class} does not inherit from RemoteProviderConfig"
|
||||
)
|
||||
|
||||
config = config_type(**adapter_config)
|
||||
else:
|
||||
config = RemoteProviderConfig(**provider_config)
|
||||
|
||||
return asyncio.run(module.get_adapter_impl(config))
|
||||
|
|
|
@ -7,22 +7,10 @@
|
|||
from functools import lru_cache
|
||||
from typing import List, Optional
|
||||
|
||||
from .datatypes import Api, DistributionSpec, RemoteProviderSpec
|
||||
from .datatypes import * # noqa: F403
|
||||
from .distribution import api_providers
|
||||
|
||||
|
||||
def client_module(api: Api) -> str:
|
||||
return f"llama_toolchain.{api.value}.client"
|
||||
|
||||
|
||||
def remote_spec(api: Api) -> RemoteProviderSpec:
|
||||
return RemoteProviderSpec(
|
||||
api=api,
|
||||
provider_id=f"{api.value}-remote",
|
||||
module=client_module(api),
|
||||
)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def available_distribution_specs() -> List[DistributionSpec]:
|
||||
providers = api_providers()
|
||||
|
@ -40,13 +28,14 @@ def available_distribution_specs() -> List[DistributionSpec]:
|
|||
DistributionSpec(
|
||||
spec_id="remote",
|
||||
description="Point to remote services for all llama stack APIs",
|
||||
provider_specs={x: remote_spec(x) for x in providers},
|
||||
provider_specs={x: remote_provider_spec(x) for x in providers},
|
||||
),
|
||||
DistributionSpec(
|
||||
spec_id="local-ollama",
|
||||
description="Like local, but use ollama for running LLM inference",
|
||||
provider_specs={
|
||||
Api.inference: providers[Api.inference]["meta-ollama"],
|
||||
# this is ODD; make this easier -- we just need a better function to retrieve registered providers
|
||||
Api.inference: providers[Api.inference][remote_provider_id("ollama")],
|
||||
Api.safety: providers[Api.safety]["meta-reference"],
|
||||
Api.agentic_system: providers[Api.agentic_system]["meta-reference"],
|
||||
Api.memory: providers[Api.memory]["meta-reference-faiss"],
|
||||
|
@ -57,9 +46,9 @@ def available_distribution_specs() -> List[DistributionSpec]:
|
|||
description="Test agentic with others as remote",
|
||||
provider_specs={
|
||||
Api.agentic_system: providers[Api.agentic_system]["meta-reference"],
|
||||
Api.inference: remote_spec(Api.inference),
|
||||
Api.memory: remote_spec(Api.memory),
|
||||
Api.safety: remote_spec(Api.safety),
|
||||
Api.inference: remote_provider_spec(Api.inference),
|
||||
Api.memory: remote_provider_spec(Api.memory),
|
||||
Api.safety: remote_provider_spec(Api.safety),
|
||||
},
|
||||
),
|
||||
DistributionSpec(
|
||||
|
|
|
@ -264,7 +264,8 @@ def resolve_impls(dist: DistributionSpec, config: Dict[str, Any]) -> Dict[Api, A
|
|||
provider_config = provider_configs[api.value]
|
||||
if isinstance(provider_spec, RemoteProviderSpec):
|
||||
impls[api] = instantiate_client(
|
||||
provider_spec, provider_config["base_url"].rstrip("/")
|
||||
provider_spec,
|
||||
provider_config,
|
||||
)
|
||||
else:
|
||||
deps = {api: impls[api] for api in provider_spec.api_dependencies}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue