mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-15 15:02:40 +00:00
Merge remote-tracking branch 'origin/main' into dependabot/uv/openai-2.5.0
This commit is contained in:
commit
13450c1a68
317 changed files with 86802 additions and 18957 deletions
|
|
@ -41,7 +41,7 @@ class AccessRule(BaseModel):
|
|||
A rule defines a list of action either to permit or to forbid. It may specify a
|
||||
principal or a resource that must match for the rule to take effect. The resource
|
||||
to match should be specified in the form of a type qualified identifier, e.g.
|
||||
model::my-model or vector_db::some-db, or a wildcard for all resources of a type,
|
||||
model::my-model or vector_store::some-db, or a wildcard for all resources of a type,
|
||||
e.g. model::*. If the principal or resource are not specified, they will match all
|
||||
requests.
|
||||
|
||||
|
|
@ -79,9 +79,9 @@ class AccessRule(BaseModel):
|
|||
description: any user has read access to any resource created by a member of their team
|
||||
- forbid:
|
||||
actions: [create, read, delete]
|
||||
resource: vector_db::*
|
||||
resource: vector_store::*
|
||||
unless: user with admin in roles
|
||||
description: only user with admin role can use vector_db resources
|
||||
description: only user with admin role can use vector_store resources
|
||||
|
||||
"""
|
||||
|
||||
|
|
|
|||
|
|
@ -1,410 +0,0 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
|
||||
LLAMA_STACK_CLIENT_DIR=${LLAMA_STACK_CLIENT_DIR:-}
|
||||
|
||||
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
||||
PYPI_VERSION=${PYPI_VERSION:-}
|
||||
BUILD_PLATFORM=${BUILD_PLATFORM:-}
|
||||
# This timeout (in seconds) is necessary when installing PyTorch via uv since it's likely to time out
|
||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||
UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
|
||||
|
||||
# mounting is not supported by docker buildx, so we use COPY instead
|
||||
USE_COPY_NOT_MOUNT=${USE_COPY_NOT_MOUNT:-}
|
||||
# Path to the run.yaml file in the container
|
||||
RUN_CONFIG_PATH=/app/run.yaml
|
||||
|
||||
BUILD_CONTEXT_DIR=$(pwd)
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# Define color codes
|
||||
RED='\033[0;31m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# Usage function
|
||||
usage() {
|
||||
echo "Usage: $0 --image-name <image_name> --container-base <container_base> --normal-deps <pip_dependencies> [--run-config <run_config>] [--external-provider-deps <external_provider_deps>] [--optional-deps <special_pip_deps>]"
|
||||
echo "Example: $0 --image-name llama-stack-img --container-base python:3.12-slim --normal-deps 'numpy pandas' --run-config ./run.yaml --external-provider-deps 'foo' --optional-deps 'bar'"
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Parse arguments
|
||||
image_name=""
|
||||
container_base=""
|
||||
normal_deps=""
|
||||
external_provider_deps=""
|
||||
optional_deps=""
|
||||
run_config=""
|
||||
distro_or_config=""
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
key="$1"
|
||||
case "$key" in
|
||||
--image-name)
|
||||
if [[ -z "$2" || "$2" == --* ]]; then
|
||||
echo "Error: --image-name requires a string value" >&2
|
||||
usage
|
||||
fi
|
||||
image_name="$2"
|
||||
shift 2
|
||||
;;
|
||||
--container-base)
|
||||
if [[ -z "$2" || "$2" == --* ]]; then
|
||||
echo "Error: --container-base requires a string value" >&2
|
||||
usage
|
||||
fi
|
||||
container_base="$2"
|
||||
shift 2
|
||||
;;
|
||||
--normal-deps)
|
||||
if [[ -z "$2" || "$2" == --* ]]; then
|
||||
echo "Error: --normal-deps requires a string value" >&2
|
||||
usage
|
||||
fi
|
||||
normal_deps="$2"
|
||||
shift 2
|
||||
;;
|
||||
--external-provider-deps)
|
||||
if [[ -z "$2" || "$2" == --* ]]; then
|
||||
echo "Error: --external-provider-deps requires a string value" >&2
|
||||
usage
|
||||
fi
|
||||
external_provider_deps="$2"
|
||||
shift 2
|
||||
;;
|
||||
--optional-deps)
|
||||
if [[ -z "$2" || "$2" == --* ]]; then
|
||||
echo "Error: --optional-deps requires a string value" >&2
|
||||
usage
|
||||
fi
|
||||
optional_deps="$2"
|
||||
shift 2
|
||||
;;
|
||||
--run-config)
|
||||
if [[ -z "$2" || "$2" == --* ]]; then
|
||||
echo "Error: --run-config requires a string value" >&2
|
||||
usage
|
||||
fi
|
||||
run_config="$2"
|
||||
shift 2
|
||||
;;
|
||||
--distro-or-config)
|
||||
if [[ -z "$2" || "$2" == --* ]]; then
|
||||
echo "Error: --distro-or-config requires a string value" >&2
|
||||
usage
|
||||
fi
|
||||
distro_or_config="$2"
|
||||
shift 2
|
||||
;;
|
||||
*)
|
||||
echo "Unknown option: $1" >&2
|
||||
usage
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Check required arguments
|
||||
if [[ -z "$image_name" || -z "$container_base" || -z "$normal_deps" ]]; then
|
||||
echo "Error: --image-name, --container-base, and --normal-deps are required." >&2
|
||||
usage
|
||||
fi
|
||||
|
||||
CONTAINER_BINARY=${CONTAINER_BINARY:-docker}
|
||||
CONTAINER_OPTS=${CONTAINER_OPTS:---progress=plain}
|
||||
TEMP_DIR=$(mktemp -d)
|
||||
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
||||
source "$SCRIPT_DIR/common.sh"
|
||||
|
||||
add_to_container() {
|
||||
output_file="$TEMP_DIR/Containerfile"
|
||||
if [ -t 0 ]; then
|
||||
printf '%s\n' "$1" >>"$output_file"
|
||||
else
|
||||
cat >>"$output_file"
|
||||
fi
|
||||
}
|
||||
|
||||
if ! is_command_available "$CONTAINER_BINARY"; then
|
||||
printf "${RED}Error: ${CONTAINER_BINARY} command not found. Is ${CONTAINER_BINARY} installed and in your PATH?${NC}" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ $container_base == *"registry.access.redhat.com/ubi9"* ]]; then
|
||||
add_to_container << EOF
|
||||
FROM $container_base
|
||||
WORKDIR /app
|
||||
|
||||
# We install the Python 3.12 dev headers and build tools so that any
|
||||
# C-extension wheels (e.g. polyleven, faiss-cpu) can compile successfully.
|
||||
|
||||
RUN dnf -y update && dnf install -y iputils git net-tools wget \
|
||||
vim-minimal python3.12 python3.12-pip python3.12-wheel \
|
||||
python3.12-setuptools python3.12-devel gcc gcc-c++ make && \
|
||||
ln -s /bin/pip3.12 /bin/pip && ln -s /bin/python3.12 /bin/python && dnf clean all
|
||||
|
||||
ENV UV_SYSTEM_PYTHON=1
|
||||
RUN pip install uv
|
||||
EOF
|
||||
else
|
||||
add_to_container << EOF
|
||||
FROM $container_base
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update && apt-get install -y \
|
||||
iputils-ping net-tools iproute2 dnsutils telnet \
|
||||
curl wget telnet git\
|
||||
procps psmisc lsof \
|
||||
traceroute \
|
||||
bubblewrap \
|
||||
gcc g++ \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
ENV UV_SYSTEM_PYTHON=1
|
||||
RUN pip install uv
|
||||
EOF
|
||||
fi
|
||||
|
||||
# Add pip dependencies first since llama-stack is what will change most often
|
||||
# so we can reuse layers.
|
||||
if [ -n "$normal_deps" ]; then
|
||||
read -ra pip_args <<< "$normal_deps"
|
||||
quoted_deps=$(printf " %q" "${pip_args[@]}")
|
||||
add_to_container << EOF
|
||||
RUN uv pip install --no-cache $quoted_deps
|
||||
EOF
|
||||
fi
|
||||
|
||||
if [ -n "$optional_deps" ]; then
|
||||
IFS='#' read -ra parts <<<"$optional_deps"
|
||||
for part in "${parts[@]}"; do
|
||||
read -ra pip_args <<< "$part"
|
||||
quoted_deps=$(printf " %q" "${pip_args[@]}")
|
||||
add_to_container <<EOF
|
||||
RUN uv pip install --no-cache $quoted_deps
|
||||
EOF
|
||||
done
|
||||
fi
|
||||
|
||||
if [ -n "$external_provider_deps" ]; then
|
||||
IFS='#' read -ra parts <<<"$external_provider_deps"
|
||||
for part in "${parts[@]}"; do
|
||||
read -ra pip_args <<< "$part"
|
||||
quoted_deps=$(printf " %q" "${pip_args[@]}")
|
||||
add_to_container <<EOF
|
||||
RUN uv pip install --no-cache $quoted_deps
|
||||
EOF
|
||||
add_to_container <<EOF
|
||||
RUN python3 - <<PYTHON | uv pip install --no-cache -r -
|
||||
import importlib
|
||||
import sys
|
||||
|
||||
try:
|
||||
package_name = '$part'.split('==')[0].split('>=')[0].split('<=')[0].split('!=')[0].split('<')[0].split('>')[0]
|
||||
module = importlib.import_module(f'{package_name}.provider')
|
||||
spec = module.get_provider_spec()
|
||||
if hasattr(spec, 'pip_packages') and spec.pip_packages:
|
||||
if isinstance(spec.pip_packages, (list, tuple)):
|
||||
print('\n'.join(spec.pip_packages))
|
||||
except Exception as e:
|
||||
print(f'Error getting provider spec for {package_name}: {e}', file=sys.stderr)
|
||||
PYTHON
|
||||
EOF
|
||||
done
|
||||
fi
|
||||
|
||||
get_python_cmd() {
|
||||
if is_command_available python; then
|
||||
echo "python"
|
||||
elif is_command_available python3; then
|
||||
echo "python3"
|
||||
else
|
||||
echo "Error: Neither python nor python3 is installed. Please install Python to continue." >&2
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
if [ -n "$run_config" ]; then
|
||||
# Copy the run config to the build context since it's an absolute path
|
||||
cp "$run_config" "$BUILD_CONTEXT_DIR/run.yaml"
|
||||
|
||||
# Parse the run.yaml configuration to identify external provider directories
|
||||
# If external providers are specified, copy their directory to the container
|
||||
# and update the configuration to reference the new container path
|
||||
python_cmd=$(get_python_cmd)
|
||||
external_providers_dir=$($python_cmd -c "import yaml; config = yaml.safe_load(open('$run_config')); print(config.get('external_providers_dir') or '')")
|
||||
external_providers_dir=$(eval echo "$external_providers_dir")
|
||||
if [ -n "$external_providers_dir" ]; then
|
||||
if [ -d "$external_providers_dir" ]; then
|
||||
echo "Copying external providers directory: $external_providers_dir"
|
||||
cp -r "$external_providers_dir" "$BUILD_CONTEXT_DIR/providers.d"
|
||||
add_to_container << EOF
|
||||
COPY providers.d /.llama/providers.d
|
||||
EOF
|
||||
fi
|
||||
|
||||
# Edit the run.yaml file to change the external_providers_dir to /.llama/providers.d
|
||||
if [ "$(uname)" = "Darwin" ]; then
|
||||
sed -i.bak -e 's|external_providers_dir:.*|external_providers_dir: /.llama/providers.d|' "$BUILD_CONTEXT_DIR/run.yaml"
|
||||
rm -f "$BUILD_CONTEXT_DIR/run.yaml.bak"
|
||||
else
|
||||
sed -i 's|external_providers_dir:.*|external_providers_dir: /.llama/providers.d|' "$BUILD_CONTEXT_DIR/run.yaml"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Copy run config into docker image
|
||||
add_to_container << EOF
|
||||
COPY run.yaml $RUN_CONFIG_PATH
|
||||
EOF
|
||||
fi
|
||||
|
||||
stack_mount="/app/llama-stack-source"
|
||||
client_mount="/app/llama-stack-client-source"
|
||||
|
||||
install_local_package() {
|
||||
local dir="$1"
|
||||
local mount_point="$2"
|
||||
local name="$3"
|
||||
|
||||
if [ ! -d "$dir" ]; then
|
||||
echo "${RED}Warning: $name is set but directory does not exist: $dir${NC}" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ "$USE_COPY_NOT_MOUNT" = "true" ]; then
|
||||
add_to_container << EOF
|
||||
COPY $dir $mount_point
|
||||
EOF
|
||||
fi
|
||||
add_to_container << EOF
|
||||
RUN uv pip install --no-cache -e $mount_point
|
||||
EOF
|
||||
}
|
||||
|
||||
|
||||
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
|
||||
install_local_package "$LLAMA_STACK_CLIENT_DIR" "$client_mount" "LLAMA_STACK_CLIENT_DIR"
|
||||
fi
|
||||
|
||||
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||
install_local_package "$LLAMA_STACK_DIR" "$stack_mount" "LLAMA_STACK_DIR"
|
||||
else
|
||||
if [ -n "$TEST_PYPI_VERSION" ]; then
|
||||
# these packages are damaged in test-pypi, so install them first
|
||||
add_to_container << EOF
|
||||
RUN uv pip install --no-cache fastapi libcst
|
||||
EOF
|
||||
add_to_container << EOF
|
||||
RUN uv pip install --no-cache --extra-index-url https://test.pypi.org/simple/ \
|
||||
--index-strategy unsafe-best-match \
|
||||
llama-stack==$TEST_PYPI_VERSION
|
||||
|
||||
EOF
|
||||
else
|
||||
if [ -n "$PYPI_VERSION" ]; then
|
||||
SPEC_VERSION="llama-stack==${PYPI_VERSION}"
|
||||
else
|
||||
SPEC_VERSION="llama-stack"
|
||||
fi
|
||||
add_to_container << EOF
|
||||
RUN uv pip install --no-cache $SPEC_VERSION
|
||||
EOF
|
||||
fi
|
||||
fi
|
||||
|
||||
# remove uv after installation
|
||||
add_to_container << EOF
|
||||
RUN pip uninstall -y uv
|
||||
EOF
|
||||
|
||||
# If a run config is provided, we use the llama stack CLI
|
||||
if [[ -n "$run_config" ]]; then
|
||||
add_to_container << EOF
|
||||
ENTRYPOINT ["llama", "stack", "run", "$RUN_CONFIG_PATH"]
|
||||
EOF
|
||||
elif [[ "$distro_or_config" != *.yaml ]]; then
|
||||
add_to_container << EOF
|
||||
ENTRYPOINT ["llama", "stack", "run", "$distro_or_config"]
|
||||
EOF
|
||||
fi
|
||||
|
||||
# Add other require item commands genearic to all containers
|
||||
add_to_container << EOF
|
||||
|
||||
RUN mkdir -p /.llama /.cache && chmod -R g+rw /.llama /.cache && (chmod -R g+rw /app 2>/dev/null || true)
|
||||
EOF
|
||||
|
||||
printf "Containerfile created successfully in %s/Containerfile\n\n" "$TEMP_DIR"
|
||||
cat "$TEMP_DIR"/Containerfile
|
||||
printf "\n"
|
||||
|
||||
# Start building the CLI arguments
|
||||
CLI_ARGS=()
|
||||
|
||||
# Read CONTAINER_OPTS and put it in an array
|
||||
read -ra CLI_ARGS <<< "$CONTAINER_OPTS"
|
||||
|
||||
if [ "$USE_COPY_NOT_MOUNT" != "true" ]; then
|
||||
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||
CLI_ARGS+=("-v" "$(readlink -f "$LLAMA_STACK_DIR"):$stack_mount")
|
||||
fi
|
||||
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
|
||||
CLI_ARGS+=("-v" "$(readlink -f "$LLAMA_STACK_CLIENT_DIR"):$client_mount")
|
||||
fi
|
||||
fi
|
||||
|
||||
if is_command_available selinuxenabled && selinuxenabled; then
|
||||
# Disable SELinux labels -- we don't want to relabel the llama-stack source dir
|
||||
CLI_ARGS+=("--security-opt" "label=disable")
|
||||
fi
|
||||
|
||||
# Set version tag based on PyPI version
|
||||
if [ -n "$PYPI_VERSION" ]; then
|
||||
version_tag="$PYPI_VERSION"
|
||||
elif [ -n "$TEST_PYPI_VERSION" ]; then
|
||||
version_tag="test-$TEST_PYPI_VERSION"
|
||||
elif [[ -n "$LLAMA_STACK_DIR" || -n "$LLAMA_STACK_CLIENT_DIR" ]]; then
|
||||
version_tag="dev"
|
||||
else
|
||||
URL="https://pypi.org/pypi/llama-stack/json"
|
||||
version_tag=$(curl -s $URL | jq -r '.info.version')
|
||||
fi
|
||||
|
||||
# Add version tag to image name
|
||||
image_tag="$image_name:$version_tag"
|
||||
|
||||
# Detect platform architecture
|
||||
ARCH=$(uname -m)
|
||||
if [ -n "$BUILD_PLATFORM" ]; then
|
||||
CLI_ARGS+=("--platform" "$BUILD_PLATFORM")
|
||||
elif [ "$ARCH" = "arm64" ] || [ "$ARCH" = "aarch64" ]; then
|
||||
CLI_ARGS+=("--platform" "linux/arm64")
|
||||
elif [ "$ARCH" = "x86_64" ]; then
|
||||
CLI_ARGS+=("--platform" "linux/amd64")
|
||||
else
|
||||
echo "Unsupported architecture: $ARCH"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "PWD: $(pwd)"
|
||||
echo "Containerfile: $TEMP_DIR/Containerfile"
|
||||
set -x
|
||||
|
||||
$CONTAINER_BINARY build \
|
||||
"${CLI_ARGS[@]}" \
|
||||
-t "$image_tag" \
|
||||
-f "$TEMP_DIR/Containerfile" \
|
||||
"$BUILD_CONTEXT_DIR"
|
||||
|
||||
# clean up tmp/configs
|
||||
rm -rf "$BUILD_CONTEXT_DIR/run.yaml" "$TEMP_DIR"
|
||||
set +x
|
||||
|
||||
echo "Success!"
|
||||
|
|
@ -1,220 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
|
||||
LLAMA_STACK_CLIENT_DIR=${LLAMA_STACK_CLIENT_DIR:-}
|
||||
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
||||
# This timeout (in seconds) is necessary when installing PyTorch via uv since it's likely to time out
|
||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||
UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
|
||||
UV_SYSTEM_PYTHON=${UV_SYSTEM_PYTHON:-}
|
||||
VIRTUAL_ENV=${VIRTUAL_ENV:-}
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# Define color codes
|
||||
RED='\033[0;31m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
||||
source "$SCRIPT_DIR/common.sh"
|
||||
|
||||
# Usage function
|
||||
usage() {
|
||||
echo "Usage: $0 --env-name <env_name> --normal-deps <pip_dependencies> [--external-provider-deps <external_provider_deps>] [--optional-deps <special_pip_deps>]"
|
||||
echo "Example: $0 --env-name mybuild --normal-deps 'numpy pandas scipy' --external-provider-deps 'foo' --optional-deps 'bar'"
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Parse arguments
|
||||
env_name=""
|
||||
normal_deps=""
|
||||
external_provider_deps=""
|
||||
optional_deps=""
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
key="$1"
|
||||
case "$key" in
|
||||
--env-name)
|
||||
if [[ -z "$2" || "$2" == --* ]]; then
|
||||
echo "Error: --env-name requires a string value" >&2
|
||||
usage
|
||||
fi
|
||||
env_name="$2"
|
||||
shift 2
|
||||
;;
|
||||
--normal-deps)
|
||||
if [[ -z "$2" || "$2" == --* ]]; then
|
||||
echo "Error: --normal-deps requires a string value" >&2
|
||||
usage
|
||||
fi
|
||||
normal_deps="$2"
|
||||
shift 2
|
||||
;;
|
||||
--external-provider-deps)
|
||||
if [[ -z "$2" || "$2" == --* ]]; then
|
||||
echo "Error: --external-provider-deps requires a string value" >&2
|
||||
usage
|
||||
fi
|
||||
external_provider_deps="$2"
|
||||
shift 2
|
||||
;;
|
||||
--optional-deps)
|
||||
if [[ -z "$2" || "$2" == --* ]]; then
|
||||
echo "Error: --optional-deps requires a string value" >&2
|
||||
usage
|
||||
fi
|
||||
optional_deps="$2"
|
||||
shift 2
|
||||
;;
|
||||
*)
|
||||
echo "Unknown option: $1" >&2
|
||||
usage
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Check required arguments
|
||||
if [[ -z "$env_name" || -z "$normal_deps" ]]; then
|
||||
echo "Error: --env-name and --normal-deps are required." >&2
|
||||
usage
|
||||
fi
|
||||
|
||||
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||
echo "Using llama-stack-dir=$LLAMA_STACK_DIR"
|
||||
fi
|
||||
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
|
||||
echo "Using llama-stack-client-dir=$LLAMA_STACK_CLIENT_DIR"
|
||||
fi
|
||||
|
||||
ENVNAME=""
|
||||
|
||||
# pre-run checks to make sure we can proceed with the installation
|
||||
pre_run_checks() {
|
||||
local env_name="$1"
|
||||
|
||||
if ! is_command_available uv; then
|
||||
echo "uv is not installed, trying to install it."
|
||||
if ! is_command_available pip; then
|
||||
echo "pip is not installed, cannot automatically install 'uv'."
|
||||
echo "Follow this link to install it:"
|
||||
echo "https://docs.astral.sh/uv/getting-started/installation/"
|
||||
exit 1
|
||||
else
|
||||
pip install uv
|
||||
fi
|
||||
fi
|
||||
|
||||
# checking if an environment with the same name already exists
|
||||
if [ -d "$env_name" ]; then
|
||||
echo "Environment '$env_name' already exists, re-using it."
|
||||
fi
|
||||
}
|
||||
|
||||
run() {
|
||||
# Use only global variables set by flag parser
|
||||
if [ -n "$UV_SYSTEM_PYTHON" ] || [ "$env_name" == "__system__" ]; then
|
||||
echo "Installing dependencies in system Python environment"
|
||||
export UV_SYSTEM_PYTHON=1
|
||||
elif [ "$VIRTUAL_ENV" == "$env_name" ]; then
|
||||
echo "Virtual environment $env_name is already active"
|
||||
else
|
||||
echo "Using virtual environment $env_name"
|
||||
uv venv "$env_name"
|
||||
source "$env_name/bin/activate"
|
||||
fi
|
||||
|
||||
if [ -n "$TEST_PYPI_VERSION" ]; then
|
||||
uv pip install fastapi libcst
|
||||
uv pip install --extra-index-url https://test.pypi.org/simple/ \
|
||||
--index-strategy unsafe-best-match \
|
||||
llama-stack=="$TEST_PYPI_VERSION" \
|
||||
$normal_deps
|
||||
if [ -n "$optional_deps" ]; then
|
||||
IFS='#' read -ra parts <<<"$optional_deps"
|
||||
for part in "${parts[@]}"; do
|
||||
echo "$part"
|
||||
uv pip install $part
|
||||
done
|
||||
fi
|
||||
if [ -n "$external_provider_deps" ]; then
|
||||
IFS='#' read -ra parts <<<"$external_provider_deps"
|
||||
for part in "${parts[@]}"; do
|
||||
echo "$part"
|
||||
uv pip install "$part"
|
||||
done
|
||||
fi
|
||||
else
|
||||
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||
# only warn if DIR does not start with "git+"
|
||||
if [ ! -d "$LLAMA_STACK_DIR" ] && [[ "$LLAMA_STACK_DIR" != git+* ]]; then
|
||||
printf "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: %s${NC}\n" "$LLAMA_STACK_DIR" >&2
|
||||
exit 1
|
||||
fi
|
||||
printf "Installing from LLAMA_STACK_DIR: %s\n" "$LLAMA_STACK_DIR"
|
||||
# editable only if LLAMA_STACK_DIR does not start with "git+"
|
||||
if [[ "$LLAMA_STACK_DIR" != git+* ]]; then
|
||||
EDITABLE="-e"
|
||||
else
|
||||
EDITABLE=""
|
||||
fi
|
||||
uv pip install --no-cache-dir $EDITABLE "$LLAMA_STACK_DIR"
|
||||
else
|
||||
uv pip install --no-cache-dir llama-stack
|
||||
fi
|
||||
|
||||
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
|
||||
# only warn if DIR does not start with "git+"
|
||||
if [ ! -d "$LLAMA_STACK_CLIENT_DIR" ] && [[ "$LLAMA_STACK_CLIENT_DIR" != git+* ]]; then
|
||||
printf "${RED}Warning: LLAMA_STACK_CLIENT_DIR is set but directory does not exist: %s${NC}\n" "$LLAMA_STACK_CLIENT_DIR" >&2
|
||||
exit 1
|
||||
fi
|
||||
printf "Installing from LLAMA_STACK_CLIENT_DIR: %s\n" "$LLAMA_STACK_CLIENT_DIR"
|
||||
# editable only if LLAMA_STACK_CLIENT_DIR does not start with "git+"
|
||||
if [[ "$LLAMA_STACK_CLIENT_DIR" != git+* ]]; then
|
||||
EDITABLE="-e"
|
||||
else
|
||||
EDITABLE=""
|
||||
fi
|
||||
uv pip install --no-cache-dir $EDITABLE "$LLAMA_STACK_CLIENT_DIR"
|
||||
fi
|
||||
|
||||
printf "Installing pip dependencies\n"
|
||||
uv pip install $normal_deps
|
||||
if [ -n "$optional_deps" ]; then
|
||||
IFS='#' read -ra parts <<<"$optional_deps"
|
||||
for part in "${parts[@]}"; do
|
||||
echo "Installing special provider module: $part"
|
||||
uv pip install $part
|
||||
done
|
||||
fi
|
||||
if [ -n "$external_provider_deps" ]; then
|
||||
IFS='#' read -ra parts <<<"$external_provider_deps"
|
||||
for part in "${parts[@]}"; do
|
||||
echo "Installing external provider module: $part"
|
||||
uv pip install "$part"
|
||||
echo "Getting provider spec for module: $part and installing dependencies"
|
||||
package_name=$(echo "$part" | sed 's/[<>=!].*//')
|
||||
python3 -c "
|
||||
import importlib
|
||||
import sys
|
||||
try:
|
||||
module = importlib.import_module(f'$package_name.provider')
|
||||
spec = module.get_provider_spec()
|
||||
if hasattr(spec, 'pip_packages') and spec.pip_packages:
|
||||
print('\\n'.join(spec.pip_packages))
|
||||
except Exception as e:
|
||||
print(f'Error getting provider spec for $package_name: {e}', file=sys.stderr)
|
||||
" | uv pip install -r -
|
||||
done
|
||||
fi
|
||||
fi
|
||||
}
|
||||
|
||||
pre_run_checks "$env_name"
|
||||
run
|
||||
|
|
@ -159,6 +159,37 @@ def upgrade_from_routing_table(
|
|||
config_dict["apis"] = config_dict["apis_to_serve"]
|
||||
config_dict.pop("apis_to_serve", None)
|
||||
|
||||
# Add default storage config if not present
|
||||
if "storage" not in config_dict:
|
||||
config_dict["storage"] = {
|
||||
"backends": {
|
||||
"kv_default": {
|
||||
"type": "kv_sqlite",
|
||||
"db_path": "~/.llama/kvstore.db",
|
||||
},
|
||||
"sql_default": {
|
||||
"type": "sql_sqlite",
|
||||
"db_path": "~/.llama/sql_store.db",
|
||||
},
|
||||
},
|
||||
"stores": {
|
||||
"metadata": {
|
||||
"namespace": "registry",
|
||||
"backend": "kv_default",
|
||||
},
|
||||
"inference": {
|
||||
"table_name": "inference_store",
|
||||
"backend": "sql_default",
|
||||
"max_write_queue_size": 10000,
|
||||
"num_writers": 4,
|
||||
},
|
||||
"conversations": {
|
||||
"table_name": "openai_conversations",
|
||||
"backend": "sql_default",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return config_dict
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import secrets
|
||||
import time
|
||||
from typing import Any, Literal
|
||||
|
|
@ -21,16 +20,11 @@ from llama_stack.apis.conversations.conversations import (
|
|||
Conversations,
|
||||
Metadata,
|
||||
)
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
from llama_stack.core.datatypes import AccessRule, StackRunConfig
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import (
|
||||
SqliteSqlStoreConfig,
|
||||
SqlStoreConfig,
|
||||
sqlstore_impl,
|
||||
)
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
|
||||
|
||||
logger = get_logger(name=__name__, category="openai_conversations")
|
||||
|
||||
|
|
@ -38,13 +32,11 @@ logger = get_logger(name=__name__, category="openai_conversations")
|
|||
class ConversationServiceConfig(BaseModel):
|
||||
"""Configuration for the built-in conversation service.
|
||||
|
||||
:param conversations_store: SQL store configuration for conversations (defaults to SQLite)
|
||||
:param run_config: Stack run configuration for resolving persistence
|
||||
:param policy: Access control rules
|
||||
"""
|
||||
|
||||
conversations_store: SqlStoreConfig = SqliteSqlStoreConfig(
|
||||
db_path=(DISTRIBS_BASE_DIR / "conversations.db").as_posix()
|
||||
)
|
||||
run_config: StackRunConfig
|
||||
policy: list[AccessRule] = []
|
||||
|
||||
|
||||
|
|
@ -63,14 +55,16 @@ class ConversationServiceImpl(Conversations):
|
|||
self.deps = deps
|
||||
self.policy = config.policy
|
||||
|
||||
base_sql_store = sqlstore_impl(config.conversations_store)
|
||||
# Use conversations store reference from run config
|
||||
conversations_ref = config.run_config.storage.stores.conversations
|
||||
if not conversations_ref:
|
||||
raise ValueError("storage.stores.conversations must be configured in run config")
|
||||
|
||||
base_sql_store = sqlstore_impl(conversations_ref)
|
||||
self.sql_store = AuthorizedSqlStore(base_sql_store, self.policy)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the store and create tables."""
|
||||
if isinstance(self.config.conversations_store, SqliteSqlStoreConfig):
|
||||
os.makedirs(os.path.dirname(self.config.conversations_store.db_path), exist_ok=True)
|
||||
|
||||
await self.sql_store.create_table(
|
||||
"openai_conversations",
|
||||
{
|
||||
|
|
|
|||
|
|
@ -23,12 +23,15 @@ from llama_stack.apis.scoring import Scoring
|
|||
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput
|
||||
from llama_stack.apis.shields import Shield, ShieldInput
|
||||
from llama_stack.apis.tools import ToolGroup, ToolGroupInput, ToolRuntime
|
||||
from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.apis.vector_stores import VectorStore, VectorStoreInput
|
||||
from llama_stack.core.access_control.datatypes import AccessRule
|
||||
from llama_stack.core.storage.datatypes import (
|
||||
KVStoreReference,
|
||||
StorageBackendType,
|
||||
StorageConfig,
|
||||
)
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStoreConfig
|
||||
|
||||
LLAMA_STACK_BUILD_CONFIG_VERSION = 2
|
||||
LLAMA_STACK_RUN_CONFIG_VERSION = 2
|
||||
|
|
@ -68,7 +71,7 @@ class ShieldWithOwner(Shield, ResourceWithOwner):
|
|||
pass
|
||||
|
||||
|
||||
class VectorDBWithOwner(VectorDB, ResourceWithOwner):
|
||||
class VectorStoreWithOwner(VectorStore, ResourceWithOwner):
|
||||
pass
|
||||
|
||||
|
||||
|
|
@ -88,12 +91,12 @@ class ToolGroupWithOwner(ToolGroup, ResourceWithOwner):
|
|||
pass
|
||||
|
||||
|
||||
RoutableObject = Model | Shield | VectorDB | Dataset | ScoringFn | Benchmark | ToolGroup
|
||||
RoutableObject = Model | Shield | VectorStore | Dataset | ScoringFn | Benchmark | ToolGroup
|
||||
|
||||
RoutableObjectWithProvider = Annotated[
|
||||
ModelWithOwner
|
||||
| ShieldWithOwner
|
||||
| VectorDBWithOwner
|
||||
| VectorStoreWithOwner
|
||||
| DatasetWithOwner
|
||||
| ScoringFnWithOwner
|
||||
| BenchmarkWithOwner
|
||||
|
|
@ -351,12 +354,32 @@ class AuthenticationRequiredError(Exception):
|
|||
pass
|
||||
|
||||
|
||||
class QualifiedModel(BaseModel):
|
||||
"""A qualified model identifier, consisting of a provider ID and a model ID."""
|
||||
|
||||
provider_id: str
|
||||
model_id: str
|
||||
|
||||
|
||||
class VectorStoresConfig(BaseModel):
|
||||
"""Configuration for vector stores in the stack."""
|
||||
|
||||
default_provider_id: str | None = Field(
|
||||
default=None,
|
||||
description="ID of the vector_io provider to use as default when multiple providers are available and none is specified.",
|
||||
)
|
||||
default_embedding_model: QualifiedModel | None = Field(
|
||||
default=None,
|
||||
description="Default embedding model configuration for vector stores.",
|
||||
)
|
||||
|
||||
|
||||
class QuotaPeriod(StrEnum):
|
||||
DAY = "day"
|
||||
|
||||
|
||||
class QuotaConfig(BaseModel):
|
||||
kvstore: SqliteKVStoreConfig = Field(description="Config for KV store backend (SQLite only for now)")
|
||||
kvstore: KVStoreReference = Field(description="Config for KV store backend (SQLite only for now)")
|
||||
anonymous_max_requests: int = Field(default=100, description="Max requests for unauthenticated clients per period")
|
||||
authenticated_max_requests: int = Field(
|
||||
default=1000, description="Max requests for authenticated clients per period"
|
||||
|
|
@ -399,6 +422,18 @@ def process_cors_config(cors_config: bool | CORSConfig | None) -> CORSConfig | N
|
|||
raise ValueError(f"Expected bool or CORSConfig, got {type(cors_config).__name__}")
|
||||
|
||||
|
||||
class RegisteredResources(BaseModel):
|
||||
"""Registry of resources available in the distribution."""
|
||||
|
||||
models: list[ModelInput] = Field(default_factory=list)
|
||||
shields: list[ShieldInput] = Field(default_factory=list)
|
||||
vector_stores: list[VectorStoreInput] = Field(default_factory=list)
|
||||
datasets: list[DatasetInput] = Field(default_factory=list)
|
||||
scoring_fns: list[ScoringFnInput] = Field(default_factory=list)
|
||||
benchmarks: list[BenchmarkInput] = Field(default_factory=list)
|
||||
tool_groups: list[ToolGroupInput] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ServerConfig(BaseModel):
|
||||
port: int = Field(
|
||||
default=8321,
|
||||
|
|
@ -438,18 +473,6 @@ class ServerConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class InferenceStoreConfig(BaseModel):
|
||||
sql_store_config: SqlStoreConfig
|
||||
max_write_queue_size: int = Field(default=10000, description="Max queued writes for inference store")
|
||||
num_writers: int = Field(default=4, description="Number of concurrent background writers")
|
||||
|
||||
|
||||
class ResponsesStoreConfig(BaseModel):
|
||||
sql_store_config: SqlStoreConfig
|
||||
max_write_queue_size: int = Field(default=10000, description="Max queued writes for responses store")
|
||||
num_writers: int = Field(default=4, description="Number of concurrent background writers")
|
||||
|
||||
|
||||
class StackRunConfig(BaseModel):
|
||||
version: int = LLAMA_STACK_RUN_CONFIG_VERSION
|
||||
|
||||
|
|
@ -476,37 +499,15 @@ One or more providers to use for each API. The same provider_type (e.g., meta-re
|
|||
can be instantiated multiple times (with different configs) if necessary.
|
||||
""",
|
||||
)
|
||||
metadata_store: KVStoreConfig | None = Field(
|
||||
default=None,
|
||||
description="""
|
||||
Configuration for the persistence store used by the distribution registry. If not specified,
|
||||
a default SQLite store will be used.""",
|
||||
storage: StorageConfig = Field(
|
||||
description="Catalog of named storage backends and references available to the stack",
|
||||
)
|
||||
|
||||
inference_store: InferenceStoreConfig | SqlStoreConfig | None = Field(
|
||||
default=None,
|
||||
description="""
|
||||
Configuration for the persistence store used by the inference API. Can be either a
|
||||
InferenceStoreConfig (with queue tuning parameters) or a SqlStoreConfig (deprecated).
|
||||
If not specified, a default SQLite store will be used.""",
|
||||
registered_resources: RegisteredResources = Field(
|
||||
default_factory=RegisteredResources,
|
||||
description="Registry of resources available in the distribution",
|
||||
)
|
||||
|
||||
conversations_store: SqlStoreConfig | None = Field(
|
||||
default=None,
|
||||
description="""
|
||||
Configuration for the persistence store used by the conversations API.
|
||||
If not specified, a default SQLite store will be used.""",
|
||||
)
|
||||
|
||||
# registry of "resources" in the distribution
|
||||
models: list[ModelInput] = Field(default_factory=list)
|
||||
shields: list[ShieldInput] = Field(default_factory=list)
|
||||
vector_dbs: list[VectorDBInput] = Field(default_factory=list)
|
||||
datasets: list[DatasetInput] = Field(default_factory=list)
|
||||
scoring_fns: list[ScoringFnInput] = Field(default_factory=list)
|
||||
benchmarks: list[BenchmarkInput] = Field(default_factory=list)
|
||||
tool_groups: list[ToolGroupInput] = Field(default_factory=list)
|
||||
|
||||
logging: LoggingConfig | None = Field(default=None, description="Configuration for Llama Stack Logging")
|
||||
|
||||
telemetry: TelemetryConfig = Field(default_factory=TelemetryConfig, description="Configuration for telemetry")
|
||||
|
|
@ -526,6 +527,11 @@ If not specified, a default SQLite store will be used.""",
|
|||
description="Path to directory containing external API implementations. The APIs code and dependencies must be installed on the system.",
|
||||
)
|
||||
|
||||
vector_stores: VectorStoresConfig | None = Field(
|
||||
default=None,
|
||||
description="Configuration for vector stores, including default embedding model",
|
||||
)
|
||||
|
||||
@field_validator("external_providers_dir")
|
||||
@classmethod
|
||||
def validate_external_providers_dir(cls, v):
|
||||
|
|
@ -535,6 +541,49 @@ If not specified, a default SQLite store will be used.""",
|
|||
return Path(v)
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_server_stores(self) -> "StackRunConfig":
|
||||
backend_map = self.storage.backends
|
||||
stores = self.storage.stores
|
||||
kv_backends = {
|
||||
name
|
||||
for name, cfg in backend_map.items()
|
||||
if cfg.type
|
||||
in {
|
||||
StorageBackendType.KV_REDIS,
|
||||
StorageBackendType.KV_SQLITE,
|
||||
StorageBackendType.KV_POSTGRES,
|
||||
StorageBackendType.KV_MONGODB,
|
||||
}
|
||||
}
|
||||
sql_backends = {
|
||||
name
|
||||
for name, cfg in backend_map.items()
|
||||
if cfg.type in {StorageBackendType.SQL_SQLITE, StorageBackendType.SQL_POSTGRES}
|
||||
}
|
||||
|
||||
def _ensure_backend(reference, expected_set, store_name: str) -> None:
|
||||
if reference is None:
|
||||
return
|
||||
backend_name = reference.backend
|
||||
if backend_name not in backend_map:
|
||||
raise ValueError(
|
||||
f"{store_name} references unknown backend '{backend_name}'. "
|
||||
f"Available backends: {sorted(backend_map)}"
|
||||
)
|
||||
if backend_name not in expected_set:
|
||||
raise ValueError(
|
||||
f"{store_name} references backend '{backend_name}' of type "
|
||||
f"'{backend_map[backend_name].type.value}', but a backend of type "
|
||||
f"{'kv_*' if expected_set is kv_backends else 'sql_*'} is required."
|
||||
)
|
||||
|
||||
_ensure_backend(stores.metadata, kv_backends, "storage.stores.metadata")
|
||||
_ensure_backend(stores.inference, sql_backends, "storage.stores.inference")
|
||||
_ensure_backend(stores.conversations, sql_backends, "storage.stores.conversations")
|
||||
_ensure_backend(stores.responses, sql_backends, "storage.stores.responses")
|
||||
return self
|
||||
|
||||
|
||||
class BuildConfig(BaseModel):
|
||||
version: int = LLAMA_STACK_BUILD_CONFIG_VERSION
|
||||
|
|
|
|||
|
|
@ -63,6 +63,10 @@ def builtin_automatically_routed_apis() -> list[AutoRoutedApiInfo]:
|
|||
routing_table_api=Api.tool_groups,
|
||||
router_api=Api.tool_runtime,
|
||||
),
|
||||
AutoRoutedApiInfo(
|
||||
routing_table_api=Api.vector_stores,
|
||||
router_api=Api.vector_io,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ from llama_stack.core.stack import (
|
|||
from llama_stack.core.utils.config import redact_sensitive_fields
|
||||
from llama_stack.core.utils.context import preserve_contexts_async_generator
|
||||
from llama_stack.core.utils.exec import in_notebook
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.log import get_logger, setup_logging
|
||||
from llama_stack.providers.utils.telemetry.tracing import CURRENT_TRACE_CONTEXT, end_trace, setup_logger, start_trace
|
||||
from llama_stack.strong_typing.inspection import is_unwrapped_body_param
|
||||
|
||||
|
|
@ -200,6 +200,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
skip_logger_removal: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
# Initialize logging from environment variables first
|
||||
setup_logging()
|
||||
|
||||
# when using the library client, we should not log to console since many
|
||||
# of our logs are intended for server-side usage
|
||||
if sinks_from_env := os.environ.get("TELEMETRY_SINKS", None):
|
||||
|
|
@ -278,7 +281,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
else:
|
||||
prefix = "!" if in_notebook() else ""
|
||||
cprint(
|
||||
f"Please run:\n\n{prefix}llama stack build --distro {self.config_path_or_distro_name} --image-type venv\n\n",
|
||||
f"Please run:\n\n{prefix}llama stack list-deps {self.config_path_or_distro_name} | xargs -L1 uv pip install\n\n",
|
||||
"yellow",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -11,9 +11,8 @@ from pydantic import BaseModel
|
|||
|
||||
from llama_stack.apis.prompts import ListPromptsResponse, Prompt, Prompts
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference
|
||||
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
|
||||
class PromptServiceConfig(BaseModel):
|
||||
|
|
@ -41,10 +40,12 @@ class PromptServiceImpl(Prompts):
|
|||
self.kvstore: KVStore
|
||||
|
||||
async def initialize(self) -> None:
|
||||
kvstore_config = SqliteKVStoreConfig(
|
||||
db_path=(DISTRIBS_BASE_DIR / self.config.run_config.image_name / "prompts.db").as_posix()
|
||||
)
|
||||
self.kvstore = await kvstore_impl(kvstore_config)
|
||||
# Use metadata store backend with prompts-specific namespace
|
||||
metadata_ref = self.config.run_config.storage.stores.metadata
|
||||
if not metadata_ref:
|
||||
raise ValueError("storage.stores.metadata must be configured in run config")
|
||||
prompts_ref = KVStoreReference(namespace="prompts", backend=metadata_ref.backend)
|
||||
self.kvstore = await kvstore_impl(prompts_ref)
|
||||
|
||||
def _get_default_key(self, prompt_id: str) -> str:
|
||||
"""Get the KVStore key that stores the default version number."""
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ from llama_stack.apis.shields import Shields
|
|||
from llama_stack.apis.telemetry import Telemetry
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1ALPHA
|
||||
from llama_stack.core.client import get_client_impl
|
||||
from llama_stack.core.datatypes import (
|
||||
|
|
@ -81,6 +82,7 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) ->
|
|||
Api.inspect: Inspect,
|
||||
Api.batches: Batches,
|
||||
Api.vector_io: VectorIO,
|
||||
Api.vector_stores: VectorStore,
|
||||
Api.models: Models,
|
||||
Api.safety: Safety,
|
||||
Api.shields: Shields,
|
||||
|
|
|
|||
|
|
@ -6,7 +6,10 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.core.datatypes import AccessRule, RoutedProtocol
|
||||
from llama_stack.core.datatypes import (
|
||||
AccessRule,
|
||||
RoutedProtocol,
|
||||
)
|
||||
from llama_stack.core.stack import StackRunConfig
|
||||
from llama_stack.core.store import DistributionRegistry
|
||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||
|
|
@ -26,6 +29,7 @@ async def get_routing_table_impl(
|
|||
from ..routing_tables.scoring_functions import ScoringFunctionsRoutingTable
|
||||
from ..routing_tables.shields import ShieldsRoutingTable
|
||||
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
|
||||
from ..routing_tables.vector_stores import VectorStoresRoutingTable
|
||||
|
||||
api_to_tables = {
|
||||
"models": ModelsRoutingTable,
|
||||
|
|
@ -34,6 +38,7 @@ async def get_routing_table_impl(
|
|||
"scoring_functions": ScoringFunctionsRoutingTable,
|
||||
"benchmarks": BenchmarksRoutingTable,
|
||||
"tool_groups": ToolGroupsRoutingTable,
|
||||
"vector_stores": VectorStoresRoutingTable,
|
||||
}
|
||||
|
||||
if api.value not in api_to_tables:
|
||||
|
|
@ -76,14 +81,21 @@ async def get_auto_router_impl(
|
|||
api_to_dep_impl[dep_name] = deps[dep_api]
|
||||
|
||||
# TODO: move pass configs to routers instead
|
||||
if api == Api.inference and run_config.inference_store:
|
||||
if api == Api.inference:
|
||||
inference_ref = run_config.storage.stores.inference
|
||||
if not inference_ref:
|
||||
raise ValueError("storage.stores.inference must be configured in run config")
|
||||
|
||||
inference_store = InferenceStore(
|
||||
config=run_config.inference_store,
|
||||
reference=inference_ref,
|
||||
policy=policy,
|
||||
)
|
||||
await inference_store.initialize()
|
||||
api_to_dep_impl["store"] = inference_store
|
||||
|
||||
elif api == Api.vector_io:
|
||||
api_to_dep_impl["vector_stores_config"] = run_config.vector_stores
|
||||
|
||||
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -44,9 +44,14 @@ from llama_stack.apis.inference import (
|
|||
OpenAIEmbeddingsResponse,
|
||||
OpenAIMessageParam,
|
||||
Order,
|
||||
RerankResponse,
|
||||
StopReason,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry
|
||||
from llama_stack.log import get_logger
|
||||
|
|
@ -182,6 +187,23 @@ class InferenceRouter(Inference):
|
|||
raise ModelTypeError(model_id, model.model_type, expected_model_type)
|
||||
return model
|
||||
|
||||
async def rerank(
|
||||
self,
|
||||
model: str,
|
||||
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
|
||||
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
|
||||
max_num_results: int | None = None,
|
||||
) -> RerankResponse:
|
||||
logger.debug(f"InferenceRouter.rerank: {model}")
|
||||
model_obj = await self._get_model(model, ModelType.rerank)
|
||||
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
return await provider.rerank(
|
||||
model=model_obj.identifier,
|
||||
query=query,
|
||||
items=items,
|
||||
max_num_results=max_num_results,
|
||||
)
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
params: Annotated[OpenAICompletionRequestWithExtraBody, Body(...)],
|
||||
|
|
|
|||
|
|
@ -37,24 +37,24 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
async def query(
|
||||
self,
|
||||
content: InterleavedContent,
|
||||
vector_db_ids: list[str],
|
||||
vector_store_ids: list[str],
|
||||
query_config: RAGQueryConfig | None = None,
|
||||
) -> RAGQueryResult:
|
||||
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
|
||||
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_store_ids}")
|
||||
provider = await self.routing_table.get_provider_impl("knowledge_search")
|
||||
return await provider.query(content, vector_db_ids, query_config)
|
||||
return await provider.query(content, vector_store_ids, query_config)
|
||||
|
||||
async def insert(
|
||||
self,
|
||||
documents: list[RAGDocument],
|
||||
vector_db_id: str,
|
||||
vector_store_id: str,
|
||||
chunk_size_in_tokens: int = 512,
|
||||
) -> None:
|
||||
logger.debug(
|
||||
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
|
||||
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_store_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
|
||||
)
|
||||
provider = await self.routing_table.get_provider_impl("insert_into_memory")
|
||||
return await provider.insert(documents, vector_db_id, chunk_size_in_tokens)
|
||||
return await provider.insert(documents, vector_store_id, chunk_size_in_tokens)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ from llama_stack.apis.vector_io import (
|
|||
VectorStoreObject,
|
||||
VectorStoreSearchResponsePage,
|
||||
)
|
||||
from llama_stack.core.datatypes import VectorStoresConfig
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
||||
|
||||
|
|
@ -43,9 +44,11 @@ class VectorIORouter(VectorIO):
|
|||
def __init__(
|
||||
self,
|
||||
routing_table: RoutingTable,
|
||||
vector_stores_config: VectorStoresConfig | None = None,
|
||||
) -> None:
|
||||
logger.debug("Initializing VectorIORouter")
|
||||
self.routing_table = routing_table
|
||||
self.vector_stores_config = vector_stores_config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.debug("VectorIORouter.initialize")
|
||||
|
|
@ -68,25 +71,6 @@ class VectorIORouter(VectorIO):
|
|||
|
||||
raise ValueError(f"Embedding model '{embedding_model_id}' not found or not an embedding model")
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
embedding_model: str,
|
||||
embedding_dimension: int | None = 384,
|
||||
provider_id: str | None = None,
|
||||
vector_db_name: str | None = None,
|
||||
provider_vector_db_id: str | None = None,
|
||||
) -> None:
|
||||
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
|
||||
await self.routing_table.register_vector_db(
|
||||
vector_db_id,
|
||||
embedding_model,
|
||||
embedding_dimension,
|
||||
provider_id,
|
||||
vector_db_name,
|
||||
provider_vector_db_id,
|
||||
)
|
||||
|
||||
async def insert_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
|
|
@ -122,6 +106,17 @@ class VectorIORouter(VectorIO):
|
|||
embedding_dimension = extra.get("embedding_dimension")
|
||||
provider_id = extra.get("provider_id")
|
||||
|
||||
# Use default embedding model if not specified
|
||||
if (
|
||||
embedding_model is None
|
||||
and self.vector_stores_config
|
||||
and self.vector_stores_config.default_embedding_model is not None
|
||||
):
|
||||
# Construct the full model ID with provider prefix
|
||||
embedding_provider_id = self.vector_stores_config.default_embedding_model.provider_id
|
||||
model_id = self.vector_stores_config.default_embedding_model.model_id
|
||||
embedding_model = f"{embedding_provider_id}/{model_id}"
|
||||
|
||||
if embedding_model is not None and embedding_dimension is None:
|
||||
embedding_dimension = await self._get_embedding_model_dimension(embedding_model)
|
||||
|
||||
|
|
@ -132,28 +127,41 @@ class VectorIORouter(VectorIO):
|
|||
raise ValueError("No vector_io providers available")
|
||||
if num_providers > 1:
|
||||
available_providers = list(self.routing_table.impls_by_provider_id.keys())
|
||||
raise ValueError(
|
||||
f"Multiple vector_io providers available. Please specify provider_id in extra_body. "
|
||||
f"Available providers: {available_providers}"
|
||||
)
|
||||
provider_id = list(self.routing_table.impls_by_provider_id.keys())[0]
|
||||
# Use default configured provider
|
||||
if self.vector_stores_config and self.vector_stores_config.default_provider_id:
|
||||
default_provider = self.vector_stores_config.default_provider_id
|
||||
if default_provider in available_providers:
|
||||
provider_id = default_provider
|
||||
logger.debug(f"Using configured default vector store provider: {provider_id}")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Configured default vector store provider '{default_provider}' not found. "
|
||||
f"Available providers: {available_providers}"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Multiple vector_io providers available. Please specify provider_id in extra_body. "
|
||||
f"Available providers: {available_providers}"
|
||||
)
|
||||
else:
|
||||
provider_id = list(self.routing_table.impls_by_provider_id.keys())[0]
|
||||
|
||||
vector_db_id = f"vs_{uuid.uuid4()}"
|
||||
registered_vector_db = await self.routing_table.register_vector_db(
|
||||
vector_db_id=vector_db_id,
|
||||
vector_store_id = f"vs_{uuid.uuid4()}"
|
||||
registered_vector_store = await self.routing_table.register_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
embedding_model=embedding_model,
|
||||
embedding_dimension=embedding_dimension,
|
||||
provider_id=provider_id,
|
||||
provider_vector_db_id=vector_db_id,
|
||||
vector_db_name=params.name,
|
||||
provider_vector_store_id=vector_store_id,
|
||||
vector_store_name=params.name,
|
||||
)
|
||||
provider = await self.routing_table.get_provider_impl(registered_vector_db.identifier)
|
||||
provider = await self.routing_table.get_provider_impl(registered_vector_store.identifier)
|
||||
|
||||
# Update model_extra with registered values so provider uses the already-registered vector_db
|
||||
# Update model_extra with registered values so provider uses the already-registered vector_store
|
||||
if params.model_extra is None:
|
||||
params.model_extra = {}
|
||||
params.model_extra["provider_vector_db_id"] = registered_vector_db.provider_resource_id
|
||||
params.model_extra["provider_id"] = registered_vector_db.provider_id
|
||||
params.model_extra["provider_vector_store_id"] = registered_vector_store.provider_resource_id
|
||||
params.model_extra["provider_id"] = registered_vector_store.provider_id
|
||||
if embedding_model is not None:
|
||||
params.model_extra["embedding_model"] = embedding_model
|
||||
if embedding_dimension is not None:
|
||||
|
|
@ -171,15 +179,15 @@ class VectorIORouter(VectorIO):
|
|||
logger.debug(f"VectorIORouter.openai_list_vector_stores: limit={limit}")
|
||||
# Route to default provider for now - could aggregate from all providers in the future
|
||||
# call retrieve on each vector dbs to get list of vector stores
|
||||
vector_dbs = await self.routing_table.get_all_with_type("vector_db")
|
||||
vector_stores = await self.routing_table.get_all_with_type("vector_store")
|
||||
all_stores = []
|
||||
for vector_db in vector_dbs:
|
||||
for vector_store in vector_stores:
|
||||
try:
|
||||
provider = await self.routing_table.get_provider_impl(vector_db.identifier)
|
||||
vector_store = await provider.openai_retrieve_vector_store(vector_db.identifier)
|
||||
provider = await self.routing_table.get_provider_impl(vector_store.identifier)
|
||||
vector_store = await provider.openai_retrieve_vector_store(vector_store.identifier)
|
||||
all_stores.append(vector_store)
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector store {vector_db.identifier}: {e}")
|
||||
logger.error(f"Error retrieving vector store {vector_store.identifier}: {e}")
|
||||
continue
|
||||
|
||||
# Sort by created_at
|
||||
|
|
@ -243,8 +251,7 @@ class VectorIORouter(VectorIO):
|
|||
vector_store_id: str,
|
||||
) -> VectorStoreDeleteResponse:
|
||||
logger.debug(f"VectorIORouter.openai_delete_vector_store: {vector_store_id}")
|
||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_delete_vector_store(vector_store_id)
|
||||
return await self.routing_table.openai_delete_vector_store(vector_store_id)
|
||||
|
||||
async def openai_search_vector_store(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
|
|||
elif api == Api.safety:
|
||||
return await p.register_shield(obj)
|
||||
elif api == Api.vector_io:
|
||||
return await p.register_vector_db(obj)
|
||||
return await p.register_vector_store(obj)
|
||||
elif api == Api.datasetio:
|
||||
return await p.register_dataset(obj)
|
||||
elif api == Api.scoring:
|
||||
|
|
@ -57,7 +57,7 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
|
|||
async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
|
||||
api = get_impl_api(p)
|
||||
if api == Api.vector_io:
|
||||
return await p.unregister_vector_db(obj.identifier)
|
||||
return await p.unregister_vector_store(obj.identifier)
|
||||
elif api == Api.inference:
|
||||
return await p.unregister_model(obj.identifier)
|
||||
elif api == Api.safety:
|
||||
|
|
@ -108,7 +108,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
elif api == Api.safety:
|
||||
p.shield_store = self
|
||||
elif api == Api.vector_io:
|
||||
p.vector_db_store = self
|
||||
p.vector_store_store = self
|
||||
elif api == Api.datasetio:
|
||||
p.dataset_store = self
|
||||
elif api == Api.scoring:
|
||||
|
|
@ -134,12 +134,15 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
from .scoring_functions import ScoringFunctionsRoutingTable
|
||||
from .shields import ShieldsRoutingTable
|
||||
from .toolgroups import ToolGroupsRoutingTable
|
||||
from .vector_stores import VectorStoresRoutingTable
|
||||
|
||||
def apiname_object():
|
||||
if isinstance(self, ModelsRoutingTable):
|
||||
return ("Inference", "model")
|
||||
elif isinstance(self, ShieldsRoutingTable):
|
||||
return ("Safety", "shield")
|
||||
elif isinstance(self, VectorStoresRoutingTable):
|
||||
return ("VectorIO", "vector_store")
|
||||
elif isinstance(self, DatasetsRoutingTable):
|
||||
return ("DatasetIO", "dataset")
|
||||
elif isinstance(self, ScoringFunctionsRoutingTable):
|
||||
|
|
|
|||
292
llama_stack/core/routing_tables/vector_stores.py
Normal file
292
llama_stack/core/routing_tables/vector_stores.py
Normal file
|
|
@ -0,0 +1,292 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
|
||||
# Removed VectorStores import to avoid exposing public API
|
||||
from llama_stack.apis.vector_io.vector_io import (
|
||||
SearchRankingOptions,
|
||||
VectorStoreChunkingStrategy,
|
||||
VectorStoreDeleteResponse,
|
||||
VectorStoreFileContentsResponse,
|
||||
VectorStoreFileDeleteResponse,
|
||||
VectorStoreFileObject,
|
||||
VectorStoreFileStatus,
|
||||
VectorStoreObject,
|
||||
VectorStoreSearchResponsePage,
|
||||
)
|
||||
from llama_stack.core.datatypes import (
|
||||
VectorStoreWithOwner,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .common import CommonRoutingTableImpl, lookup_model
|
||||
|
||||
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||
|
||||
|
||||
class VectorStoresRoutingTable(CommonRoutingTableImpl):
|
||||
"""Internal routing table for vector_store operations.
|
||||
|
||||
Does not inherit from VectorStores to avoid exposing public API endpoints.
|
||||
Only provides internal routing functionality for VectorIORouter.
|
||||
"""
|
||||
|
||||
# Internal methods only - no public API exposure
|
||||
|
||||
async def register_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
embedding_model: str,
|
||||
embedding_dimension: int | None = 384,
|
||||
provider_id: str | None = None,
|
||||
provider_vector_store_id: str | None = None,
|
||||
vector_store_name: str | None = None,
|
||||
) -> Any:
|
||||
if provider_id is None:
|
||||
if len(self.impls_by_provider_id) > 0:
|
||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
if len(self.impls_by_provider_id) > 1:
|
||||
logger.warning(
|
||||
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
|
||||
)
|
||||
else:
|
||||
raise ValueError("No provider available. Please configure a vector_io provider.")
|
||||
model = await lookup_model(self, embedding_model)
|
||||
if model is None:
|
||||
raise ModelNotFoundError(embedding_model)
|
||||
if model.model_type != ModelType.embedding:
|
||||
raise ModelTypeError(embedding_model, model.model_type, ModelType.embedding)
|
||||
|
||||
vector_store = VectorStoreWithOwner(
|
||||
identifier=vector_store_id,
|
||||
type=ResourceType.vector_store.value,
|
||||
provider_id=provider_id,
|
||||
provider_resource_id=provider_vector_store_id,
|
||||
embedding_model=embedding_model,
|
||||
embedding_dimension=embedding_dimension,
|
||||
vector_store_name=vector_store_name,
|
||||
)
|
||||
await self.register_object(vector_store)
|
||||
return vector_store
|
||||
|
||||
async def openai_retrieve_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreObject:
|
||||
await self.assert_action_allowed("read", "vector_store", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store(vector_store_id)
|
||||
|
||||
async def openai_update_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
name: str | None = None,
|
||||
expires_after: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> VectorStoreObject:
|
||||
await self.assert_action_allowed("update", "vector_store", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_update_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
name=name,
|
||||
expires_after=expires_after,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
async def openai_delete_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreDeleteResponse:
|
||||
await self.assert_action_allowed("delete", "vector_store", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
result = await provider.openai_delete_vector_store(vector_store_id)
|
||||
await self.unregister_vector_store(vector_store_id)
|
||||
return result
|
||||
|
||||
async def unregister_vector_store(self, vector_store_id: str) -> None:
|
||||
"""Remove the vector store from the routing table registry."""
|
||||
try:
|
||||
vector_store_obj = await self.get_object_by_identifier("vector_store", vector_store_id)
|
||||
if vector_store_obj:
|
||||
await self.unregister_object(vector_store_obj)
|
||||
except Exception as e:
|
||||
# Log the error but don't fail the operation
|
||||
logger.warning(f"Failed to unregister vector store {vector_store_id} from routing table: {e}")
|
||||
|
||||
async def openai_search_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
query: str | list[str],
|
||||
filters: dict[str, Any] | None = None,
|
||||
max_num_results: int | None = 10,
|
||||
ranking_options: SearchRankingOptions | None = None,
|
||||
rewrite_query: bool | None = False,
|
||||
search_mode: str | None = "vector",
|
||||
) -> VectorStoreSearchResponsePage:
|
||||
await self.assert_action_allowed("read", "vector_store", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_search_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
query=query,
|
||||
filters=filters,
|
||||
max_num_results=max_num_results,
|
||||
ranking_options=ranking_options,
|
||||
rewrite_query=rewrite_query,
|
||||
search_mode=search_mode,
|
||||
)
|
||||
|
||||
async def openai_attach_file_to_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
attributes: dict[str, Any] | None = None,
|
||||
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
||||
) -> VectorStoreFileObject:
|
||||
await self.assert_action_allowed("update", "vector_store", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_attach_file_to_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
attributes=attributes,
|
||||
chunking_strategy=chunking_strategy,
|
||||
)
|
||||
|
||||
async def openai_list_files_in_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
limit: int | None = 20,
|
||||
order: str | None = "desc",
|
||||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
filter: VectorStoreFileStatus | None = None,
|
||||
) -> list[VectorStoreFileObject]:
|
||||
await self.assert_action_allowed("read", "vector_store", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_list_files_in_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
limit=limit,
|
||||
order=order,
|
||||
after=after,
|
||||
before=before,
|
||||
filter=filter,
|
||||
)
|
||||
|
||||
async def openai_retrieve_vector_store_file(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileObject:
|
||||
await self.assert_action_allowed("read", "vector_store", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
||||
async def openai_retrieve_vector_store_file_contents(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileContentsResponse:
|
||||
await self.assert_action_allowed("read", "vector_store", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store_file_contents(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
||||
async def openai_update_vector_store_file(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
attributes: dict[str, Any],
|
||||
) -> VectorStoreFileObject:
|
||||
await self.assert_action_allowed("update", "vector_store", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_update_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
attributes=attributes,
|
||||
)
|
||||
|
||||
async def openai_delete_vector_store_file(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileDeleteResponse:
|
||||
await self.assert_action_allowed("delete", "vector_store", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_delete_vector_store_file(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
)
|
||||
|
||||
async def openai_create_vector_store_file_batch(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_ids: list[str],
|
||||
attributes: dict[str, Any] | None = None,
|
||||
chunking_strategy: Any | None = None,
|
||||
):
|
||||
await self.assert_action_allowed("update", "vector_store", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_create_vector_store_file_batch(
|
||||
vector_store_id=vector_store_id,
|
||||
file_ids=file_ids,
|
||||
attributes=attributes,
|
||||
chunking_strategy=chunking_strategy,
|
||||
)
|
||||
|
||||
async def openai_retrieve_vector_store_file_batch(
|
||||
self,
|
||||
batch_id: str,
|
||||
vector_store_id: str,
|
||||
):
|
||||
await self.assert_action_allowed("read", "vector_store", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store_file_batch(
|
||||
batch_id=batch_id,
|
||||
vector_store_id=vector_store_id,
|
||||
)
|
||||
|
||||
async def openai_list_files_in_vector_store_file_batch(
|
||||
self,
|
||||
batch_id: str,
|
||||
vector_store_id: str,
|
||||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
filter: str | None = None,
|
||||
limit: int | None = 20,
|
||||
order: str | None = "desc",
|
||||
):
|
||||
await self.assert_action_allowed("read", "vector_store", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_list_files_in_vector_store_file_batch(
|
||||
batch_id=batch_id,
|
||||
vector_store_id=vector_store_id,
|
||||
after=after,
|
||||
before=before,
|
||||
filter=filter,
|
||||
limit=limit,
|
||||
order=order,
|
||||
)
|
||||
|
||||
async def openai_cancel_vector_store_file_batch(
|
||||
self,
|
||||
batch_id: str,
|
||||
vector_store_id: str,
|
||||
):
|
||||
await self.assert_action_allowed("update", "vector_store", vector_store_id)
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_cancel_vector_store_file_batch(
|
||||
batch_id=batch_id,
|
||||
vector_store_id=vector_store_id,
|
||||
)
|
||||
|
|
@ -72,13 +72,30 @@ class AuthProvider(ABC):
|
|||
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> dict[str, list[str]]:
|
||||
attributes: dict[str, list[str]] = {}
|
||||
for claim_key, attribute_key in mapping.items():
|
||||
if claim_key not in claims:
|
||||
# First try dot notation for nested traversal (e.g., "resource_access.llamastack.roles")
|
||||
# Then fall back to literal key with dots (e.g., "my.dotted.key")
|
||||
claim: object = claims
|
||||
keys = claim_key.split(".")
|
||||
for key in keys:
|
||||
if isinstance(claim, dict) and key in claim:
|
||||
claim = claim[key]
|
||||
else:
|
||||
claim = None
|
||||
break
|
||||
|
||||
if claim is None and claim_key in claims:
|
||||
# Fall back to checking if claim_key exists as a literal key
|
||||
claim = claims[claim_key]
|
||||
|
||||
if claim is None:
|
||||
continue
|
||||
claim = claims[claim_key]
|
||||
|
||||
if isinstance(claim, list):
|
||||
values = claim
|
||||
else:
|
||||
elif isinstance(claim, str):
|
||||
values = claim.split()
|
||||
else:
|
||||
continue
|
||||
|
||||
if attribute_key in attributes:
|
||||
attributes[attribute_key].extend(values)
|
||||
|
|
|
|||
|
|
@ -10,10 +10,10 @@ from datetime import UTC, datetime, timedelta
|
|||
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, StorageBackendType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.kvstore import _KVSTORE_BACKENDS, kvstore_impl
|
||||
|
||||
logger = get_logger(name=__name__, category="core::server")
|
||||
|
||||
|
|
@ -33,7 +33,7 @@ class QuotaMiddleware:
|
|||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
kv_config: KVStoreConfig,
|
||||
kv_config: KVStoreReference,
|
||||
anonymous_max_requests: int,
|
||||
authenticated_max_requests: int,
|
||||
window_seconds: int = 86400,
|
||||
|
|
@ -45,15 +45,15 @@ class QuotaMiddleware:
|
|||
self.authenticated_max_requests = authenticated_max_requests
|
||||
self.window_seconds = window_seconds
|
||||
|
||||
if isinstance(self.kv_config, SqliteKVStoreConfig):
|
||||
logger.warning(
|
||||
"QuotaMiddleware: Using SQLite backend. Expiry/TTL is not enforced; cleanup is manual. "
|
||||
f"window_seconds={self.window_seconds}"
|
||||
)
|
||||
|
||||
async def _get_kv(self) -> KVStore:
|
||||
if self.kv is None:
|
||||
self.kv = await kvstore_impl(self.kv_config)
|
||||
backend_config = _KVSTORE_BACKENDS.get(self.kv_config.backend)
|
||||
if backend_config and backend_config.type == StorageBackendType.KV_SQLITE:
|
||||
logger.warning(
|
||||
"QuotaMiddleware: Using SQLite backend. Expiry/TTL is not enforced; cleanup is manual. "
|
||||
f"window_seconds={self.window_seconds}"
|
||||
)
|
||||
return self.kv
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
||||
|
|
|
|||
|
|
@ -56,7 +56,7 @@ from llama_stack.core.stack import (
|
|||
from llama_stack.core.utils.config import redact_sensitive_fields
|
||||
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
|
||||
from llama_stack.core.utils.context import preserve_contexts_async_generator
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.log import get_logger, setup_logging
|
||||
from llama_stack.providers.datatypes import Api
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
|
||||
|
|
@ -374,6 +374,9 @@ def create_app() -> StackApp:
|
|||
Returns:
|
||||
Configured StackApp instance.
|
||||
"""
|
||||
# Initialize logging from environment variables first
|
||||
setup_logging()
|
||||
|
||||
config_file = os.getenv("LLAMA_STACK_CONFIG")
|
||||
if config_file is None:
|
||||
raise ValueError("LLAMA_STACK_CONFIG environment variable is required")
|
||||
|
|
|
|||
|
|
@ -35,13 +35,23 @@ from llama_stack.apis.telemetry import Telemetry
|
|||
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig, VectorStoresConfig
|
||||
from llama_stack.core.distribution import get_provider_registry
|
||||
from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl
|
||||
from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceImpl
|
||||
from llama_stack.core.providers import ProviderImpl, ProviderImplConfig
|
||||
from llama_stack.core.resolver import ProviderRegistry, resolve_impls
|
||||
from llama_stack.core.routing_tables.common import CommonRoutingTableImpl
|
||||
from llama_stack.core.storage.datatypes import (
|
||||
InferenceStoreReference,
|
||||
KVStoreReference,
|
||||
ServerStoresConfig,
|
||||
SqliteKVStoreConfig,
|
||||
SqliteSqlStoreConfig,
|
||||
SqlStoreReference,
|
||||
StorageBackendConfig,
|
||||
StorageConfig,
|
||||
)
|
||||
from llama_stack.core.store.registry import create_dist_registry
|
||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.log import get_logger
|
||||
|
|
@ -98,33 +108,9 @@ REGISTRY_REFRESH_TASK = None
|
|||
TEST_RECORDING_CONTEXT = None
|
||||
|
||||
|
||||
async def validate_default_embedding_model(impls: dict[Api, Any]):
|
||||
"""Validate that at most one embedding model is marked as default."""
|
||||
if Api.models not in impls:
|
||||
return
|
||||
|
||||
models_impl = impls[Api.models]
|
||||
response = await models_impl.list_models()
|
||||
models_list = response.data if hasattr(response, "data") else response
|
||||
|
||||
default_embedding_models = []
|
||||
for model in models_list:
|
||||
if model.model_type == "embedding" and model.metadata.get("default_configured") is True:
|
||||
default_embedding_models.append(model.identifier)
|
||||
|
||||
if len(default_embedding_models) > 1:
|
||||
raise ValueError(
|
||||
f"Multiple embedding models marked as default_configured=True: {default_embedding_models}. "
|
||||
"Only one embedding model can be marked as default."
|
||||
)
|
||||
|
||||
if default_embedding_models:
|
||||
logger.info(f"Default embedding model configured: {default_embedding_models[0]}")
|
||||
|
||||
|
||||
async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
|
||||
for rsrc, api, register_method, list_method in RESOURCES:
|
||||
objects = getattr(run_config, rsrc)
|
||||
objects = getattr(run_config.registered_resources, rsrc)
|
||||
if api not in impls:
|
||||
continue
|
||||
|
||||
|
|
@ -152,7 +138,41 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
|
|||
f"{rsrc.capitalize()}: {obj.identifier} served by {obj.provider_id}",
|
||||
)
|
||||
|
||||
await validate_default_embedding_model(impls)
|
||||
|
||||
async def validate_vector_stores_config(vector_stores_config: VectorStoresConfig | None, impls: dict[Api, Any]):
|
||||
"""Validate vector stores configuration."""
|
||||
if vector_stores_config is None:
|
||||
return
|
||||
|
||||
default_embedding_model = vector_stores_config.default_embedding_model
|
||||
if default_embedding_model is None:
|
||||
return
|
||||
|
||||
provider_id = default_embedding_model.provider_id
|
||||
model_id = default_embedding_model.model_id
|
||||
default_model_id = f"{provider_id}/{model_id}"
|
||||
|
||||
if Api.models not in impls:
|
||||
raise ValueError(f"Models API is not available but vector_stores config requires model '{default_model_id}'")
|
||||
|
||||
models_impl = impls[Api.models]
|
||||
response = await models_impl.list_models()
|
||||
models_list = {m.identifier: m for m in response.data if m.model_type == "embedding"}
|
||||
|
||||
default_model = models_list.get(default_model_id)
|
||||
if default_model is None:
|
||||
raise ValueError(f"Embedding model '{default_model_id}' not found. Available embedding models: {models_list}")
|
||||
|
||||
embedding_dimension = default_model.metadata.get("embedding_dimension")
|
||||
if embedding_dimension is None:
|
||||
raise ValueError(f"Embedding model '{default_model_id}' is missing 'embedding_dimension' in metadata")
|
||||
|
||||
try:
|
||||
int(embedding_dimension)
|
||||
except ValueError as err:
|
||||
raise ValueError(f"Embedding dimension '{embedding_dimension}' cannot be converted to an integer") from err
|
||||
|
||||
logger.debug(f"Validated default embedding model: {default_model_id} (dimension: {embedding_dimension})")
|
||||
|
||||
|
||||
class EnvVarError(Exception):
|
||||
|
|
@ -329,6 +349,25 @@ def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConf
|
|||
impls[Api.conversations] = conversations_impl
|
||||
|
||||
|
||||
def _initialize_storage(run_config: StackRunConfig):
|
||||
kv_backends: dict[str, StorageBackendConfig] = {}
|
||||
sql_backends: dict[str, StorageBackendConfig] = {}
|
||||
for backend_name, backend_config in run_config.storage.backends.items():
|
||||
type = backend_config.type.value
|
||||
if type.startswith("kv_"):
|
||||
kv_backends[backend_name] = backend_config
|
||||
elif type.startswith("sql_"):
|
||||
sql_backends[backend_name] = backend_config
|
||||
else:
|
||||
raise ValueError(f"Unknown storage backend type: {type}")
|
||||
|
||||
from llama_stack.providers.utils.kvstore.kvstore import register_kvstore_backends
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
||||
|
||||
register_kvstore_backends(kv_backends)
|
||||
register_sqlstore_backends(sql_backends)
|
||||
|
||||
|
||||
class Stack:
|
||||
def __init__(self, run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None):
|
||||
self.run_config = run_config
|
||||
|
|
@ -347,7 +386,11 @@ class Stack:
|
|||
TEST_RECORDING_CONTEXT.__enter__()
|
||||
logger.info(f"API recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
|
||||
|
||||
dist_registry, _ = await create_dist_registry(self.run_config.metadata_store, self.run_config.image_name)
|
||||
_initialize_storage(self.run_config)
|
||||
stores = self.run_config.storage.stores
|
||||
if not stores.metadata:
|
||||
raise ValueError("storage.stores.metadata must be configured with a kv_* backend")
|
||||
dist_registry, _ = await create_dist_registry(stores.metadata, self.run_config.image_name)
|
||||
policy = self.run_config.server.auth.access_policy if self.run_config.server.auth else []
|
||||
|
||||
internal_impls = {}
|
||||
|
|
@ -367,8 +410,8 @@ class Stack:
|
|||
await impls[Api.conversations].initialize()
|
||||
|
||||
await register_resources(self.run_config, impls)
|
||||
|
||||
await refresh_registry_once(impls)
|
||||
await validate_vector_stores_config(self.run_config.vector_stores, impls)
|
||||
self.impls = impls
|
||||
|
||||
def create_registry_refresh_task(self):
|
||||
|
|
@ -488,5 +531,16 @@ def run_config_from_adhoc_config_spec(
|
|||
image_name="distro-test",
|
||||
apis=list(provider_configs_by_api.keys()),
|
||||
providers=provider_configs_by_api,
|
||||
storage=StorageConfig(
|
||||
backends={
|
||||
"kv_default": SqliteKVStoreConfig(db_path=f"{distro_dir}/kvstore.db"),
|
||||
"sql_default": SqliteSqlStoreConfig(db_path=f"{distro_dir}/sql_store.db"),
|
||||
},
|
||||
stores=ServerStoresConfig(
|
||||
metadata=KVStoreReference(backend="kv_default", namespace="registry"),
|
||||
inference=InferenceStoreReference(backend="sql_default", table_name="inference_store"),
|
||||
conversations=SqlStoreReference(backend="sql_default", table_name="openai_conversations"),
|
||||
),
|
||||
),
|
||||
)
|
||||
return config
|
||||
|
|
|
|||
5
llama_stack/core/storage/__init__.py
Normal file
5
llama_stack/core/storage/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
283
llama_stack/core/storage/datatypes.py
Normal file
283
llama_stack/core/storage/datatypes.py
Normal file
|
|
@ -0,0 +1,283 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import re
|
||||
from abc import abstractmethod
|
||||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class StorageBackendType(StrEnum):
|
||||
KV_REDIS = "kv_redis"
|
||||
KV_SQLITE = "kv_sqlite"
|
||||
KV_POSTGRES = "kv_postgres"
|
||||
KV_MONGODB = "kv_mongodb"
|
||||
SQL_SQLITE = "sql_sqlite"
|
||||
SQL_POSTGRES = "sql_postgres"
|
||||
|
||||
|
||||
class CommonConfig(BaseModel):
|
||||
namespace: str | None = Field(
|
||||
default=None,
|
||||
description="All keys will be prefixed with this namespace",
|
||||
)
|
||||
|
||||
|
||||
class RedisKVStoreConfig(CommonConfig):
|
||||
type: Literal[StorageBackendType.KV_REDIS] = StorageBackendType.KV_REDIS
|
||||
host: str = "localhost"
|
||||
port: int = 6379
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
return f"redis://{self.host}:{self.port}"
|
||||
|
||||
@classmethod
|
||||
def pip_packages(cls) -> list[str]:
|
||||
return ["redis"]
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls):
|
||||
return {
|
||||
"type": StorageBackendType.KV_REDIS.value,
|
||||
"host": "${env.REDIS_HOST:=localhost}",
|
||||
"port": "${env.REDIS_PORT:=6379}",
|
||||
}
|
||||
|
||||
|
||||
class SqliteKVStoreConfig(CommonConfig):
|
||||
type: Literal[StorageBackendType.KV_SQLITE] = StorageBackendType.KV_SQLITE
|
||||
db_path: str = Field(
|
||||
description="File path for the sqlite database",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def pip_packages(cls) -> list[str]:
|
||||
return ["aiosqlite"]
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, db_name: str = "kvstore.db"):
|
||||
return {
|
||||
"type": StorageBackendType.KV_SQLITE.value,
|
||||
"db_path": "${env.SQLITE_STORE_DIR:=" + __distro_dir__ + "}/" + db_name,
|
||||
}
|
||||
|
||||
|
||||
class PostgresKVStoreConfig(CommonConfig):
|
||||
type: Literal[StorageBackendType.KV_POSTGRES] = StorageBackendType.KV_POSTGRES
|
||||
host: str = "localhost"
|
||||
port: int | str = 5432
|
||||
db: str = "llamastack"
|
||||
user: str
|
||||
password: str | None = None
|
||||
ssl_mode: str | None = None
|
||||
ca_cert_path: str | None = None
|
||||
table_name: str = "llamastack_kvstore"
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, table_name: str = "llamastack_kvstore", **kwargs):
|
||||
return {
|
||||
"type": StorageBackendType.KV_POSTGRES.value,
|
||||
"host": "${env.POSTGRES_HOST:=localhost}",
|
||||
"port": "${env.POSTGRES_PORT:=5432}",
|
||||
"db": "${env.POSTGRES_DB:=llamastack}",
|
||||
"user": "${env.POSTGRES_USER:=llamastack}",
|
||||
"password": "${env.POSTGRES_PASSWORD:=llamastack}",
|
||||
"table_name": "${env.POSTGRES_TABLE_NAME:=" + table_name + "}",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@field_validator("table_name")
|
||||
def validate_table_name(cls, v: str) -> str:
|
||||
# PostgreSQL identifiers rules:
|
||||
# - Must start with a letter or underscore
|
||||
# - Can contain letters, numbers, and underscores
|
||||
# - Maximum length is 63 bytes
|
||||
pattern = r"^[a-zA-Z_][a-zA-Z0-9_]*$"
|
||||
if not re.match(pattern, v):
|
||||
raise ValueError(
|
||||
"Invalid table name. Must start with letter or underscore and contain only letters, numbers, and underscores"
|
||||
)
|
||||
if len(v) > 63:
|
||||
raise ValueError("Table name must be less than 63 characters")
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def pip_packages(cls) -> list[str]:
|
||||
return ["psycopg2-binary"]
|
||||
|
||||
|
||||
class MongoDBKVStoreConfig(CommonConfig):
|
||||
type: Literal[StorageBackendType.KV_MONGODB] = StorageBackendType.KV_MONGODB
|
||||
host: str = "localhost"
|
||||
port: int = 27017
|
||||
db: str = "llamastack"
|
||||
user: str | None = None
|
||||
password: str | None = None
|
||||
collection_name: str = "llamastack_kvstore"
|
||||
|
||||
@classmethod
|
||||
def pip_packages(cls) -> list[str]:
|
||||
return ["pymongo"]
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, collection_name: str = "llamastack_kvstore"):
|
||||
return {
|
||||
"type": StorageBackendType.KV_MONGODB.value,
|
||||
"host": "${env.MONGODB_HOST:=localhost}",
|
||||
"port": "${env.MONGODB_PORT:=5432}",
|
||||
"db": "${env.MONGODB_DB}",
|
||||
"user": "${env.MONGODB_USER}",
|
||||
"password": "${env.MONGODB_PASSWORD}",
|
||||
"collection_name": "${env.MONGODB_COLLECTION_NAME:=" + collection_name + "}",
|
||||
}
|
||||
|
||||
|
||||
class SqlAlchemySqlStoreConfig(BaseModel):
|
||||
@property
|
||||
@abstractmethod
|
||||
def engine_str(self) -> str: ...
|
||||
|
||||
# TODO: move this when we have a better way to specify dependencies with internal APIs
|
||||
@classmethod
|
||||
def pip_packages(cls) -> list[str]:
|
||||
return ["sqlalchemy[asyncio]"]
|
||||
|
||||
|
||||
class SqliteSqlStoreConfig(SqlAlchemySqlStoreConfig):
|
||||
type: Literal[StorageBackendType.SQL_SQLITE] = StorageBackendType.SQL_SQLITE
|
||||
db_path: str = Field(
|
||||
description="Database path, e.g. ~/.llama/distributions/ollama/sqlstore.db",
|
||||
)
|
||||
|
||||
@property
|
||||
def engine_str(self) -> str:
|
||||
return "sqlite+aiosqlite:///" + Path(self.db_path).expanduser().as_posix()
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, db_name: str = "sqlstore.db"):
|
||||
return {
|
||||
"type": StorageBackendType.SQL_SQLITE.value,
|
||||
"db_path": "${env.SQLITE_STORE_DIR:=" + __distro_dir__ + "}/" + db_name,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def pip_packages(cls) -> list[str]:
|
||||
return super().pip_packages() + ["aiosqlite"]
|
||||
|
||||
|
||||
class PostgresSqlStoreConfig(SqlAlchemySqlStoreConfig):
|
||||
type: Literal[StorageBackendType.SQL_POSTGRES] = StorageBackendType.SQL_POSTGRES
|
||||
host: str = "localhost"
|
||||
port: int | str = 5432
|
||||
db: str = "llamastack"
|
||||
user: str
|
||||
password: str | None = None
|
||||
|
||||
@property
|
||||
def engine_str(self) -> str:
|
||||
return f"postgresql+asyncpg://{self.user}:{self.password}@{self.host}:{self.port}/{self.db}"
|
||||
|
||||
@classmethod
|
||||
def pip_packages(cls) -> list[str]:
|
||||
return super().pip_packages() + ["asyncpg"]
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs):
|
||||
return {
|
||||
"type": StorageBackendType.SQL_POSTGRES.value,
|
||||
"host": "${env.POSTGRES_HOST:=localhost}",
|
||||
"port": "${env.POSTGRES_PORT:=5432}",
|
||||
"db": "${env.POSTGRES_DB:=llamastack}",
|
||||
"user": "${env.POSTGRES_USER:=llamastack}",
|
||||
"password": "${env.POSTGRES_PASSWORD:=llamastack}",
|
||||
}
|
||||
|
||||
|
||||
# reference = (backend_name, table_name)
|
||||
class SqlStoreReference(BaseModel):
|
||||
"""A reference to a 'SQL-like' persistent store. A table name must be provided."""
|
||||
|
||||
table_name: str = Field(
|
||||
description="Name of the table to use for the SqlStore",
|
||||
)
|
||||
|
||||
backend: str = Field(
|
||||
description="Name of backend from storage.backends",
|
||||
)
|
||||
|
||||
|
||||
# reference = (backend_name, namespace)
|
||||
class KVStoreReference(BaseModel):
|
||||
"""A reference to a 'key-value' persistent store. A namespace must be provided."""
|
||||
|
||||
namespace: str = Field(
|
||||
description="Key prefix for KVStore backends",
|
||||
)
|
||||
|
||||
backend: str = Field(
|
||||
description="Name of backend from storage.backends",
|
||||
)
|
||||
|
||||
|
||||
StorageBackendConfig = Annotated[
|
||||
RedisKVStoreConfig
|
||||
| SqliteKVStoreConfig
|
||||
| PostgresKVStoreConfig
|
||||
| MongoDBKVStoreConfig
|
||||
| SqliteSqlStoreConfig
|
||||
| PostgresSqlStoreConfig,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
class InferenceStoreReference(SqlStoreReference):
|
||||
"""Inference store configuration with queue tuning."""
|
||||
|
||||
max_write_queue_size: int = Field(
|
||||
default=10000,
|
||||
description="Max queued writes for inference store",
|
||||
)
|
||||
num_writers: int = Field(
|
||||
default=4,
|
||||
description="Number of concurrent background writers",
|
||||
)
|
||||
|
||||
|
||||
class ResponsesStoreReference(InferenceStoreReference):
|
||||
"""Responses store configuration with queue tuning."""
|
||||
|
||||
|
||||
class ServerStoresConfig(BaseModel):
|
||||
metadata: KVStoreReference | None = Field(
|
||||
default=None,
|
||||
description="Metadata store configuration (uses KV backend)",
|
||||
)
|
||||
inference: InferenceStoreReference | None = Field(
|
||||
default=None,
|
||||
description="Inference store configuration (uses SQL backend)",
|
||||
)
|
||||
conversations: SqlStoreReference | None = Field(
|
||||
default=None,
|
||||
description="Conversations store configuration (uses SQL backend)",
|
||||
)
|
||||
responses: ResponsesStoreReference | None = Field(
|
||||
default=None,
|
||||
description="Responses store configuration (uses SQL backend)",
|
||||
)
|
||||
|
||||
|
||||
class StorageConfig(BaseModel):
|
||||
backends: dict[str, StorageBackendConfig] = Field(
|
||||
description="Named backend configurations (e.g., 'default', 'cache')",
|
||||
)
|
||||
stores: ServerStoresConfig = Field(
|
||||
default_factory=lambda: ServerStoresConfig(),
|
||||
description="Named references to storage backends used by the stack core",
|
||||
)
|
||||
|
|
@ -11,10 +11,9 @@ from typing import Protocol
|
|||
import pydantic
|
||||
|
||||
from llama_stack.core.datatypes import RoutableObjectWithProvider
|
||||
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||
|
||||
logger = get_logger(__name__, category="core::registry")
|
||||
|
||||
|
|
@ -191,16 +190,10 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
|
|||
|
||||
|
||||
async def create_dist_registry(
|
||||
metadata_store: KVStoreConfig | None,
|
||||
image_name: str,
|
||||
metadata_store: KVStoreReference, image_name: str
|
||||
) -> tuple[CachedDiskDistributionRegistry, KVStore]:
|
||||
# instantiate kvstore for storing and retrieving distribution metadata
|
||||
if metadata_store:
|
||||
dist_kvstore = await kvstore_impl(metadata_store)
|
||||
else:
|
||||
dist_kvstore = await kvstore_impl(
|
||||
SqliteKVStoreConfig(db_path=(DISTRIBS_BASE_DIR / image_name / "kvstore.db").as_posix())
|
||||
)
|
||||
dist_kvstore = await kvstore_impl(metadata_store)
|
||||
dist_registry = CachedDiskDistributionRegistry(dist_kvstore)
|
||||
await dist_registry.initialize()
|
||||
return dist_registry, dist_kvstore
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@
|
|||
1. Start up Llama Stack API server. More details [here](https://llamastack.github.io/latest/getting_started/index.htmll).
|
||||
|
||||
```
|
||||
llama stack build --distro together --image-type venv
|
||||
llama stack list-deps together | xargs -L1 uv pip install
|
||||
|
||||
llama stack run together
|
||||
```
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ def tool_chat_page():
|
|||
tool_groups_list = [tool_group.identifier for tool_group in tool_groups]
|
||||
mcp_tools_list = [tool for tool in tool_groups_list if tool.startswith("mcp::")]
|
||||
builtin_tools_list = [tool for tool in tool_groups_list if not tool.startswith("mcp::")]
|
||||
selected_vector_dbs = []
|
||||
selected_vector_stores = []
|
||||
|
||||
def reset_agent():
|
||||
st.session_state.clear()
|
||||
|
|
@ -55,13 +55,13 @@ def tool_chat_page():
|
|||
)
|
||||
|
||||
if "builtin::rag" in toolgroup_selection:
|
||||
vector_dbs = llama_stack_api.client.vector_dbs.list() or []
|
||||
if not vector_dbs:
|
||||
vector_stores = llama_stack_api.client.vector_stores.list() or []
|
||||
if not vector_stores:
|
||||
st.info("No vector databases available for selection.")
|
||||
vector_dbs = [vector_db.identifier for vector_db in vector_dbs]
|
||||
selected_vector_dbs = st.multiselect(
|
||||
vector_stores = [vector_store.identifier for vector_store in vector_stores]
|
||||
selected_vector_stores = st.multiselect(
|
||||
label="Select Document Collections to use in RAG queries",
|
||||
options=vector_dbs,
|
||||
options=vector_stores,
|
||||
on_change=reset_agent,
|
||||
)
|
||||
|
||||
|
|
@ -119,7 +119,7 @@ def tool_chat_page():
|
|||
tool_dict = dict(
|
||||
name="builtin::rag",
|
||||
args={
|
||||
"vector_db_ids": list(selected_vector_dbs),
|
||||
"vector_store_ids": list(selected_vector_stores),
|
||||
},
|
||||
)
|
||||
toolgroup_selection[i] = tool_dict
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue