diff --git a/llama_stack/providers/adapters/inference/together/together.py b/llama_stack/providers/adapters/inference/together/together.py index 18c83aada..e9af6f17c 100644 --- a/llama_stack/providers/adapters/inference/together/together.py +++ b/llama_stack/providers/adapters/inference/together/together.py @@ -23,9 +23,15 @@ from llama_stack.distribution.request_headers import get_request_provider_data from .config import TogetherImplConfig TOGETHER_SUPPORTED_MODELS = { - "Llama3.1-8B-Instruct": "meta-llama/Llama-3.1-8B-Instruct-Turbo", - "Llama3.1-70B-Instruct": "meta-llama/Llama-3.1-70B-Instruct-Turbo", - "Llama3.1-405B-Instruct": "meta-llama/Llama-3.1-405B-Instruct-Turbo", + "Llama3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", + "Llama3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", + "Llama3.1-405B-Instruct": "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo", + "Llama3.2-1B-Instruct": "meta-llama/Meta-Llama-3.2-1B-Instruct-Turbo", + "Llama3.2-3B-Instruct": "meta-llama/Meta-Llama-3.2-3B-Instruct-Turbo", + "Llama3.2-11B-Vision": "meta-llama/Meta-Llama-3.2-11B-Vision-Turbo", + "Llama3.2-90B-Vision": "meta-llama/Meta-Llama-3.2-90B-Vision-Turbo", + "Llama3.2-11B-Vision-Instruct": "meta-llama./Meta-Llama-3.2-11B-Vision-Turbo", + "Llama3.2-90B-Vision-Instruct": "meta-llama/Meta-Llama-3.2-90B-Vision-Turbo", } diff --git a/llama_stack/providers/adapters/safety/together/together.py b/llama_stack/providers/adapters/safety/together/together.py index 15b6bb3a1..6c243a724 100644 --- a/llama_stack/providers/adapters/safety/together/together.py +++ b/llama_stack/providers/adapters/safety/together/together.py @@ -3,7 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. - +from llama_models.sku_list import resolve_model from together import Together from llama_models.llama3.api.datatypes import * # noqa: F403 @@ -17,6 +17,27 @@ from llama_stack.distribution.request_headers import get_request_provider_data from .config import TogetherSafetyConfig +SAFETY_SHIELD_TYPES = { + "Llama-Guard-3-8B": "meta-llama/Meta-Llama-Guard-3-8B", +} + + +def shield_type_to_model_name(shield_type: str) -> str: + if shield_type == "llama_guard": + shield_type = "Llama-Guard-3-8B" + + model = resolve_model(shield_type) + if ( + model is None + or not model.descriptor(shorten_default_variant=True) in SAFETY_SHIELD_TYPES + or model.model_family is not ModelFamily.safety + ): + raise ValueError( + f"{shield_type} is not supported, please use of {','.join(SAFETY_SHIELD_TYPES.keys())}" + ) + + return SAFETY_SHIELD_TYPES.get(model.descriptor(shorten_default_variant=True)) + class TogetherSafetyImpl(Safety): def __init__(self, config: TogetherSafetyConfig) -> None: @@ -28,8 +49,6 @@ class TogetherSafetyImpl(Safety): async def run_shield( self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None ) -> RunShieldResponse: - if shield_type != "llama_guard": - raise ValueError(f"shield type {shield_type} is not supported") together_api_key = None provider_data = get_request_provider_data() @@ -39,23 +58,25 @@ class TogetherSafetyImpl(Safety): ) together_api_key = provider_data.together_api_key + model_name = shield_type_to_model_name(shield_type) + # messages can have role assistant or user api_messages = [] for message in messages: if message.role in (Role.user.value, Role.assistant.value): api_messages.append({"role": message.role, "content": message.content}) - violation = await get_safety_response(together_api_key, api_messages) + violation = await get_safety_response( + together_api_key, model_name, api_messages + ) return RunShieldResponse(violation=violation) async def get_safety_response( - api_key: str, messages: List[Dict[str, str]] + api_key: str, model_name: str, messages: List[Dict[str, str]] ) -> Optional[SafetyViolation]: client = Together(api_key=api_key) - response = client.chat.completions.create( - messages=messages, model="meta-llama/Meta-Llama-Guard-3-8B" - ) + response = client.chat.completions.create(messages=messages, model=model_name) if len(response.choices) == 0: return None