support Llama 3.2 models in Together inference adapter and cleanup Together safety adapter

This commit is contained in:
Yogish Baliga 2024-09-25 17:51:42 -07:00
parent 9bb0c8f4fc
commit 2b568a462a
2 changed files with 38 additions and 11 deletions

View file

@ -23,9 +23,15 @@ from llama_stack.distribution.request_headers import get_request_provider_data
from .config import TogetherImplConfig from .config import TogetherImplConfig
TOGETHER_SUPPORTED_MODELS = { TOGETHER_SUPPORTED_MODELS = {
"Llama3.1-8B-Instruct": "meta-llama/Llama-3.1-8B-Instruct-Turbo", "Llama3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
"Llama3.1-70B-Instruct": "meta-llama/Llama-3.1-70B-Instruct-Turbo", "Llama3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
"Llama3.1-405B-Instruct": "meta-llama/Llama-3.1-405B-Instruct-Turbo", "Llama3.1-405B-Instruct": "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
"Llama3.2-1B-Instruct": "meta-llama/Meta-Llama-3.2-1B-Instruct-Turbo",
"Llama3.2-3B-Instruct": "meta-llama/Meta-Llama-3.2-3B-Instruct-Turbo",
"Llama3.2-11B-Vision": "meta-llama/Meta-Llama-3.2-11B-Vision-Turbo",
"Llama3.2-90B-Vision": "meta-llama/Meta-Llama-3.2-90B-Vision-Turbo",
"Llama3.2-11B-Vision-Instruct": "meta-llama./Meta-Llama-3.2-11B-Vision-Turbo",
"Llama3.2-90B-Vision-Instruct": "meta-llama/Meta-Llama-3.2-90B-Vision-Turbo",
} }

View file

@ -3,7 +3,7 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_models.sku_list import resolve_model
from together import Together from together import Together
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
@ -17,6 +17,27 @@ from llama_stack.distribution.request_headers import get_request_provider_data
from .config import TogetherSafetyConfig from .config import TogetherSafetyConfig
SAFETY_SHIELD_TYPES = {
"Llama-Guard-3-8B": "meta-llama/Meta-Llama-Guard-3-8B",
}
def shield_type_to_model_name(shield_type: str) -> str:
if shield_type == "llama_guard":
shield_type = "Llama-Guard-3-8B"
model = resolve_model(shield_type)
if (
model is None
or not model.descriptor(shorten_default_variant=True) in SAFETY_SHIELD_TYPES
or model.model_family is not ModelFamily.safety
):
raise ValueError(
f"{shield_type} is not supported, please use of {','.join(SAFETY_SHIELD_TYPES.keys())}"
)
return SAFETY_SHIELD_TYPES.get(model.descriptor(shorten_default_variant=True))
class TogetherSafetyImpl(Safety): class TogetherSafetyImpl(Safety):
def __init__(self, config: TogetherSafetyConfig) -> None: def __init__(self, config: TogetherSafetyConfig) -> None:
@ -28,8 +49,6 @@ class TogetherSafetyImpl(Safety):
async def run_shield( async def run_shield(
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse: ) -> RunShieldResponse:
if shield_type != "llama_guard":
raise ValueError(f"shield type {shield_type} is not supported")
together_api_key = None together_api_key = None
provider_data = get_request_provider_data() provider_data = get_request_provider_data()
@ -39,23 +58,25 @@ class TogetherSafetyImpl(Safety):
) )
together_api_key = provider_data.together_api_key together_api_key = provider_data.together_api_key
model_name = shield_type_to_model_name(shield_type)
# messages can have role assistant or user # messages can have role assistant or user
api_messages = [] api_messages = []
for message in messages: for message in messages:
if message.role in (Role.user.value, Role.assistant.value): if message.role in (Role.user.value, Role.assistant.value):
api_messages.append({"role": message.role, "content": message.content}) api_messages.append({"role": message.role, "content": message.content})
violation = await get_safety_response(together_api_key, api_messages) violation = await get_safety_response(
together_api_key, model_name, api_messages
)
return RunShieldResponse(violation=violation) return RunShieldResponse(violation=violation)
async def get_safety_response( async def get_safety_response(
api_key: str, messages: List[Dict[str, str]] api_key: str, model_name: str, messages: List[Dict[str, str]]
) -> Optional[SafetyViolation]: ) -> Optional[SafetyViolation]:
client = Together(api_key=api_key) client = Together(api_key=api_key)
response = client.chat.completions.create( response = client.chat.completions.create(messages=messages, model=model_name)
messages=messages, model="meta-llama/Meta-Llama-Guard-3-8B"
)
if len(response.choices) == 0: if len(response.choices) == 0:
return None return None