Merge branch 'meta-llama:main' into main

This commit is contained in:
Zain Hasan 2024-09-24 17:09:55 -07:00 committed by GitHub
commit 3ee415dc35
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 140 additions and 116 deletions

View file

@ -190,7 +190,7 @@ class Inference(Protocol):
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
# zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = list,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,

View file

@ -13,9 +13,9 @@ from typing import Any, Dict, List, Optional
import fire
import httpx
from termcolor import cprint
from llama_stack.distribution.datatypes import RemoteProviderConfig
from termcolor import cprint
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.utils.memory.file_utils import data_url_from_file
@ -120,7 +120,7 @@ async def run_main(host: str, port: int, stream: bool):
name="test_bank",
config=VectorMemoryBankConfig(
bank_id="test_bank",
embedding_model="dragon-roberta-query-2",
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
),
@ -129,7 +129,7 @@ async def run_main(host: str, port: int, stream: bool):
retrieved_bank = await client.get_memory_bank(bank.bank_id)
assert retrieved_bank is not None
assert retrieved_bank.config.embedding_model == "dragon-roberta-query-2"
assert retrieved_bank.config.embedding_model == "all-MiniLM-L6-v2"
urls = [
"memory_optimizations.rst",

View file

@ -9,12 +9,12 @@ import json
from llama_models.sku_list import resolve_model
from termcolor import colored
from llama_stack.cli.subcommand import Subcommand
from llama_stack.cli.table import print_table
from llama_stack.distribution.utils.serialize import EnumEncoder
from termcolor import colored
class ModelDescribe(Subcommand):
"""Show details about a model"""
@ -52,7 +52,7 @@ class ModelDescribe(Subcommand):
colored(model.descriptor(), "white", attrs=["bold"]),
),
("HuggingFace ID", model.huggingface_repo or "<Not Available>"),
("Description", model.description_markdown),
("Description", model.description),
("Context Length", f"{model.max_seq_length // 1024}K tokens"),
("Weights format", model.quantization_format.value),
("Model params.json", json.dumps(model.arch_args, indent=4)),

View file

