mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
fixing safety inference and safety adapter for new API spec. Pinned t… (#105)
* fixing safety inference and safety adapter for new API spec. Pinned the llama_models version to 0.0.24 as the latest version 0.0.35 has the model descriptor name changed. I was getting the missing package error during runtime as well, hence added the dependency to requirements.txt * support Llama 3.2 models in Together inference adapter and cleanup Together safety adapter * fixing model names * adding vision guard to Together safety
This commit is contained in:
parent
0a3999a9a4
commit
940968ee3f
5 changed files with 68 additions and 40 deletions
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .config import TogetherImplConfig, TogetherHeaderExtractor
|
from .config import TogetherImplConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: TogetherImplConfig, _deps):
|
async def get_adapter_impl(config: TogetherImplConfig, _deps):
|
||||||
|
|
|
@ -4,17 +4,8 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
from llama_models.schema_utils import json_schema_type
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
from llama_stack.distribution.request_headers import annotate_header
|
|
||||||
|
|
||||||
|
|
||||||
class TogetherHeaderExtractor(BaseModel):
|
|
||||||
api_key: annotate_header(
|
|
||||||
"X-LlamaStack-Together-ApiKey", str, "The API Key for the request"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -18,13 +18,17 @@ from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.providers.utils.inference.augment_messages import (
|
from llama_stack.providers.utils.inference.augment_messages import (
|
||||||
augment_messages_for_tools,
|
augment_messages_for_tools,
|
||||||
)
|
)
|
||||||
|
from llama_stack.distribution.request_headers import get_request_provider_data
|
||||||
|
|
||||||
from .config import TogetherImplConfig
|
from .config import TogetherImplConfig
|
||||||
|
|
||||||
TOGETHER_SUPPORTED_MODELS = {
|
TOGETHER_SUPPORTED_MODELS = {
|
||||||
"Llama3.1-8B-Instruct": "meta-llama/Llama-3.1-8B-Instruct-Turbo",
|
"Llama3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
|
||||||
"Llama3.1-70B-Instruct": "meta-llama/Llama-3.1-70B-Instruct-Turbo",
|
"Llama3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
|
||||||
"Llama3.1-405B-Instruct": "meta-llama/Llama-3.1-405B-Instruct-Turbo",
|
"Llama3.1-405B-Instruct": "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
|
||||||
|
"Llama3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct-Turbo",
|
||||||
|
"Llama3.2-11B-Vision-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
|
||||||
|
"Llama3.2-90B-Vision-Instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -97,6 +101,16 @@ class TogetherInferenceAdapter(Inference):
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
|
||||||
|
together_api_key = None
|
||||||
|
provider_data = get_request_provider_data()
|
||||||
|
if provider_data is None or not provider_data.together_api_key:
|
||||||
|
raise ValueError(
|
||||||
|
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
|
||||||
|
)
|
||||||
|
together_api_key = provider_data.together_api_key
|
||||||
|
|
||||||
|
client = Together(api_key=together_api_key)
|
||||||
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -116,7 +130,7 @@ class TogetherInferenceAdapter(Inference):
|
||||||
|
|
||||||
if not request.stream:
|
if not request.stream:
|
||||||
# TODO: might need to add back an async here
|
# TODO: might need to add back an async here
|
||||||
r = self.client.chat.completions.create(
|
r = client.chat.completions.create(
|
||||||
model=together_model,
|
model=together_model,
|
||||||
messages=self._messages_to_together_messages(messages),
|
messages=self._messages_to_together_messages(messages),
|
||||||
stream=False,
|
stream=False,
|
||||||
|
@ -151,7 +165,7 @@ class TogetherInferenceAdapter(Inference):
|
||||||
ipython = False
|
ipython = False
|
||||||
stop_reason = None
|
stop_reason = None
|
||||||
|
|
||||||
for chunk in self.client.chat.completions.create(
|
for chunk in client.chat.completions.create(
|
||||||
model=together_model,
|
model=together_model,
|
||||||
messages=self._messages_to_together_messages(messages),
|
messages=self._messages_to_together_messages(messages),
|
||||||
stream=True,
|
stream=True,
|
||||||
|
|
|
@ -3,12 +3,41 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
from llama_models.sku_list import resolve_model
|
||||||
from together import Together
|
from together import Together
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
from llama_stack.apis.safety import (
|
||||||
|
RunShieldResponse,
|
||||||
|
Safety,
|
||||||
|
SafetyViolation,
|
||||||
|
ViolationLevel,
|
||||||
|
)
|
||||||
from llama_stack.distribution.request_headers import get_request_provider_data
|
from llama_stack.distribution.request_headers import get_request_provider_data
|
||||||
|
|
||||||
from .config import TogetherProviderDataValidator, TogetherSafetyConfig
|
from .config import TogetherSafetyConfig
|
||||||
|
|
||||||
|
SAFETY_SHIELD_TYPES = {
|
||||||
|
"Llama-Guard-3-8B": "meta-llama/Meta-Llama-Guard-3-8B",
|
||||||
|
"Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision-Turbo",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def shield_type_to_model_name(shield_type: str) -> str:
|
||||||
|
if shield_type == "llama_guard":
|
||||||
|
shield_type = "Llama-Guard-3-8B"
|
||||||
|
|
||||||
|
model = resolve_model(shield_type)
|
||||||
|
if (
|
||||||
|
model is None
|
||||||
|
or not model.descriptor(shorten_default_variant=True) in SAFETY_SHIELD_TYPES
|
||||||
|
or model.model_family is not ModelFamily.safety
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"{shield_type} is not supported, please use of {','.join(SAFETY_SHIELD_TYPES.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return SAFETY_SHIELD_TYPES.get(model.descriptor(shorten_default_variant=True))
|
||||||
|
|
||||||
|
|
||||||
class TogetherSafetyImpl(Safety):
|
class TogetherSafetyImpl(Safety):
|
||||||
|
@ -21,24 +50,16 @@ class TogetherSafetyImpl(Safety):
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
||||||
) -> RunShieldResponse:
|
) -> RunShieldResponse:
|
||||||
if shield_type != "llama_guard":
|
|
||||||
raise ValueError(f"shield type {shield_type} is not supported")
|
|
||||||
|
|
||||||
provider_data = get_request_provider_data()
|
|
||||||
|
|
||||||
together_api_key = None
|
together_api_key = None
|
||||||
if provider_data is not None:
|
provider_data = get_request_provider_data()
|
||||||
if not isinstance(provider_data, TogetherProviderDataValidator):
|
if provider_data is None or not provider_data.together_api_key:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
|
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
|
||||||
)
|
)
|
||||||
|
|
||||||
together_api_key = provider_data.together_api_key
|
together_api_key = provider_data.together_api_key
|
||||||
if not together_api_key:
|
|
||||||
together_api_key = self.config.api_key
|
|
||||||
|
|
||||||
if not together_api_key:
|
model_name = shield_type_to_model_name(shield_type)
|
||||||
raise ValueError("The API key must be provider in the header or config")
|
|
||||||
|
|
||||||
# messages can have role assistant or user
|
# messages can have role assistant or user
|
||||||
api_messages = []
|
api_messages = []
|
||||||
|
@ -46,23 +67,25 @@ class TogetherSafetyImpl(Safety):
|
||||||
if message.role in (Role.user.value, Role.assistant.value):
|
if message.role in (Role.user.value, Role.assistant.value):
|
||||||
api_messages.append({"role": message.role, "content": message.content})
|
api_messages.append({"role": message.role, "content": message.content})
|
||||||
|
|
||||||
violation = await get_safety_response(together_api_key, api_messages)
|
violation = await get_safety_response(
|
||||||
|
together_api_key, model_name, api_messages
|
||||||
|
)
|
||||||
return RunShieldResponse(violation=violation)
|
return RunShieldResponse(violation=violation)
|
||||||
|
|
||||||
|
|
||||||
async def get_safety_response(
|
async def get_safety_response(
|
||||||
api_key: str, messages: List[Dict[str, str]]
|
api_key: str, model_name: str, messages: List[Dict[str, str]]
|
||||||
) -> Optional[SafetyViolation]:
|
) -> Optional[SafetyViolation]:
|
||||||
client = Together(api_key=api_key)
|
client = Together(api_key=api_key)
|
||||||
response = client.chat.completions.create(
|
response = client.chat.completions.create(messages=messages, model=model_name)
|
||||||
messages=messages, model="meta-llama/Meta-Llama-Guard-3-8B"
|
|
||||||
)
|
|
||||||
if len(response.choices) == 0:
|
if len(response.choices) == 0:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
response_text = response.choices[0].message.content
|
response_text = response.choices[0].message.content
|
||||||
if response_text == "safe":
|
if response_text == "safe":
|
||||||
return None
|
return SafetyViolation(
|
||||||
|
violation_level=ViolationLevel.INFO, user_message="safe", metadata={}
|
||||||
|
)
|
||||||
|
|
||||||
parts = response_text.split("\n")
|
parts = response_text.split("\n")
|
||||||
if len(parts) != 2:
|
if len(parts) != 2:
|
||||||
|
|
|
@ -91,7 +91,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
],
|
],
|
||||||
module="llama_stack.providers.adapters.inference.together",
|
module="llama_stack.providers.adapters.inference.together",
|
||||||
config_class="llama_stack.providers.adapters.inference.together.TogetherImplConfig",
|
config_class="llama_stack.providers.adapters.inference.together.TogetherImplConfig",
|
||||||
header_extractor_class="llama_stack.providers.adapters.inference.together.TogetherHeaderExtractor",
|
provider_data_validator="llama_stack.providers.adapters.safety.together.TogetherProviderDataValidator",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue