mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Merge branch 'meta-llama:main' into main
This commit is contained in:
commit
3ee415dc35
16 changed files with 140 additions and 116 deletions
|
@ -190,7 +190,7 @@ class Inference(Protocol):
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
# zero-shot tool definitions as input to the model
|
# zero-shot tool definitions as input to the model
|
||||||
tools: Optional[List[ToolDefinition]] = list,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
|
|
|
@ -13,9 +13,9 @@ from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import httpx
|
import httpx
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
from llama_stack.providers.utils.memory.file_utils import data_url_from_file
|
from llama_stack.providers.utils.memory.file_utils import data_url_from_file
|
||||||
|
@ -120,7 +120,7 @@ async def run_main(host: str, port: int, stream: bool):
|
||||||
name="test_bank",
|
name="test_bank",
|
||||||
config=VectorMemoryBankConfig(
|
config=VectorMemoryBankConfig(
|
||||||
bank_id="test_bank",
|
bank_id="test_bank",
|
||||||
embedding_model="dragon-roberta-query-2",
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
chunk_size_in_tokens=512,
|
chunk_size_in_tokens=512,
|
||||||
overlap_size_in_tokens=64,
|
overlap_size_in_tokens=64,
|
||||||
),
|
),
|
||||||
|
@ -129,7 +129,7 @@ async def run_main(host: str, port: int, stream: bool):
|
||||||
|
|
||||||
retrieved_bank = await client.get_memory_bank(bank.bank_id)
|
retrieved_bank = await client.get_memory_bank(bank.bank_id)
|
||||||
assert retrieved_bank is not None
|
assert retrieved_bank is not None
|
||||||
assert retrieved_bank.config.embedding_model == "dragon-roberta-query-2"
|
assert retrieved_bank.config.embedding_model == "all-MiniLM-L6-v2"
|
||||||
|
|
||||||
urls = [
|
urls = [
|
||||||
"memory_optimizations.rst",
|
"memory_optimizations.rst",
|
||||||
|
|
|
@ -9,12 +9,12 @@ import json
|
||||||
|
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from termcolor import colored
|
|
||||||
|
|
||||||
from llama_stack.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
from llama_stack.cli.table import print_table
|
from llama_stack.cli.table import print_table
|
||||||
from llama_stack.distribution.utils.serialize import EnumEncoder
|
from llama_stack.distribution.utils.serialize import EnumEncoder
|
||||||
|
|
||||||
|
from termcolor import colored
|
||||||
|
|
||||||
|
|
||||||
class ModelDescribe(Subcommand):
|
class ModelDescribe(Subcommand):
|
||||||
"""Show details about a model"""
|
"""Show details about a model"""
|
||||||
|
@ -52,7 +52,7 @@ class ModelDescribe(Subcommand):
|
||||||
colored(model.descriptor(), "white", attrs=["bold"]),
|
colored(model.descriptor(), "white", attrs=["bold"]),
|
||||||
),
|
),
|
||||||
("HuggingFace ID", model.huggingface_repo or "<Not Available>"),
|
("HuggingFace ID", model.huggingface_repo or "<Not Available>"),
|
||||||
("Description", model.description_markdown),
|
("Description", model.description),
|
||||||
("Context Length", f"{model.max_seq_length // 1024}K tokens"),
|
("Context Length", f"{model.max_seq_length // 1024}K tokens"),
|
||||||
("Weights format", model.quantization_format.value),
|
("Weights format", model.quantization_format.value),
|
||||||
("Model params.json", json.dumps(model.arch_args, indent=4)),
|
("Model params.json", json.dumps(model.arch_args, indent=4)),
|
||||||
|
|
|
@ -66,6 +66,14 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
|
||||||
if provider_spec.docker_image:
|
if provider_spec.docker_image:
|
||||||
raise ValueError("A stack's dependencies cannot have a docker image")
|
raise ValueError("A stack's dependencies cannot have a docker image")
|
||||||
|
|
||||||
|
special_deps = []
|
||||||
|
deps = []
|
||||||
|
for package in package_deps.pip_packages:
|
||||||
|
if "--no-deps" in package or "--index-url" in package:
|
||||||
|
special_deps.append(package)
|
||||||
|
else:
|
||||||
|
deps.append(package)
|
||||||
|
|
||||||
if build_config.image_type == ImageType.docker.value:
|
if build_config.image_type == ImageType.docker.value:
|
||||||
script = pkg_resources.resource_filename(
|
script = pkg_resources.resource_filename(
|
||||||
"llama_stack", "distribution/build_container.sh"
|
"llama_stack", "distribution/build_container.sh"
|
||||||
|
@ -75,7 +83,7 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
|
||||||
build_config.name,
|
build_config.name,
|
||||||
package_deps.docker_image,
|
package_deps.docker_image,
|
||||||
str(build_file_path),
|
str(build_file_path),
|
||||||
" ".join(package_deps.pip_packages),
|
" ".join(deps),
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
script = pkg_resources.resource_filename(
|
script = pkg_resources.resource_filename(
|
||||||
|
@ -84,14 +92,17 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
|
||||||
args = [
|
args = [
|
||||||
script,
|
script,
|
||||||
build_config.name,
|
build_config.name,
|
||||||
" ".join(package_deps.pip_packages),
|
" ".join(deps),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if special_deps:
|
||||||
|
args.append("#".join(special_deps))
|
||||||
|
|
||||||
return_code = run_with_pty(args)
|
return_code = run_with_pty(args)
|
||||||
if return_code != 0:
|
if return_code != 0:
|
||||||
cprint(
|
cprint(
|
||||||
f"Failed to build target {build_config.name} with return code {return_code}",
|
f"Failed to build target {build_config.name} with return code {return_code}",
|
||||||
color="red",
|
color="red",
|
||||||
)
|
)
|
||||||
|
|
||||||
return return_code
|
return return_code
|
||||||
|
|
|
@ -17,14 +17,16 @@ if [ -n "$LLAMA_MODELS_DIR" ]; then
|
||||||
echo "Using llama-models-dir=$LLAMA_MODELS_DIR"
|
echo "Using llama-models-dir=$LLAMA_MODELS_DIR"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
set -euo pipefail
|
if [ "$#" -lt 2 ]; then
|
||||||
|
echo "Usage: $0 <distribution_type> <build_name> <pip_dependencies> [<special_pip_deps>]" >&2
|
||||||
if [ "$#" -ne 2 ]; then
|
|
||||||
echo "Usage: $0 <distribution_type> <build_name> <pip_dependencies>" >&2
|
|
||||||
echo "Example: $0 <distribution_type> mybuild 'numpy pandas scipy'" >&2
|
echo "Example: $0 <distribution_type> mybuild 'numpy pandas scipy'" >&2
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
special_pip_deps="$3"
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
build_name="$1"
|
build_name="$1"
|
||||||
env_name="llamastack-$build_name"
|
env_name="llamastack-$build_name"
|
||||||
pip_dependencies="$2"
|
pip_dependencies="$2"
|
||||||
|
@ -43,6 +45,7 @@ source "$SCRIPT_DIR/common.sh"
|
||||||
ensure_conda_env_python310() {
|
ensure_conda_env_python310() {
|
||||||
local env_name="$1"
|
local env_name="$1"
|
||||||
local pip_dependencies="$2"
|
local pip_dependencies="$2"
|
||||||
|
local special_pip_deps="$3"
|
||||||
local python_version="3.10"
|
local python_version="3.10"
|
||||||
|
|
||||||
# Check if conda command is available
|
# Check if conda command is available
|
||||||
|
@ -78,7 +81,12 @@ ensure_conda_env_python310() {
|
||||||
if [ -n "$TEST_PYPI_VERSION" ]; then
|
if [ -n "$TEST_PYPI_VERSION" ]; then
|
||||||
# these packages are damaged in test-pypi, so install them first
|
# these packages are damaged in test-pypi, so install them first
|
||||||
$CONDA_PREFIX/bin/pip install fastapi libcst
|
$CONDA_PREFIX/bin/pip install fastapi libcst
|
||||||
$CONDA_PREFIX/bin/pip install --extra-index-url https://test.pypi.org/simple/ llama-models==$TEST_PYPI_VERSION llama-stack==$TEST_PYPI_VERSION $pip_dependencies
|
$CONDA_PREFIX/bin/pip install --extra-index-url https://test.pypi.org/simple/ \
|
||||||
|
llama-models==$TEST_PYPI_VERSION llama-stack==$TEST_PYPI_VERSION \
|
||||||
|
$pip_dependencies
|
||||||
|
if [ -n "$special_pip_deps" ]; then
|
||||||
|
$CONDA_PREFIX/bin/pip install --no-deps "$special_pip_deps"
|
||||||
|
fi
|
||||||
else
|
else
|
||||||
# Re-installing llama-stack in the new conda environment
|
# Re-installing llama-stack in the new conda environment
|
||||||
if [ -n "$LLAMA_STACK_DIR" ]; then
|
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||||
|
@ -105,11 +113,16 @@ ensure_conda_env_python310() {
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Install pip dependencies
|
# Install pip dependencies
|
||||||
if [ -n "$pip_dependencies" ]; then
|
printf "Installing pip dependencies\n"
|
||||||
printf "Installing pip dependencies: $pip_dependencies\n"
|
$CONDA_PREFIX/bin/pip install $pip_dependencies
|
||||||
$CONDA_PREFIX/bin/pip install $pip_dependencies
|
if [ -n "$special_pip_deps" ]; then
|
||||||
|
IFS='#' read -ra parts <<< "$special_pip_deps"
|
||||||
|
for part in "${parts[@]}"; do
|
||||||
|
echo "$part"
|
||||||
|
$CONDA_PREFIX/bin/pip install $part
|
||||||
|
done
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
|
||||||
ensure_conda_env_python310 "$env_name" "$pip_dependencies"
|
ensure_conda_env_python310 "$env_name" "$pip_dependencies" "$special_pip_deps"
|
||||||
|
|
|
@ -4,12 +4,16 @@ LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
|
||||||
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
|
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
|
||||||
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
||||||
|
|
||||||
if [ "$#" -ne 4 ]; then
|
if [ "$#" -lt 4 ]; then
|
||||||
echo "Usage: $0 <build_name> <docker_base> <pip_dependencies>
|
echo "Usage: $0 <build_name> <docker_base> <pip_dependencies> [<special_pip_deps>]" >&2
|
||||||
echo "Example: $0 my-fastapi-app python:3.9-slim 'fastapi uvicorn'
|
echo "Example: $0 my-fastapi-app python:3.9-slim 'fastapi uvicorn' " >&2
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
special_pip_deps="$5"
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
build_name="$1"
|
build_name="$1"
|
||||||
image_name="llamastack-$build_name"
|
image_name="llamastack-$build_name"
|
||||||
docker_base=$2
|
docker_base=$2
|
||||||
|
@ -21,8 +25,6 @@ RED='\033[0;31m'
|
||||||
GREEN='\033[0;32m'
|
GREEN='\033[0;32m'
|
||||||
NC='\033[0m' # No Color
|
NC='\033[0m' # No Color
|
||||||
|
|
||||||
set -euo pipefail
|
|
||||||
|
|
||||||
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
||||||
REPO_DIR=$(dirname $(dirname "$SCRIPT_DIR"))
|
REPO_DIR=$(dirname $(dirname "$SCRIPT_DIR"))
|
||||||
DOCKER_BINARY=${DOCKER_BINARY:-docker}
|
DOCKER_BINARY=${DOCKER_BINARY:-docker}
|
||||||
|
@ -85,6 +87,13 @@ if [ -n "$pip_dependencies" ]; then
|
||||||
add_to_docker "RUN pip install $pip_dependencies"
|
add_to_docker "RUN pip install $pip_dependencies"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
if [ -n "$special_pip_deps" ]; then
|
||||||
|
IFS='#' read -ra parts <<< "$special_pip_deps"
|
||||||
|
for part in "${parts[@]}"; do
|
||||||
|
add_to_docker "RUN pip install $part"
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
add_to_docker <<EOF
|
add_to_docker <<EOF
|
||||||
|
|
||||||
# This would be good in production but for debugging flexibility lets not add it right now
|
# This would be good in production but for debugging flexibility lets not add it right now
|
||||||
|
|
|
@ -103,8 +103,7 @@ class InferenceRouter(Inference):
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
# TODO: we need to fix streaming response to align provider implementations with Protocol.
|
params = dict(
|
||||||
async for chunk in self.routing_table.get_provider_impl(model).chat_completion(
|
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
|
@ -113,6 +112,10 @@ class InferenceRouter(Inference):
|
||||||
tool_prompt_format=tool_prompt_format,
|
tool_prompt_format=tool_prompt_format,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
|
)
|
||||||
|
# TODO: we need to fix streaming response to align provider implementations with Protocol.
|
||||||
|
async for chunk in self.routing_table.get_provider_impl(model).chat_completion(
|
||||||
|
**params
|
||||||
):
|
):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
|
|
@ -33,8 +33,10 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
for p in self.providers.values():
|
for p in self.providers.values():
|
||||||
await p.shutdown()
|
await p.shutdown()
|
||||||
|
|
||||||
def get_provider_impl(self, routing_key: str) -> Optional[Any]:
|
def get_provider_impl(self, routing_key: str) -> Any:
|
||||||
return self.providers.get(routing_key)
|
if routing_key not in self.providers:
|
||||||
|
raise ValueError(f"Could not find provider for {routing_key}")
|
||||||
|
return self.providers[routing_key]
|
||||||
|
|
||||||
def get_routing_keys(self) -> List[str]:
|
def get_routing_keys(self) -> List[str]:
|
||||||
return self.routing_keys
|
return self.routing_keys
|
||||||
|
|
|
@ -368,17 +368,19 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
|
||||||
providers = all_providers[info.router_api]
|
providers = all_providers[info.router_api]
|
||||||
|
|
||||||
inner_specs = []
|
inner_specs = []
|
||||||
|
inner_deps = []
|
||||||
for rt_entry in routing_table:
|
for rt_entry in routing_table:
|
||||||
if rt_entry.provider_id not in providers:
|
if rt_entry.provider_id not in providers:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`"
|
f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`"
|
||||||
)
|
)
|
||||||
inner_specs.append(providers[rt_entry.provider_id])
|
inner_specs.append(providers[rt_entry.provider_id])
|
||||||
|
inner_deps.extend(providers[rt_entry.provider_id].api_dependencies)
|
||||||
|
|
||||||
specs[source_api] = RoutingTableProviderSpec(
|
specs[source_api] = RoutingTableProviderSpec(
|
||||||
api=source_api,
|
api=source_api,
|
||||||
module="llama_stack.distribution.routers",
|
module="llama_stack.distribution.routers",
|
||||||
api_dependencies=[],
|
api_dependencies=inner_deps,
|
||||||
inner_specs=inner_specs,
|
inner_specs=inner_specs,
|
||||||
)
|
)
|
||||||
configs[source_api] = routing_table
|
configs[source_api] = routing_table
|
||||||
|
|
|
@ -119,7 +119,7 @@ class TGIAdapter(Inference):
|
||||||
)
|
)
|
||||||
stop_reason = None
|
stop_reason = None
|
||||||
if response.details.finish_reason:
|
if response.details.finish_reason:
|
||||||
if response.details.finish_reason == "stop":
|
if response.details.finish_reason in ["stop", "eos_token"]:
|
||||||
stop_reason = StopReason.end_of_turn
|
stop_reason = StopReason.end_of_turn
|
||||||
elif response.details.finish_reason == "length":
|
elif response.details.finish_reason == "length":
|
||||||
stop_reason = StopReason.out_of_tokens
|
stop_reason = StopReason.out_of_tokens
|
||||||
|
|
|
@ -7,11 +7,11 @@
|
||||||
from .config import SafetyConfig
|
from .config import SafetyConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: SafetyConfig, _deps):
|
async def get_provider_impl(config: SafetyConfig, deps):
|
||||||
from .safety import MetaReferenceSafetyImpl
|
from .safety import MetaReferenceSafetyImpl
|
||||||
|
|
||||||
assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}"
|
assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}"
|
||||||
|
|
||||||
impl = MetaReferenceSafetyImpl(config)
|
impl = MetaReferenceSafetyImpl(config, deps)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -7,8 +7,10 @@
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||||
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.apis.safety import * # noqa: F403
|
from llama_stack.apis.safety import * # noqa: F403
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
from llama_stack.distribution.datatypes import Api
|
||||||
|
|
||||||
from llama_stack.providers.impls.meta_reference.safety.shields.base import (
|
from llama_stack.providers.impls.meta_reference.safety.shields.base import (
|
||||||
OnViolationAction,
|
OnViolationAction,
|
||||||
|
@ -34,20 +36,11 @@ def resolve_and_get_path(model_name: str) -> str:
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceSafetyImpl(Safety):
|
class MetaReferenceSafetyImpl(Safety):
|
||||||
def __init__(self, config: SafetyConfig) -> None:
|
def __init__(self, config: SafetyConfig, deps) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.inference_api = deps[Api.inference]
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
shield_cfg = self.config.llama_guard_shield
|
|
||||||
if shield_cfg is not None:
|
|
||||||
model_dir = resolve_and_get_path(shield_cfg.model)
|
|
||||||
_ = LlamaGuardShield.instance(
|
|
||||||
model_dir=model_dir,
|
|
||||||
excluded_categories=shield_cfg.excluded_categories,
|
|
||||||
disable_input_check=shield_cfg.disable_input_check,
|
|
||||||
disable_output_check=shield_cfg.disable_output_check,
|
|
||||||
)
|
|
||||||
|
|
||||||
shield_cfg = self.config.prompt_guard_shield
|
shield_cfg = self.config.prompt_guard_shield
|
||||||
if shield_cfg is not None:
|
if shield_cfg is not None:
|
||||||
model_dir = resolve_and_get_path(shield_cfg.model)
|
model_dir = resolve_and_get_path(shield_cfg.model)
|
||||||
|
@ -91,11 +84,18 @@ class MetaReferenceSafetyImpl(Safety):
|
||||||
def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase:
|
def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase:
|
||||||
cfg = self.config
|
cfg = self.config
|
||||||
if typ == MetaReferenceShieldType.llama_guard:
|
if typ == MetaReferenceShieldType.llama_guard:
|
||||||
|
cfg = cfg.llama_guard_shield
|
||||||
assert (
|
assert (
|
||||||
cfg.llama_guard_shield is not None
|
cfg is not None
|
||||||
), "Cannot use LlamaGuardShield since not present in config"
|
), "Cannot use LlamaGuardShield since not present in config"
|
||||||
model_dir = resolve_and_get_path(cfg.llama_guard_shield.model)
|
|
||||||
return LlamaGuardShield.instance(model_dir=model_dir)
|
return LlamaGuardShield(
|
||||||
|
model=cfg.model,
|
||||||
|
inference_api=self.inference_api,
|
||||||
|
excluded_categories=cfg.excluded_categories,
|
||||||
|
disable_input_check=cfg.disable_input_check,
|
||||||
|
disable_output_check=cfg.disable_output_check,
|
||||||
|
)
|
||||||
elif typ == MetaReferenceShieldType.jailbreak_shield:
|
elif typ == MetaReferenceShieldType.jailbreak_shield:
|
||||||
assert (
|
assert (
|
||||||
cfg.prompt_guard_shield is not None
|
cfg.prompt_guard_shield is not None
|
||||||
|
|
|
@ -9,9 +9,8 @@ import re
|
||||||
from string import Template
|
from string import Template
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
|
||||||
from llama_models.llama3.api.datatypes import Message, Role
|
from llama_models.llama3.api.datatypes import Message, Role
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
|
||||||
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
|
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
|
||||||
|
|
||||||
|
@ -100,39 +99,17 @@ PROMPT_TEMPLATE = Template(
|
||||||
|
|
||||||
|
|
||||||
class LlamaGuardShield(ShieldBase):
|
class LlamaGuardShield(ShieldBase):
|
||||||
@staticmethod
|
|
||||||
def instance(
|
|
||||||
on_violation_action=OnViolationAction.RAISE,
|
|
||||||
model_dir: str = None,
|
|
||||||
excluded_categories: List[str] = None,
|
|
||||||
disable_input_check: bool = False,
|
|
||||||
disable_output_check: bool = False,
|
|
||||||
) -> "LlamaGuardShield":
|
|
||||||
global _INSTANCE
|
|
||||||
if _INSTANCE is None:
|
|
||||||
_INSTANCE = LlamaGuardShield(
|
|
||||||
on_violation_action,
|
|
||||||
model_dir,
|
|
||||||
excluded_categories,
|
|
||||||
disable_input_check,
|
|
||||||
disable_output_check,
|
|
||||||
)
|
|
||||||
return _INSTANCE
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
model: str,
|
||||||
model_dir: str = None,
|
inference_api: Inference,
|
||||||
excluded_categories: List[str] = None,
|
excluded_categories: List[str] = None,
|
||||||
disable_input_check: bool = False,
|
disable_input_check: bool = False,
|
||||||
disable_output_check: bool = False,
|
disable_output_check: bool = False,
|
||||||
|
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
||||||
):
|
):
|
||||||
super().__init__(on_violation_action)
|
super().__init__(on_violation_action)
|
||||||
|
|
||||||
dtype = torch.bfloat16
|
|
||||||
|
|
||||||
assert model_dir is not None, "Llama Guard model_dir is None"
|
|
||||||
|
|
||||||
if excluded_categories is None:
|
if excluded_categories is None:
|
||||||
excluded_categories = []
|
excluded_categories = []
|
||||||
|
|
||||||
|
@ -140,18 +117,12 @@ class LlamaGuardShield(ShieldBase):
|
||||||
x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories
|
x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories
|
||||||
), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]"
|
), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]"
|
||||||
|
|
||||||
self.device = "cuda"
|
self.model = model
|
||||||
|
self.inference_api = inference_api
|
||||||
self.excluded_categories = excluded_categories
|
self.excluded_categories = excluded_categories
|
||||||
self.disable_input_check = disable_input_check
|
self.disable_input_check = disable_input_check
|
||||||
self.disable_output_check = disable_output_check
|
self.disable_output_check = disable_output_check
|
||||||
|
|
||||||
# load model
|
|
||||||
torch_dtype = torch.bfloat16
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
model_dir, torch_dtype=torch_dtype, device_map=self.device
|
|
||||||
)
|
|
||||||
|
|
||||||
def check_unsafe_response(self, response: str) -> Optional[str]:
|
def check_unsafe_response(self, response: str) -> Optional[str]:
|
||||||
match = re.match(r"^unsafe\n(.*)$", response)
|
match = re.match(r"^unsafe\n(.*)$", response)
|
||||||
if match:
|
if match:
|
||||||
|
@ -212,26 +183,21 @@ class LlamaGuardShield(ShieldBase):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prompt = self.build_prompt(messages)
|
prompt = self.build_prompt(messages)
|
||||||
llama_guard_input = {
|
|
||||||
"role": "user",
|
|
||||||
"content": prompt,
|
|
||||||
}
|
|
||||||
input_ids = self.tokenizer.apply_chat_template(
|
|
||||||
[llama_guard_input], return_tensors="pt", tokenize=True
|
|
||||||
).to(self.device)
|
|
||||||
prompt_len = input_ids.shape[1]
|
|
||||||
output = self.model.generate(
|
|
||||||
input_ids=input_ids,
|
|
||||||
max_new_tokens=20,
|
|
||||||
output_scores=True,
|
|
||||||
return_dict_in_generate=True,
|
|
||||||
pad_token_id=0,
|
|
||||||
)
|
|
||||||
generated_tokens = output.sequences[:, prompt_len:]
|
|
||||||
|
|
||||||
response = self.tokenizer.decode(
|
# TODO: llama-stack inference protocol has issues with non-streaming inference code
|
||||||
generated_tokens[0], skip_special_tokens=True
|
content = ""
|
||||||
)
|
async for chunk in self.inference_api.chat_completion(
|
||||||
response = response.strip()
|
model=self.model,
|
||||||
shield_response = self.get_shield_response(response)
|
messages=[
|
||||||
|
UserMessage(content=prompt),
|
||||||
|
],
|
||||||
|
stream=True,
|
||||||
|
):
|
||||||
|
event = chunk.event
|
||||||
|
if event.event_type == ChatCompletionResponseEventType.progress:
|
||||||
|
assert isinstance(event.delta, str)
|
||||||
|
content += event.delta
|
||||||
|
|
||||||
|
content = content.strip()
|
||||||
|
shield_response = self.get_shield_response(content)
|
||||||
return shield_response
|
return shield_response
|
||||||
|
|
|
@ -8,11 +8,25 @@ from typing import List
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
EMBEDDING_DEPS = [
|
EMBEDDING_DEPS = [
|
||||||
"blobfile",
|
"blobfile",
|
||||||
"chardet",
|
"chardet",
|
||||||
"pypdf",
|
"pypdf",
|
||||||
"sentence-transformers",
|
"tqdm",
|
||||||
|
"numpy",
|
||||||
|
"scikit-learn",
|
||||||
|
"scipy",
|
||||||
|
"nltk",
|
||||||
|
"sentencepiece",
|
||||||
|
"transformers",
|
||||||
|
# this happens to work because special dependencies are always installed last
|
||||||
|
# so if there was a regular torch installed first, this would be ignored
|
||||||
|
# we need a better way to do this to identify potential conflicts, etc.
|
||||||
|
# for now, this lets us significantly reduce the size of the container which
|
||||||
|
# does not have any "local" inference code (and hence does not need GPU-enabled torch)
|
||||||
|
"torch --index-url https://download.pytorch.org/whl/cpu",
|
||||||
|
"sentence-transformers --no-deps",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -15,13 +15,15 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
api=Api.safety,
|
api=Api.safety,
|
||||||
provider_id="meta-reference",
|
provider_id="meta-reference",
|
||||||
pip_packages=[
|
pip_packages=[
|
||||||
"accelerate",
|
|
||||||
"codeshield",
|
"codeshield",
|
||||||
"torch",
|
|
||||||
"transformers",
|
"transformers",
|
||||||
|
"torch --index-url https://download.pytorch.org/whl/cpu",
|
||||||
],
|
],
|
||||||
module="llama_stack.providers.impls.meta_reference.safety",
|
module="llama_stack.providers.impls.meta_reference.safety",
|
||||||
config_class="llama_stack.providers.impls.meta_reference.safety.SafetyConfig",
|
config_class="llama_stack.providers.impls.meta_reference.safety.SafetyConfig",
|
||||||
|
api_dependencies=[
|
||||||
|
Api.inference,
|
||||||
|
],
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
api=Api.safety,
|
api=Api.safety,
|
||||||
|
|
|
@ -25,20 +25,22 @@ from llama_stack.apis.memory import * # noqa: F403
|
||||||
|
|
||||||
ALL_MINILM_L6_V2_DIMENSION = 384
|
ALL_MINILM_L6_V2_DIMENSION = 384
|
||||||
|
|
||||||
EMBEDDING_MODEL = None
|
EMBEDDING_MODELS = {}
|
||||||
|
|
||||||
|
|
||||||
def get_embedding_model() -> "SentenceTransformer":
|
def get_embedding_model(model: str) -> "SentenceTransformer":
|
||||||
global EMBEDDING_MODEL
|
global EMBEDDING_MODELS
|
||||||
|
|
||||||
if EMBEDDING_MODEL is None:
|
loaded_model = EMBEDDING_MODELS.get(model)
|
||||||
print("Loading sentence transformer")
|
if loaded_model is not None:
|
||||||
|
return loaded_model
|
||||||
|
|
||||||
from sentence_transformers import SentenceTransformer
|
print(f"Loading sentence transformer for {model}...")
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
EMBEDDING_MODEL = SentenceTransformer("all-MiniLM-L6-v2")
|
loaded_model = SentenceTransformer(model)
|
||||||
|
EMBEDDING_MODELS[model] = loaded_model
|
||||||
return EMBEDDING_MODEL
|
return loaded_model
|
||||||
|
|
||||||
|
|
||||||
def parse_data_url(data_url: str):
|
def parse_data_url(data_url: str):
|
||||||
|
@ -151,7 +153,7 @@ class BankWithIndex:
|
||||||
self,
|
self,
|
||||||
documents: List[MemoryBankDocument],
|
documents: List[MemoryBankDocument],
|
||||||
) -> None:
|
) -> None:
|
||||||
model = get_embedding_model()
|
model = get_embedding_model(self.bank.config.embedding_model)
|
||||||
for doc in documents:
|
for doc in documents:
|
||||||
content = await content_from_doc(doc)
|
content = await content_from_doc(doc)
|
||||||
chunks = make_overlapped_chunks(
|
chunks = make_overlapped_chunks(
|
||||||
|
@ -187,6 +189,6 @@ class BankWithIndex:
|
||||||
else:
|
else:
|
||||||
query_str = _process(query)
|
query_str = _process(query)
|
||||||
|
|
||||||
model = get_embedding_model()
|
model = get_embedding_model(self.bank.config.embedding_model)
|
||||||
query_vector = model.encode([query_str])[0].astype(np.float32)
|
query_vector = model.encode([query_str])[0].astype(np.float32)
|
||||||
return await self.index.query(query_vector, k)
|
return await self.index.query(query_vector, k)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue