Merge branch 'meta-llama:main' into main

This commit is contained in:
Zain Hasan 2024-09-24 17:09:55 -07:00 committed by GitHub
commit 3ee415dc35
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 140 additions and 116 deletions

View file

@ -119,7 +119,7 @@ class TGIAdapter(Inference):
)
stop_reason = None
if response.details.finish_reason:
if response.details.finish_reason == "stop":
if response.details.finish_reason in ["stop", "eos_token"]:
stop_reason = StopReason.end_of_turn
elif response.details.finish_reason == "length":
stop_reason = StopReason.out_of_tokens

View file

@ -7,11 +7,11 @@
from .config import SafetyConfig
async def get_provider_impl(config: SafetyConfig, _deps):
async def get_provider_impl(config: SafetyConfig, deps):
from .safety import MetaReferenceSafetyImpl
assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}"
impl = MetaReferenceSafetyImpl(config)
impl = MetaReferenceSafetyImpl(config, deps)
await impl.initialize()
return impl

View file

@ -7,8 +7,10 @@
from llama_models.sku_list import resolve_model
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.impls.meta_reference.safety.shields.base import (
OnViolationAction,
@ -34,20 +36,11 @@ def resolve_and_get_path(model_name: str) -> str:
class MetaReferenceSafetyImpl(Safety):
def __init__(self, config: SafetyConfig) -> None:
def __init__(self, config: SafetyConfig, deps) -> None:
self.config = config
self.inference_api = deps[Api.inference]
async def initialize(self) -> None:
shield_cfg = self.config.llama_guard_shield
if shield_cfg is not None:
model_dir = resolve_and_get_path(shield_cfg.model)
_ = LlamaGuardShield.instance(
model_dir=model_dir,
excluded_categories=shield_cfg.excluded_categories,
disable_input_check=shield_cfg.disable_input_check,
disable_output_check=shield_cfg.disable_output_check,
)
shield_cfg = self.config.prompt_guard_shield
if shield_cfg is not None:
model_dir = resolve_and_get_path(shield_cfg.model)
@ -91,11 +84,18 @@ class MetaReferenceSafetyImpl(Safety):
def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase:
cfg = self.config
if typ == MetaReferenceShieldType.llama_guard:
cfg = cfg.llama_guard_shield
assert (
cfg.llama_guard_shield is not None
cfg is not None
), "Cannot use LlamaGuardShield since not present in config"
model_dir = resolve_and_get_path(cfg.llama_guard_shield.model)
return LlamaGuardShield.instance(model_dir=model_dir)
return LlamaGuardShield(
model=cfg.model,
inference_api=self.inference_api,
excluded_categories=cfg.excluded_categories,
disable_input_check=cfg.disable_input_check,
disable_output_check=cfg.disable_output_check,
)
elif typ == MetaReferenceShieldType.jailbreak_shield:
assert (
cfg.prompt_guard_shield is not None

View file

@ -9,9 +9,8 @@ import re
from string import Template
from typing import List, Optional
import torch
from llama_models.llama3.api.datatypes import Message, Role
from transformers import AutoModelForCausalLM, AutoTokenizer
from llama_stack.apis.inference import * # noqa: F403
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
@ -100,39 +99,17 @@ PROMPT_TEMPLATE = Template(
class LlamaGuardShield(ShieldBase):
@staticmethod
def instance(
on_violation_action=OnViolationAction.RAISE,
model_dir: str = None,
excluded_categories: List[str] = None,
disable_input_check: bool = False,
disable_output_check: bool = False,
) -> "LlamaGuardShield":
global _INSTANCE
if _INSTANCE is None:
_INSTANCE = LlamaGuardShield(
on_violation_action,
model_dir,
excluded_categories,
disable_input_check,
disable_output_check,
)
return _INSTANCE
def __init__(
self,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
model_dir: str = None,
model: str,
inference_api: Inference,
excluded_categories: List[str] = None,
disable_input_check: bool = False,
disable_output_check: bool = False,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
):
super().__init__(on_violation_action)
dtype = torch.bfloat16
assert model_dir is not None, "Llama Guard model_dir is None"
if excluded_categories is None:
excluded_categories = []
@ -140,18 +117,12 @@ class LlamaGuardShield(ShieldBase):
x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories
), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]"
self.device = "cuda"
self.model = model
self.inference_api = inference_api
self.excluded_categories = excluded_categories
self.disable_input_check = disable_input_check
self.disable_output_check = disable_output_check
# load model
torch_dtype = torch.bfloat16
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.model = AutoModelForCausalLM.from_pretrained(
model_dir, torch_dtype=torch_dtype, device_map=self.device
)
def check_unsafe_response(self, response: str) -> Optional[str]:
match = re.match(r"^unsafe\n(.*)$", response)
if match:
@ -212,26 +183,21 @@ class LlamaGuardShield(ShieldBase):
)
else:
prompt = self.build_prompt(messages)
llama_guard_input = {
"role": "user",
"content": prompt,
}
input_ids = self.tokenizer.apply_chat_template(
[llama_guard_input], return_tensors="pt", tokenize=True
).to(self.device)
prompt_len = input_ids.shape[1]
output = self.model.generate(
input_ids=input_ids,
max_new_tokens=20,
output_scores=True,
return_dict_in_generate=True,
pad_token_id=0,
)
generated_tokens = output.sequences[:, prompt_len:]
response = self.tokenizer.decode(
generated_tokens[0], skip_special_tokens=True
)
response = response.strip()
shield_response = self.get_shield_response(response)
# TODO: llama-stack inference protocol has issues with non-streaming inference code
content = ""
async for chunk in self.inference_api.chat_completion(
model=self.model,
messages=[
UserMessage(content=prompt),
],
stream=True,
):
event = chunk.event
if event.event_type == ChatCompletionResponseEventType.progress:
assert isinstance(event.delta, str)
content += event.delta
content = content.strip()
shield_response = self.get_shield_response(content)
return shield_response

View file

@ -8,11 +8,25 @@ from typing import List
from llama_stack.distribution.datatypes import * # noqa: F403
EMBEDDING_DEPS = [
"blobfile",
"chardet",
"pypdf",
"sentence-transformers",
"tqdm",
"numpy",
"scikit-learn",
"scipy",
"nltk",
"sentencepiece",
"transformers",
# this happens to work because special dependencies are always installed last
# so if there was a regular torch installed first, this would be ignored
# we need a better way to do this to identify potential conflicts, etc.
# for now, this lets us significantly reduce the size of the container which
# does not have any "local" inference code (and hence does not need GPU-enabled torch)
"torch --index-url https://download.pytorch.org/whl/cpu",
"sentence-transformers --no-deps",
]

View file

@ -15,13 +15,15 @@ def available_providers() -> List[ProviderSpec]:
api=Api.safety,
provider_id="meta-reference",
pip_packages=[
"accelerate",
"codeshield",
"torch",
"transformers",
"torch --index-url https://download.pytorch.org/whl/cpu",
],
module="llama_stack.providers.impls.meta_reference.safety",
config_class="llama_stack.providers.impls.meta_reference.safety.SafetyConfig",
api_dependencies=[
Api.inference,
],
),
remote_provider_spec(
api=Api.safety,

View file

@ -25,20 +25,22 @@ from llama_stack.apis.memory import * # noqa: F403
ALL_MINILM_L6_V2_DIMENSION = 384
EMBEDDING_MODEL = None
EMBEDDING_MODELS = {}
def get_embedding_model() -> "SentenceTransformer":
global EMBEDDING_MODEL
def get_embedding_model(model: str) -> "SentenceTransformer":
global EMBEDDING_MODELS
if EMBEDDING_MODEL is None:
print("Loading sentence transformer")
loaded_model = EMBEDDING_MODELS.get(model)
if loaded_model is not None:
return loaded_model
from sentence_transformers import SentenceTransformer
print(f"Loading sentence transformer for {model}...")
from sentence_transformers import SentenceTransformer
EMBEDDING_MODEL = SentenceTransformer("all-MiniLM-L6-v2")
return EMBEDDING_MODEL
loaded_model = SentenceTransformer(model)
EMBEDDING_MODELS[model] = loaded_model
return loaded_model
def parse_data_url(data_url: str):
@ -151,7 +153,7 @@ class BankWithIndex:
self,
documents: List[MemoryBankDocument],
) -> None:
model = get_embedding_model()
model = get_embedding_model(self.bank.config.embedding_model)
for doc in documents:
content = await content_from_doc(doc)
chunks = make_overlapped_chunks(
@ -187,6 +189,6 @@ class BankWithIndex:
else:
query_str = _process(query)
model = get_embedding_model()
model = get_embedding_model(self.bank.config.embedding_model)
query_vector = model.encode([query_str])[0].astype(np.float32)
return await self.index.query(query_vector, k)