Merge branch 'meta-llama:main' into main

This commit is contained in:
Pixee OSS Assistant 2024-09-29 07:57:31 -04:00 committed by GitHub
commit cd64371b2e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
28 changed files with 286 additions and 283 deletions

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import TogetherImplConfig, TogetherHeaderExtractor
from .config import TogetherImplConfig
async def get_adapter_impl(config: TogetherImplConfig, _deps):

View file

@ -4,17 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel, Field
from llama_models.schema_utils import json_schema_type
from llama_stack.distribution.request_headers import annotate_header
class TogetherHeaderExtractor(BaseModel):
api_key: annotate_header(
"X-LlamaStack-Together-ApiKey", str, "The API Key for the request"
)
from pydantic import BaseModel, Field
@json_schema_type

View file

@ -15,6 +15,7 @@ from llama_models.sku_list import resolve_model
from together import Together
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.request_headers import get_request_provider_data
from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools,
)
@ -22,9 +23,12 @@ from llama_stack.providers.utils.inference.augment_messages import (
from .config import TogetherImplConfig
TOGETHER_SUPPORTED_MODELS = {
"Llama3.1-8B-Instruct": "meta-llama/Llama-3.1-8B-Instruct-Turbo",
"Llama3.1-70B-Instruct": "meta-llama/Llama-3.1-70B-Instruct-Turbo",
"Llama3.1-405B-Instruct": "meta-llama/Llama-3.1-405B-Instruct-Turbo",
"Llama3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
"Llama3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
"Llama3.1-405B-Instruct": "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
"Llama3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct-Turbo",
"Llama3.2-11B-Vision-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
"Llama3.2-90B-Vision-Instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo",
}
@ -97,6 +101,16 @@ class TogetherInferenceAdapter(Inference):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
together_api_key = None
provider_data = get_request_provider_data()
if provider_data is None or not provider_data.together_api_key:
raise ValueError(
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
)
together_api_key = provider_data.together_api_key
client = Together(api_key=together_api_key)
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request = ChatCompletionRequest(
model=model,
@ -116,7 +130,7 @@ class TogetherInferenceAdapter(Inference):
if not request.stream:
# TODO: might need to add back an async here
r = self.client.chat.completions.create(
r = client.chat.completions.create(
model=together_model,
messages=self._messages_to_together_messages(messages),
stream=False,
@ -151,7 +165,7 @@ class TogetherInferenceAdapter(Inference):
ipython = False
stop_reason = None
for chunk in self.client.chat.completions.create(
for chunk in client.chat.completions.create(
model=together_model,
messages=self._messages_to_together_messages(messages),
stream=True,

View file

@ -3,12 +3,41 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.sku_list import resolve_model
from together import Together
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.safety import (
RunShieldResponse,
Safety,
SafetyViolation,
ViolationLevel,
)
from llama_stack.distribution.request_headers import get_request_provider_data
from .config import TogetherProviderDataValidator, TogetherSafetyConfig
from .config import TogetherSafetyConfig
SAFETY_SHIELD_TYPES = {
"Llama-Guard-3-8B": "meta-llama/Meta-Llama-Guard-3-8B",
"Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision-Turbo",
}
def shield_type_to_model_name(shield_type: str) -> str:
if shield_type == "llama_guard":
shield_type = "Llama-Guard-3-8B"
model = resolve_model(shield_type)
if (
model is None
or not model.descriptor(shorten_default_variant=True) in SAFETY_SHIELD_TYPES
or model.model_family is not ModelFamily.safety
):
raise ValueError(
f"{shield_type} is not supported, please use of {','.join(SAFETY_SHIELD_TYPES.keys())}"
)
return SAFETY_SHIELD_TYPES.get(model.descriptor(shorten_default_variant=True))
class TogetherSafetyImpl(Safety):
@ -21,24 +50,16 @@ class TogetherSafetyImpl(Safety):
async def run_shield(
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse:
if shield_type != "llama_guard":
raise ValueError(f"shield type {shield_type} is not supported")
provider_data = get_request_provider_data()
together_api_key = None
if provider_data is not None:
if not isinstance(provider_data, TogetherProviderDataValidator):
raise ValueError(
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
)
provider_data = get_request_provider_data()
if provider_data is None or not provider_data.together_api_key:
raise ValueError(
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
)
together_api_key = provider_data.together_api_key
together_api_key = provider_data.together_api_key
if not together_api_key:
together_api_key = self.config.api_key
if not together_api_key:
raise ValueError("The API key must be provider in the header or config")
model_name = shield_type_to_model_name(shield_type)
# messages can have role assistant or user
api_messages = []
@ -46,17 +67,17 @@ class TogetherSafetyImpl(Safety):
if message.role in (Role.user.value, Role.assistant.value):
api_messages.append({"role": message.role, "content": message.content})
violation = await get_safety_response(together_api_key, api_messages)
violation = await get_safety_response(
together_api_key, model_name, api_messages
)
return RunShieldResponse(violation=violation)
async def get_safety_response(
api_key: str, messages: List[Dict[str, str]]
api_key: str, model_name: str, messages: List[Dict[str, str]]
) -> Optional[SafetyViolation]:
client = Together(api_key=api_key)
response = client.chat.completions.create(
messages=messages, model="meta-llama/Meta-Llama-Guard-3-8B"
)
response = client.chat.completions.create(messages=messages, model=model_name)
if len(response.choices) == 0:
return None

View file

@ -7,12 +7,13 @@
from typing import Optional
from llama_models.datatypes import * # noqa: F403
from llama_models.sku_list import all_registered_models, resolve_model
from llama_models.sku_list import resolve_model
from llama_stack.apis.inference import * # noqa: F401, F403
from pydantic import BaseModel, Field, field_validator
from llama_stack.providers.utils.inference import supported_inference_models
class MetaReferenceImplConfig(BaseModel):
model: str = Field(
@ -27,12 +28,7 @@ class MetaReferenceImplConfig(BaseModel):
@field_validator("model")
@classmethod
def validate_model(cls, model: str) -> str:
permitted_models = [
m.descriptor()
for m in all_registered_models()
if m.model_family in {ModelFamily.llama3_1, ModelFamily.llama3_2}
or m.core_model_id == CoreModelId.llama_guard_3_8b
]
permitted_models = supported_inference_models()
if model not in permitted_models:
model_list = "\n\t".join(permitted_models)
raise ValueError(

View file

@ -52,7 +52,7 @@ def model_checkpoint_dir(model) -> str:
checkpoint_dir = checkpoint_dir / "original"
assert checkpoint_dir.exists(), (
f"Could not find checkpoint dir: {checkpoint_dir}."
f"Could not find checkpoints in: {model_local_dir(model.descriptor())}. "
f"Please download model using `llama download --model-id {model.descriptor()}`"
)
return str(checkpoint_dir)

View file

@ -14,6 +14,10 @@ import torch
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
from llama_models.llama3.api.model import Transformer, TransformerBlock
from termcolor import cprint
from torch import Tensor
from llama_stack.apis.inference import QuantizationType
from llama_stack.apis.inference.config import (
@ -21,9 +25,6 @@ from llama_stack.apis.inference.config import (
MetaReferenceImplConfig,
)
from termcolor import cprint
from torch import Tensor
def is_fbgemm_available() -> bool:
try:

View file

@ -88,10 +88,10 @@ class MetaReferenceSafetyImpl(Safety):
assert (
cfg is not None
), "Cannot use LlamaGuardShield since not present in config"
model_dir = resolve_and_get_path(cfg.model)
return LlamaGuardShield(
model_dir=model_dir,
model=cfg.model,
inference_api=self.inference_api,
excluded_categories=cfg.excluded_categories,
disable_input_check=cfg.disable_input_check,
disable_output_check=cfg.disable_output_check,

View file

@ -9,17 +9,10 @@ import re
from string import Template
from typing import List, Optional
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
MllamaForConditionalGeneration,
MllamaProcessor
)
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
from llama_models.llama3.api.datatypes import * # noqa: F403
SAFE_RESPONSE = "safe"
@ -75,20 +68,17 @@ DEFAULT_LG_V3_SAFETY_CATEGORIES = [
CAT_ELECTIONS,
]
# model names
LG_3_8B = "Llama-Guard-3-8B"
LG_3_1B = "Llama-Guard-3-1B"
LG_3_11B_VISION = "Llama-Guard-3-11B-Vision"
MODEL_TO_SAFETY_CATEGORIES_MAP = {
LG_3_8B: DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE],
LG_3_1B: DEFAULT_LG_V3_SAFETY_CATEGORIES,
LG_3_11B_VISION: DEFAULT_LG_V3_SAFETY_CATEGORIES,
CoreModelId.llama_guard_3_8b.value: (
DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE]
),
CoreModelId.llama_guard_3_1b.value: DEFAULT_LG_V3_SAFETY_CATEGORIES,
CoreModelId.llama_guard_3_11b_vision.value: DEFAULT_LG_V3_SAFETY_CATEGORIES,
}
PROMPT_TASK = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nTask: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories."
PROMPT_TASK = "Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories."
SAFETY_CATEGORIES = """
@ -109,7 +99,7 @@ $conversations
PROMPT_INSTRUCTIONS = """
Provide your safety assessment for ONLY THE LAST $agent_type message in the above conversation:
- First line must read 'safe' or 'unsafe'.
- If unsafe, a second line must include a comma-separated list of violated categories. <|eot_id|><|start_header_id|>assistant<|end_header_id|>"""
- If unsafe, a second line must include a comma-separated list of violated categories."""
PROMPT_TEMPLATE = Template(
@ -120,7 +110,8 @@ PROMPT_TEMPLATE = Template(
class LlamaGuardShield(ShieldBase):
def __init__(
self,
model_dir: str,
model: str,
inference_api: Inference,
excluded_categories: List[str] = None,
disable_input_check: bool = False,
disable_output_check: bool = False,
@ -128,12 +119,6 @@ class LlamaGuardShield(ShieldBase):
):
super().__init__(on_violation_action)
dtype = torch.bfloat16
self.model_dir = model_dir
self.device = "cuda"
assert self.model_dir is not None, "Llama Guard model_dir is None"
if excluded_categories is None:
excluded_categories = []
@ -141,27 +126,15 @@ class LlamaGuardShield(ShieldBase):
x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories
), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]"
if model not in MODEL_TO_SAFETY_CATEGORIES_MAP:
raise ValueError(f"Unsupported model: {model}")
self.model = model
self.inference_api = inference_api
self.excluded_categories = excluded_categories
self.disable_input_check = disable_input_check
self.disable_output_check = disable_output_check
torch_dtype = torch.bfloat16
self.model_dir = f"meta-llama/{self.get_model_name()}"
if self.is_lg_vision():
self.model = MllamaForConditionalGeneration.from_pretrained(
self.model_dir, device_map=self.device, torch_dtype=torch_dtype
)
self.processor = MllamaProcessor.from_pretrained(self.model_dir)
else:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)
self.model = AutoModelForCausalLM.from_pretrained(
self.model_dir, torch_dtype=torch_dtype, device_map=self.device
)
def check_unsafe_response(self, response: str) -> Optional[str]:
match = re.match(r"^unsafe\n(.*)$", response)
if match:
@ -177,7 +150,8 @@ class LlamaGuardShield(ShieldBase):
excluded_categories = []
final_categories = []
all_categories = MODEL_TO_SAFETY_CATEGORIES_MAP[self.get_model_name()]
all_categories = MODEL_TO_SAFETY_CATEGORIES_MAP[self.model]
for cat in all_categories:
cat_code = SAFETY_CATEGORIES_TO_CODE_MAP[cat]
if cat_code in excluded_categories:
@ -186,11 +160,99 @@ class LlamaGuardShield(ShieldBase):
return final_categories
def validate_messages(self, messages: List[Message]) -> None:
if len(messages) == 0:
raise ValueError("Messages must not be empty")
if messages[0].role != Role.user.value:
raise ValueError("Messages must start with user")
if len(messages) >= 2 and (
messages[0].role == Role.user.value and messages[1].role == Role.user.value
):
messages = messages[1:]
for i in range(1, len(messages)):
if messages[i].role == messages[i - 1].role:
raise ValueError(
f"Messages must alternate between user and assistant. Message {i} has the same role as message {i-1}"
)
return messages
async def run(self, messages: List[Message]) -> ShieldResponse:
messages = self.validate_messages(messages)
if self.disable_input_check and messages[-1].role == Role.user.value:
return ShieldResponse(is_violation=False)
elif self.disable_output_check and messages[-1].role == Role.assistant.value:
return ShieldResponse(
is_violation=False,
)
if self.model == CoreModelId.llama_guard_3_11b_vision.value:
shield_input_message = self.build_vision_shield_input(messages)
else:
shield_input_message = self.build_text_shield_input(messages)
# TODO: llama-stack inference protocol has issues with non-streaming inference code
content = ""
async for chunk in self.inference_api.chat_completion(
model=self.model,
messages=[shield_input_message],
stream=True,
):
event = chunk.event
if event.event_type == ChatCompletionResponseEventType.progress:
assert isinstance(event.delta, str)
content += event.delta
content = content.strip()
shield_response = self.get_shield_response(content)
return shield_response
def build_text_shield_input(self, messages: List[Message]) -> UserMessage:
return UserMessage(content=self.build_prompt(messages))
def build_vision_shield_input(self, messages: List[Message]) -> UserMessage:
conversation = []
most_recent_img = None
for m in messages[::-1]:
if isinstance(m.content, str):
conversation.append(m)
elif isinstance(m.content, ImageMedia):
if most_recent_img is None and m.role == Role.user.value:
most_recent_img = m.content
conversation.append(m)
elif isinstance(m.content, list):
content = []
for c in m.content:
if isinstance(c, str):
content.append(c)
elif isinstance(c, ImageMedia):
if most_recent_img is None and m.role == Role.user.value:
most_recent_img = c
content.append(c)
else:
raise ValueError(f"Unknown content type: {c}")
conversation.append(UserMessage(content=content))
else:
raise ValueError(f"Unknown content type: {m.content}")
prompt = []
if most_recent_img is not None:
prompt.append(most_recent_img)
prompt.append(self.build_prompt(conversation[::-1]))
return UserMessage(content=prompt)
def build_prompt(self, messages: List[Message]) -> str:
categories = self.get_safety_categories()
categories_str = "\n".join(categories)
conversations_str = "\n\n".join(
[f"{m.role.capitalize()}: {m.content}" for m in messages]
[
f"{m.role.capitalize()}: {interleaved_text_media_as_str(m.content)}"
for m in messages
]
)
return PROMPT_TEMPLATE.substitute(
agent_type=messages[-1].role.capitalize(),
@ -214,134 +276,3 @@ class LlamaGuardShield(ShieldBase):
)
raise ValueError(f"Unexpected response: {response}")
def build_mm_prompt(self, messages: List[Message]) -> str:
conversation = []
most_recent_img = None
for m in messages[::-1]:
if isinstance(m.content, str):
conversation.append(
{
"role": m.role,
"content": [{"type": "text", "text": m.content}],
}
)
elif isinstance(m.content, ImageMedia):
if most_recent_img is None and m.role == Role.user.value:
most_recent_img = m.content
conversation.append(
{
"role": m.role,
"content": [{"type": "image"}],
}
)
elif isinstance(m.content, list):
content = []
for c in m.content:
if isinstance(c, str):
content.append({"type": "text", "text": c})
elif isinstance(c, ImageMedia):
if most_recent_img is None and m.role == Role.user.value:
most_recent_img = c
content.append({"type": "image"})
else:
raise ValueError(f"Unknown content type: {c}")
conversation.append(
{
"role": m.role,
"content": content,
}
)
else:
raise ValueError(f"Unknown content type: {m.content}")
return conversation[::-1], most_recent_img
async def run_lg_mm(self, messages: List[Message]) -> ShieldResponse:
formatted_messages, most_recent_img = self.build_mm_prompt(messages)
raw_image = None
if most_recent_img:
raw_image = interleaved_text_media_localize(most_recent_img)
raw_image = raw_image.image
llama_guard_input_templ_applied = self.processor.apply_chat_template(
formatted_messages,
add_generation_prompt=True,
tokenize=False,
skip_special_tokens=False,
)
inputs = self.processor(
text=llama_guard_input_templ_applied, images=raw_image, return_tensors="pt"
).to(self.device)
output = self.model.generate(**inputs, do_sample=False, max_new_tokens=50)
response = self.processor.decode(
output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
)
shield_response = self.get_shield_response(response)
return shield_response
async def run_lg_text(self, messages: List[Message]):
prompt = self.build_prompt(messages)
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
prompt_len = input_ids.shape[1]
output = self.model.generate(
input_ids=input_ids,
max_new_tokens=20,
output_scores=True,
return_dict_in_generate=True,
pad_token_id=0,
)
generated_tokens = output.sequences[:, prompt_len:]
response = self.tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
shield_response = self.get_shield_response(response)
return shield_response
def get_model_name(self):
return self.model_dir.split("/")[-1]
def is_lg_vision(self):
model_name = self.get_model_name()
return model_name == LG_3_11B_VISION
def validate_messages(self, messages: List[Message]) -> None:
if len(messages) == 0:
raise ValueError("Messages must not be empty")
if messages[0].role != Role.user.value:
raise ValueError("Messages must start with user")
if len(messages) >= 2 and (
messages[0].role == Role.user.value and messages[1].role == Role.user.value
):
messages = messages[1:]
for i in range(1, len(messages)):
if messages[i].role == messages[i - 1].role:
raise ValueError(
f"Messages must alternate between user and assistant. Message {i} has the same role as message {i-1}"
)
return messages
async def run(self, messages: List[Message]) -> ShieldResponse:
messages = self.validate_messages(messages)
if self.disable_input_check and messages[-1].role == Role.user.value:
return ShieldResponse(is_violation=False)
elif self.disable_output_check and messages[-1].role == Role.assistant.value:
return ShieldResponse(
is_violation=False,
)
else:
if self.is_lg_vision():
shield_response = await self.run_lg_mm(messages)
else:
shield_response = await self.run_lg_text(messages)
return shield_response

View file

@ -91,7 +91,7 @@ def available_providers() -> List[ProviderSpec]:
],
module="llama_stack.providers.adapters.inference.together",
config_class="llama_stack.providers.adapters.inference.together.TogetherImplConfig",
header_extractor_class="llama_stack.providers.adapters.inference.together.TogetherHeaderExtractor",
provider_data_validator="llama_stack.providers.adapters.safety.together.TogetherProviderDataValidator",
),
),
]

View file

@ -21,10 +21,9 @@ def available_providers() -> List[ProviderSpec]:
api=Api.safety,
provider_id="meta-reference",
pip_packages=[
"accelerate",
"codeshield",
"torch",
"transformers",
"torch --index-url https://download.pytorch.org/whl/cpu",
],
module="llama_stack.providers.impls.meta_reference.safety",
config_class="llama_stack.providers.impls.meta_reference.safety.SafetyConfig",

View file

@ -3,3 +3,31 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_models.datatypes import * # noqa: F403
from llama_models.sku_list import all_registered_models
def is_supported_safety_model(model: Model) -> bool:
if model.quantization_format != CheckpointQuantizationFormat.bf16:
return False
model_id = model.core_model_id
return model_id in [
CoreModelId.llama_guard_3_8b,
CoreModelId.llama_guard_3_1b,
CoreModelId.llama_guard_3_11b_vision,
]
def supported_inference_models() -> List[str]:
return [
m.descriptor()
for m in all_registered_models()
if (
m.model_family in {ModelFamily.llama3_1, ModelFamily.llama3_2}
or is_supported_safety_model(m)
)
]

View file

@ -16,6 +16,8 @@ from llama_models.llama3.prompt_templates import (
)
from llama_models.sku_list import resolve_model
from llama_stack.providers.utils.inference import supported_inference_models
def augment_messages_for_tools(request: ChatCompletionRequest) -> List[Message]:
"""Reads chat completion request and augments the messages to handle tools.
@ -27,8 +29,8 @@ def augment_messages_for_tools(request: ChatCompletionRequest) -> List[Message]:
cprint(f"Could not resolve model {request.model}", color="red")
return request.messages
if model.model_family not in [ModelFamily.llama3_1, ModelFamily.llama3_2]:
cprint(f"Model family {model.model_family} not llama 3_1 or 3_2", color="red")
if model.descriptor() not in supported_inference_models():
cprint(f"Unsupported inference model? {model.descriptor()}", color="red")
return request.messages
if model.model_family == ModelFamily.llama3_1 or (