From 940968ee3f2960bfc623ea95c9645101db8eeba1 Mon Sep 17 00:00:00 2001 From: Yogish Baliga Date: Sat, 28 Sep 2024 15:45:38 -0700 Subject: [PATCH] =?UTF-8?q?fixing=20safety=20inference=20and=20safety=20ad?= =?UTF-8?q?apter=20for=20new=20API=20spec.=20Pinned=20t=E2=80=A6=20(#105)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fixing safety inference and safety adapter for new API spec. Pinned the llama_models version to 0.0.24 as the latest version 0.0.35 has the model descriptor name changed. I was getting the missing package error during runtime as well, hence added the dependency to requirements.txt * support Llama 3.2 models in Together inference adapter and cleanup Together safety adapter * fixing model names * adding vision guard to Together safety --- .../adapters/inference/together/__init__.py | 2 +- .../adapters/inference/together/config.py | 11 +-- .../adapters/inference/together/together.py | 24 +++++-- .../adapters/safety/together/together.py | 69 ++++++++++++------- llama_stack/providers/registry/inference.py | 2 +- 5 files changed, 68 insertions(+), 40 deletions(-) diff --git a/llama_stack/providers/adapters/inference/together/__init__.py b/llama_stack/providers/adapters/inference/together/__init__.py index c964ddffb..05ea91e58 100644 --- a/llama_stack/providers/adapters/inference/together/__init__.py +++ b/llama_stack/providers/adapters/inference/together/__init__.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .config import TogetherImplConfig, TogetherHeaderExtractor +from .config import TogetherImplConfig async def get_adapter_impl(config: TogetherImplConfig, _deps): diff --git a/llama_stack/providers/adapters/inference/together/config.py b/llama_stack/providers/adapters/inference/together/config.py index c58f722bc..03ee047d2 100644 --- a/llama_stack/providers/adapters/inference/together/config.py +++ b/llama_stack/providers/adapters/inference/together/config.py @@ -4,17 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from pydantic import BaseModel, Field - from llama_models.schema_utils import json_schema_type - -from llama_stack.distribution.request_headers import annotate_header - - -class TogetherHeaderExtractor(BaseModel): - api_key: annotate_header( - "X-LlamaStack-Together-ApiKey", str, "The API Key for the request" - ) +from pydantic import BaseModel, Field @json_schema_type diff --git a/llama_stack/providers/adapters/inference/together/together.py b/llama_stack/providers/adapters/inference/together/together.py index cafca3fdf..a56b18d7d 100644 --- a/llama_stack/providers/adapters/inference/together/together.py +++ b/llama_stack/providers/adapters/inference/together/together.py @@ -18,13 +18,17 @@ from llama_stack.apis.inference import * # noqa: F403 from llama_stack.providers.utils.inference.augment_messages import ( augment_messages_for_tools, ) +from llama_stack.distribution.request_headers import get_request_provider_data from .config import TogetherImplConfig TOGETHER_SUPPORTED_MODELS = { - "Llama3.1-8B-Instruct": "meta-llama/Llama-3.1-8B-Instruct-Turbo", - "Llama3.1-70B-Instruct": "meta-llama/Llama-3.1-70B-Instruct-Turbo", - "Llama3.1-405B-Instruct": "meta-llama/Llama-3.1-405B-Instruct-Turbo", + "Llama3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", + "Llama3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", + "Llama3.1-405B-Instruct": "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo", + "Llama3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct-Turbo", + "Llama3.2-11B-Vision-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo", + "Llama3.2-90B-Vision-Instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo", } @@ -97,6 +101,16 @@ class TogetherInferenceAdapter(Inference): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: + + together_api_key = None + provider_data = get_request_provider_data() + if provider_data is None or not provider_data.together_api_key: + raise ValueError( + 'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": }' + ) + together_api_key = provider_data.together_api_key + + client = Together(api_key=together_api_key) # wrapper request to make it easier to pass around (internal only, not exposed to API) request = ChatCompletionRequest( model=model, @@ -116,7 +130,7 @@ class TogetherInferenceAdapter(Inference): if not request.stream: # TODO: might need to add back an async here - r = self.client.chat.completions.create( + r = client.chat.completions.create( model=together_model, messages=self._messages_to_together_messages(messages), stream=False, @@ -151,7 +165,7 @@ class TogetherInferenceAdapter(Inference): ipython = False stop_reason = None - for chunk in self.client.chat.completions.create( + for chunk in client.chat.completions.create( model=together_model, messages=self._messages_to_together_messages(messages), stream=True, diff --git a/llama_stack/providers/adapters/safety/together/together.py b/llama_stack/providers/adapters/safety/together/together.py index 223377073..940d02861 100644 --- a/llama_stack/providers/adapters/safety/together/together.py +++ b/llama_stack/providers/adapters/safety/together/together.py @@ -3,12 +3,41 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. - +from llama_models.sku_list import resolve_model from together import Together +from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_stack.apis.safety import ( + RunShieldResponse, + Safety, + SafetyViolation, + ViolationLevel, +) from llama_stack.distribution.request_headers import get_request_provider_data -from .config import TogetherProviderDataValidator, TogetherSafetyConfig +from .config import TogetherSafetyConfig + +SAFETY_SHIELD_TYPES = { + "Llama-Guard-3-8B": "meta-llama/Meta-Llama-Guard-3-8B", + "Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision-Turbo", +} + + +def shield_type_to_model_name(shield_type: str) -> str: + if shield_type == "llama_guard": + shield_type = "Llama-Guard-3-8B" + + model = resolve_model(shield_type) + if ( + model is None + or not model.descriptor(shorten_default_variant=True) in SAFETY_SHIELD_TYPES + or model.model_family is not ModelFamily.safety + ): + raise ValueError( + f"{shield_type} is not supported, please use of {','.join(SAFETY_SHIELD_TYPES.keys())}" + ) + + return SAFETY_SHIELD_TYPES.get(model.descriptor(shorten_default_variant=True)) class TogetherSafetyImpl(Safety): @@ -21,24 +50,16 @@ class TogetherSafetyImpl(Safety): async def run_shield( self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None ) -> RunShieldResponse: - if shield_type != "llama_guard": - raise ValueError(f"shield type {shield_type} is not supported") - - provider_data = get_request_provider_data() together_api_key = None - if provider_data is not None: - if not isinstance(provider_data, TogetherProviderDataValidator): - raise ValueError( - 'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": }' - ) + provider_data = get_request_provider_data() + if provider_data is None or not provider_data.together_api_key: + raise ValueError( + 'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": }' + ) + together_api_key = provider_data.together_api_key - together_api_key = provider_data.together_api_key - if not together_api_key: - together_api_key = self.config.api_key - - if not together_api_key: - raise ValueError("The API key must be provider in the header or config") + model_name = shield_type_to_model_name(shield_type) # messages can have role assistant or user api_messages = [] @@ -46,23 +67,25 @@ class TogetherSafetyImpl(Safety): if message.role in (Role.user.value, Role.assistant.value): api_messages.append({"role": message.role, "content": message.content}) - violation = await get_safety_response(together_api_key, api_messages) + violation = await get_safety_response( + together_api_key, model_name, api_messages + ) return RunShieldResponse(violation=violation) async def get_safety_response( - api_key: str, messages: List[Dict[str, str]] + api_key: str, model_name: str, messages: List[Dict[str, str]] ) -> Optional[SafetyViolation]: client = Together(api_key=api_key) - response = client.chat.completions.create( - messages=messages, model="meta-llama/Meta-Llama-Guard-3-8B" - ) + response = client.chat.completions.create(messages=messages, model=model_name) if len(response.choices) == 0: return None response_text = response.choices[0].message.content if response_text == "safe": - return None + return SafetyViolation( + violation_level=ViolationLevel.INFO, user_message="safe", metadata={} + ) parts = response_text.split("\n") if len(parts) != 2: diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 31b3e2c2d..9e7ed90f7 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -91,7 +91,7 @@ def available_providers() -> List[ProviderSpec]: ], module="llama_stack.providers.adapters.inference.together", config_class="llama_stack.providers.adapters.inference.together.TogetherImplConfig", - header_extractor_class="llama_stack.providers.adapters.inference.together.TogetherHeaderExtractor", + provider_data_validator="llama_stack.providers.adapters.safety.together.TogetherProviderDataValidator", ), ), ]