@ -66,6 +66,14 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
if provider_spec.docker_image:
raise ValueError("A stack's dependencies cannot have a docker image")
special_deps = []
deps = []
for package in package_deps.pip_packages:
if "--no-deps" in package or "--index-url" in package:
special_deps.append(package)
else:
deps.append(package)
if build_config.image_type == ImageType.docker.value:
script = pkg_resources.resource_filename(
"llama_stack", "distribution/build_container.sh"
@ -75,7 +83,7 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
build_config.name,
package_deps.docker_image,
str(build_file_path),
" ".join(package_deps.pip_packages),
" ".join(deps),
]
else:
script = pkg_resources.resource_filename(
@ -84,9 +92,12 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
args = [
script,
build_config.name,
" ".join(package_deps.pip_packages),
" ".join(deps),
]
if special_deps:
args.append("#".join(special_deps))
return_code = run_with_pty(args)
if return_code != 0:
cprint(

View file

@ -17,14 +17,16 @@ if [ -n "$LLAMA_MODELS_DIR" ]; then
echo "Using llama-models-dir=$LLAMA_MODELS_DIR"
fi
set -euo pipefail
if [ "$#" -ne 2 ]; then
echo "Usage: $0 <distribution_type> <build_name> <pip_dependencies>" >&2
if [ "$#" -lt 2 ]; then
echo "Usage: $0 <distribution_type> <build_name> <pip_dependencies> [<special_pip_deps>]" >&2
echo "Example: $0 <distribution_type> mybuild 'numpy pandas scipy'" >&2
exit 1
fi
special_pip_deps="$3"
set -euo pipefail
build_name="$1"
env_name="llamastack-$build_name"
pip_dependencies="$2"
@ -43,6 +45,7 @@ source "$SCRIPT_DIR/common.sh"
ensure_conda_env_python310() {
local env_name="$1"
local pip_dependencies="$2"
local special_pip_deps="$3"
local python_version="3.10"
# Check if conda command is available
@ -78,7 +81,12 @@ ensure_conda_env_python310() {
if [ -n "$TEST_PYPI_VERSION" ]; then
# these packages are damaged in test-pypi, so install them first
$CONDA_PREFIX/bin/pip install fastapi libcst
$CONDA_PREFIX/bin/pip install --extra-index-url https://test.pypi.org/simple/ llama-models==$TEST_PYPI_VERSION llama-stack==$TEST_PYPI_VERSION $pip_dependencies
$CONDA_PREFIX/bin/pip install --extra-index-url https://test.pypi.org/simple/ \
llama-models==$TEST_PYPI_VERSION llama-stack==$TEST_PYPI_VERSION \
$pip_dependencies
if [ -n "$special_pip_deps" ]; then
$CONDA_PREFIX/bin/pip install --no-deps "$special_pip_deps"
fi
else
# Re-installing llama-stack in the new conda environment
if [ -n "$LLAMA_STACK_DIR" ]; then
@ -105,11 +113,16 @@ ensure_conda_env_python310() {
fi
# Install pip dependencies
if [ -n "$pip_dependencies" ]; then
printf "Installing pip dependencies: $pip_dependencies\n"
printf "Installing pip dependencies\n"
$CONDA_PREFIX/bin/pip install $pip_dependencies
if [ -n "$special_pip_deps" ]; then
IFS='#' read -ra parts <<< "$special_pip_deps"
for part in "${parts[@]}"; do
echo "$part"
$CONDA_PREFIX/bin/pip install $part
done
fi
fi
}
ensure_conda_env_python310 "$env_name" "$pip_dependencies"
ensure_conda_env_python310 "$env_name" "$pip_dependencies" "$special_pip_deps"

View file

@ -4,12 +4,16 @@ LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
if [ "$#" -ne 4 ]; then
echo "Usage: $0 <build_name> <docker_base> <pip_dependencies>
echo "Example: $0 my-fastapi-app python:3.9-slim 'fastapi uvicorn'
if [ "$#" -lt 4 ]; then
echo "Usage: $0 <build_name> <docker_base> <pip_dependencies> [<special_pip_deps>]" >&2
echo "Example: $0 my-fastapi-app python:3.9-slim 'fastapi uvicorn' " >&2
exit 1
fi
special_pip_deps="$5"
set -euo pipefail
build_name="$1"
image_name="llamastack-$build_name"
docker_base=$2
@ -21,8 +25,6 @@ RED='\033[0;31m'
GREEN='\033[0;32m'
NC='\033[0m' # No Color
set -euo pipefail
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
REPO_DIR=$(dirname $(dirname "$SCRIPT_DIR"))
DOCKER_BINARY=${DOCKER_BINARY:-docker}
@ -85,6 +87,13 @@ if [ -n "$pip_dependencies" ]; then
add_to_docker "RUN pip install $pip_dependencies"
fi
if [ -n "$special_pip_deps" ]; then
IFS='#' read -ra parts <<< "$special_pip_deps"
for part in "${parts[@]}"; do
add_to_docker "RUN pip install $part"
done
fi
add_to_docker <<EOF
# This would be good in production but for debugging flexibility lets not add it right now

View file

@ -103,8 +103,7 @@ class InferenceRouter(Inference):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
# TODO: we need to fix streaming response to align provider implementations with Protocol.
async for chunk in self.routing_table.get_provider_impl(model).chat_completion(
params = dict(
model=model,
messages=messages,
sampling_params=sampling_params,
@ -113,6 +112,10 @@ class InferenceRouter(Inference):
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
# TODO: we need to fix streaming response to align provider implementations with Protocol.
async for chunk in self.routing_table.get_provider_impl(model).chat_completion(
**params
):
yield chunk

View file

@ -33,8 +33,10 @@ class CommonRoutingTableImpl(RoutingTable):
for p in self.providers.values():
await p.shutdown()
def get_provider_impl(self, routing_key: str) -> Optional[Any]:
return self.providers.get(routing_key)
def get_provider_impl(self, routing_key: str) -> Any:
if routing_key not in self.providers:
raise ValueError(f"Could not find provider for {routing_key}")
return self.providers[routing_key]
def get_routing_keys(self) -> List[str]:
return self.routing_keys

View file

@ -368,17 +368,19 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
providers = all_providers[info.router_api]
inner_specs = []
inner_deps = []
for rt_entry in routing_table:
if rt_entry.provider_id not in providers:
raise ValueError(
f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`"
)
inner_specs.append(providers[rt_entry.provider_id])
inner_deps.extend(providers[rt_entry.provider_id].api_dependencies)
specs[source_api] = RoutingTableProviderSpec(
api=source_api,
module="llama_stack.distribution.routers",
api_dependencies=[],
api_dependencies=inner_deps,
inner_specs=inner_specs,
)
configs[source_api] = routing_table

View file

@ -119,7 +119,7 @@ class TGIAdapter(Inference):
)
stop_reason = None
if response.details.finish_reason:
if response.details.finish_reason == "stop":
if response.details.finish_reason in ["stop", "eos_token"]:
stop_reason = StopReason.end_of_turn
elif response.details.finish_reason == "length":
stop_reason = StopReason.out_of_tokens

View file

@ -7,11 +7,11 @@
from .config import SafetyConfig
async def get_provider_impl(config: SafetyConfig, _deps):
async def get_provider_impl(config: SafetyConfig, deps):
from .safety import MetaReferenceSafetyImpl
assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}"
impl = MetaReferenceSafetyImpl(config)
impl = MetaReferenceSafetyImpl(config, deps)
await impl.initialize()
return impl

View file

@ -7,8 +7,10 @@
from llama_models.sku_list import resolve_model
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.impls.meta_reference.safety.shields.base import (
OnViolationAction,
@ -34,20 +36,11 @@ def resolve_and_get_path(model_name: str) -> str:
class MetaReferenceSafetyImpl(Safety):
def __init__(self, config: SafetyConfig) -> None:
def __init__(self, config: SafetyConfig, deps) -> None:
self.config = config
self.inference_api = deps[Api.inference]
async def initialize(self) -> None:
shield_cfg = self.config.llama_guard_shield
if shield_cfg is not None:
model_dir = resolve_and_get_path(shield_cfg.model)
_ = LlamaGuardShield.instance(
model_dir=model_dir,
excluded_categories=shield_cfg.excluded_categories,
disable_input_check=shield_cfg.disable_input_check,
disable_output_check=shield_cfg.disable_output_check,
)
shield_cfg = self.config.prompt_guard_shield
if shield_cfg is not None:
model_dir = resolve_and_get_path(shield_cfg.model)
@ -91,11 +84,18 @@ class MetaReferenceSafetyImpl(Safety):
def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase:
cfg = self.config
if typ == MetaReferenceShieldType.llama_guard:
cfg = cfg.llama_guard_shield
assert (
cfg.llama_guard_shield is not None
cfg is not None
), "Cannot use LlamaGuardShield since not present in config"
model_dir = resolve_and_get_path(cfg.llama_guard_shield.model)
return LlamaGuardShield.instance(model_dir=model_dir)
return LlamaGuardShield(
model=cfg.model,
inference_api=self.inference_api,
excluded_categories=cfg.excluded_categories,
disable_input_check=cfg.disable_input_check,
disable_output_check=cfg.disable_output_check,
)
elif typ == MetaReferenceShieldType.jailbreak_shield:
assert (
cfg.prompt_guard_shield is not None

View file

@ -9,9 +9,8 @@ import re
from string import Template
from typing import List, Optional
import torch
from llama_models.llama3.api.datatypes import Message, Role
from transformers import AutoModelForCausalLM, AutoTokenizer
from llama_stack.apis.inference import * # noqa: F403
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
@ -100,39 +99,17 @@ PROMPT_TEMPLATE = Template(
class LlamaGuardShield(ShieldBase):
@staticmethod
def instance(
on_violation_action=OnViolationAction.RAISE,
model_dir: str = None,
excluded_categories: List[str] = None,
disable_input_check: bool = False,
disable_output_check: bool = False,
) -> "LlamaGuardShield":
global _INSTANCE
if _INSTANCE is None:
_INSTANCE = LlamaGuardShield(
on_violation_action,
model_dir,
excluded_categories,
disable_input_check,
disable_output_check,
)
return _INSTANCE
def __init__(
self,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
model_dir: str = None,
model: str,
inference_api: Inference,
excluded_categories: List[str] = None,
disable_input_check: bool = False,
disable_output_check: bool = False,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
):
super().__init__(on_violation_action)
dtype = torch.bfloat16
assert model_dir is not None, "Llama Guard model_dir is None"
if excluded_categories is None:
excluded_categories = []
@ -140,18 +117,12 @@ class LlamaGuardShield(ShieldBase):
x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories
), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]"
self.device = "cuda"
self.model = model
self.inference_api = inference_api
self.excluded_categories = excluded_categories
self.disable_input_check = disable_input_check
self.disable_output_check = disable_output_check
# load model
torch_dtype = torch.bfloat16
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.model = AutoModelForCausalLM.from_pretrained(
model_dir, torch_dtype=torch_dtype, device_map=self.device
)
def check_unsafe_response(self, response: str) -> Optional[str]:
match = re.match(r"^unsafe\n(.*)$", response)
if match:
@ -212,26 +183,21 @@ class LlamaGuardShield(ShieldBase):
)
else:
prompt = self.build_prompt(messages)
llama_guard_input = {
"role": "user",
"content": prompt,
}
input_ids = self.tokenizer.apply_chat_template(
[llama_guard_input], return_tensors="pt", tokenize=True
).to(self.device)
prompt_len = input_ids.shape[1]
output = self.model.generate(
input_ids=input_ids,
max_new_tokens=20,
output_scores=True,
return_dict_in_generate=True,
pad_token_id=0,
)
generated_tokens = output.sequences[:, prompt_len:]
response = self.tokenizer.decode(
generated_tokens[0], skip_special_tokens=True
)
response = response.strip()
shield_response = self.get_shield_response(response)
# TODO: llama-stack inference protocol has issues with non-streaming inference code
content = ""
async for chunk in self.inference_api.chat_completion(
model=self.model,
messages=[
UserMessage(content=prompt),
],
stream=True,
):
event = chunk.event
if event.event_type == ChatCompletionResponseEventType.progress:
assert isinstance(event.delta, str)
content += event.delta
content = content.strip()
shield_response = self.get_shield_response(content)
return shield_response

View file

@ -8,11 +8,25 @@ from typing import List
from llama_stack.distribution.datatypes import * # noqa: F403
EMBEDDING_DEPS = [
"blobfile",
"chardet",
"pypdf",
"sentence-transformers",
"tqdm",
"numpy",
"scikit-learn",
"scipy",
"nltk",
"sentencepiece",
"transformers",
# this happens to work because special dependencies are always installed last
# so if there was a regular torch installed first, this would be ignored
# we need a better way to do this to identify potential conflicts, etc.
# for now, this lets us significantly reduce the size of the container which
# does not have any "local" inference code (and hence does not need GPU-enabled torch)
"torch --index-url https://download.pytorch.org/whl/cpu",
"sentence-transformers --no-deps",
]

View file

@ -15,13 +15,15 @@ def available_providers() -> List[ProviderSpec]:
api=Api.safety,
provider_id="meta-reference",
pip_packages=[
"accelerate",
"codeshield",
"torch",
"transformers",
"torch --index-url https://download.pytorch.org/whl/cpu",
],
module="llama_stack.providers.impls.meta_reference.safety",
config_class="llama_stack.providers.impls.meta_reference.safety.SafetyConfig",
api_dependencies=[
Api.inference,
],
),
remote_provider_spec(
api=Api.safety,

View file

@ -25,20 +25,22 @@ from llama_stack.apis.memory import * # noqa: F403
ALL_MINILM_L6_V2_DIMENSION = 384
EMBEDDING_MODEL = None
EMBEDDING_MODELS = {}
def get_embedding_model() -> "SentenceTransformer":
global EMBEDDING_MODEL
def get_embedding_model(model: str) -> "SentenceTransformer":
global EMBEDDING_MODELS
if EMBEDDING_MODEL is None:
print("Loading sentence transformer")
loaded_model = EMBEDDING_MODELS.get(model)
if loaded_model is not None:
return loaded_model
print(f"Loading sentence transformer for {model}...")
from sentence_transformers import SentenceTransformer
EMBEDDING_MODEL = SentenceTransformer("all-MiniLM-L6-v2")
return EMBEDDING_MODEL
loaded_model = SentenceTransformer(model)
EMBEDDING_MODELS[model] = loaded_model
return loaded_model
def parse_data_url(data_url: str):
@ -151,7 +153,7 @@ class BankWithIndex:
self,
documents: List[MemoryBankDocument],
) -> None:
model = get_embedding_model()
model = get_embedding_model(self.bank.config.embedding_model)
for doc in documents:
content = await content_from_doc(doc)
chunks = make_overlapped_chunks(
@ -187,6 +189,6 @@ class BankWithIndex:
else:
query_str = _process(query)
model = get_embedding_model()
model = get_embedding_model(self.bank.config.embedding_model)
query_vector = model.encode([query_str])[0].astype(np.float32)
return await self.index.query(query_vector, k)