mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-09 21:18:38 +00:00
Merge branch 'meta-llama:main' into main
This commit is contained in:
commit
c13b2f06af
88 changed files with 4367 additions and 784 deletions
17
llama_stack/providers/adapters/inference/bedrock/__init__.py
Normal file
17
llama_stack/providers/adapters/inference/bedrock/__init__.py
Normal file
|
@ -0,0 +1,17 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from .bedrock import BedrockInferenceAdapter
|
||||
from .config import BedrockConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: BedrockConfig, _deps):
|
||||
assert isinstance(config, BedrockConfig), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = BedrockInferenceAdapter(config)
|
||||
|
||||
await impl.initialize()
|
||||
|
||||
return impl
|
457
llama_stack/providers/adapters/inference/bedrock/bedrock.py
Normal file
457
llama_stack/providers/adapters/inference/bedrock/bedrock.py
Normal file
|
@ -0,0 +1,457 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import * # noqa: F403
|
||||
|
||||
import boto3
|
||||
from botocore.client import BaseClient
|
||||
from botocore.config import Config
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfig
|
||||
|
||||
# mapping of Model SKUs to ollama models
|
||||
BEDROCK_SUPPORTED_MODELS = {
|
||||
"Meta-Llama3.1-8B-Instruct": "meta.llama3-1-8b-instruct-v1:0",
|
||||
"Meta-Llama3.1-70B-Instruct": "meta.llama3-1-70b-instruct-v1:0",
|
||||
"Meta-Llama3.1-405B-Instruct": "meta.llama3-1-405b-instruct-v1:0",
|
||||
}
|
||||
|
||||
|
||||
class BedrockInferenceAdapter(Inference):
|
||||
|
||||
@staticmethod
|
||||
def _create_bedrock_client(config: BedrockConfig) -> BaseClient:
|
||||
retries_config = {
|
||||
k: v
|
||||
for k, v in dict(
|
||||
total_max_attempts=config.total_max_attempts,
|
||||
mode=config.retry_mode,
|
||||
).items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
config_args = {
|
||||
k: v
|
||||
for k, v in dict(
|
||||
region_name=config.region_name,
|
||||
retries=retries_config if retries_config else None,
|
||||
connect_timeout=config.connect_timeout,
|
||||
read_timeout=config.read_timeout,
|
||||
).items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
boto3_config = Config(**config_args)
|
||||
|
||||
session_args = {
|
||||
k: v
|
||||
for k, v in dict(
|
||||
aws_access_key_id=config.aws_access_key_id,
|
||||
aws_secret_access_key=config.aws_secret_access_key,
|
||||
aws_session_token=config.aws_session_token,
|
||||
region_name=config.region_name,
|
||||
profile_name=config.profile_name,
|
||||
).items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
boto3_session = boto3.session.Session(**session_args)
|
||||
|
||||
return boto3_session.client("bedrock-runtime", config=boto3_config)
|
||||
|
||||
def __init__(self, config: BedrockConfig) -> None:
|
||||
self._config = config
|
||||
|
||||
self._client = BedrockInferenceAdapter._create_bedrock_client(config)
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(tokenizer)
|
||||
|
||||
@property
|
||||
def client(self) -> BaseClient:
|
||||
return self._client
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self.client.close()
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def resolve_bedrock_model(model_name: str) -> str:
|
||||
model = resolve_model(model_name)
|
||||
assert (
|
||||
model is not None
|
||||
and model.descriptor(shorten_default_variant=True)
|
||||
in BEDROCK_SUPPORTED_MODELS
|
||||
), (
|
||||
f"Unsupported model: {model_name}, use one of the supported models: "
|
||||
f"{','.join(BEDROCK_SUPPORTED_MODELS.keys())}"
|
||||
)
|
||||
|
||||
return BEDROCK_SUPPORTED_MODELS.get(
|
||||
model.descriptor(shorten_default_variant=True)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _bedrock_stop_reason_to_stop_reason(bedrock_stop_reason: str) -> StopReason:
|
||||
if bedrock_stop_reason == "max_tokens":
|
||||
return StopReason.out_of_tokens
|
||||
return StopReason.end_of_turn
|
||||
|
||||
@staticmethod
|
||||
def _builtin_tool_name_to_enum(tool_name_str: str) -> Union[BuiltinTool, str]:
|
||||
for builtin_tool in BuiltinTool:
|
||||
if builtin_tool.value == tool_name_str:
|
||||
return builtin_tool
|
||||
else:
|
||||
return tool_name_str
|
||||
|
||||
@staticmethod
|
||||
def _bedrock_message_to_message(converse_api_res: Dict) -> Message:
|
||||
stop_reason = BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
|
||||
converse_api_res["stopReason"]
|
||||
)
|
||||
|
||||
bedrock_message = converse_api_res["output"]["message"]
|
||||
|
||||
role = bedrock_message["role"]
|
||||
contents = bedrock_message["content"]
|
||||
|
||||
tool_calls = []
|
||||
text_content = []
|
||||
for content in contents:
|
||||
if "toolUse" in content:
|
||||
tool_use = content["toolUse"]
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
tool_name=BedrockInferenceAdapter._builtin_tool_name_to_enum(
|
||||
tool_use["name"]
|
||||
),
|
||||
arguments=tool_use["input"] if "input" in tool_use else None,
|
||||
call_id=tool_use["toolUseId"],
|
||||
)
|
||||
)
|
||||
elif "text" in content:
|
||||
text_content.append(content["text"])
|
||||
|
||||
return CompletionMessage(
|
||||
role=role,
|
||||
content=text_content,
|
||||
stop_reason=stop_reason,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _messages_to_bedrock_messages(
|
||||
messages: List[Message],
|
||||
) -> Tuple[List[Dict], Optional[List[Dict]]]:
|
||||
bedrock_messages = []
|
||||
system_bedrock_messages = []
|
||||
|
||||
user_contents = []
|
||||
assistant_contents = None
|
||||
for message in messages:
|
||||
role = message.role
|
||||
content_list = (
|
||||
message.content
|
||||
if isinstance(message.content, list)
|
||||
else [message.content]
|
||||
)
|
||||
if role == "ipython" or role == "user":
|
||||
if not user_contents:
|
||||
user_contents = []
|
||||
|
||||
if role == "ipython":
|
||||
user_contents.extend(
|
||||
[
|
||||
{
|
||||
"toolResult": {
|
||||
"toolUseId": message.call_id,
|
||||
"content": [
|
||||
{"text": content} for content in content_list
|
||||
],
|
||||
}
|
||||
}
|
||||
]
|
||||
)
|
||||
else:
|
||||
user_contents.extend(
|
||||
[{"text": content} for content in content_list]
|
||||
)
|
||||
|
||||
if assistant_contents:
|
||||
bedrock_messages.append(
|
||||
{"role": "assistant", "content": assistant_contents}
|
||||
)
|
||||
assistant_contents = None
|
||||
elif role == "system":
|
||||
system_bedrock_messages.extend(
|
||||
[{"text": content} for content in content_list]
|
||||
)
|
||||
elif role == "assistant":
|
||||
if not assistant_contents:
|
||||
assistant_contents = []
|
||||
|
||||
assistant_contents.extend(
|
||||
[
|
||||
{
|
||||
"text": content,
|
||||
}
|
||||
for content in content_list
|
||||
]
|
||||
+ [
|
||||
{
|
||||
"toolUse": {
|
||||
"input": tool_call.arguments,
|
||||
"name": (
|
||||
tool_call.tool_name
|
||||
if isinstance(tool_call.tool_name, str)
|
||||
else tool_call.tool_name.value
|
||||
),
|
||||
"toolUseId": tool_call.call_id,
|
||||
}
|
||||
}
|
||||
for tool_call in message.tool_calls
|
||||
]
|
||||
)
|
||||
|
||||
if user_contents:
|
||||
bedrock_messages.append({"role": "user", "content": user_contents})
|
||||
user_contents = None
|
||||
else:
|
||||
# Unknown role
|
||||
pass
|
||||
|
||||
if user_contents:
|
||||
bedrock_messages.append({"role": "user", "content": user_contents})
|
||||
if assistant_contents:
|
||||
bedrock_messages.append(
|
||||
{"role": "assistant", "content": assistant_contents}
|
||||
)
|
||||
|
||||
if system_bedrock_messages:
|
||||
return bedrock_messages, system_bedrock_messages
|
||||
|
||||
return bedrock_messages, None
|
||||
|
||||
@staticmethod
|
||||
def get_bedrock_inference_config(sampling_params: Optional[SamplingParams]) -> Dict:
|
||||
inference_config = {}
|
||||
if sampling_params:
|
||||
param_mapping = {
|
||||
"max_tokens": "maxTokens",
|
||||
"temperature": "temperature",
|
||||
"top_p": "topP",
|
||||
}
|
||||
|
||||
for k, v in param_mapping.items():
|
||||
if getattr(sampling_params, k):
|
||||
inference_config[v] = getattr(sampling_params, k)
|
||||
|
||||
return inference_config
|
||||
|
||||
@staticmethod
|
||||
def _tool_parameters_to_input_schema(
|
||||
tool_parameters: Optional[Dict[str, ToolParamDefinition]]
|
||||
) -> Dict:
|
||||
input_schema = {"type": "object"}
|
||||
if not tool_parameters:
|
||||
return input_schema
|
||||
|
||||
json_properties = {}
|
||||
required = []
|
||||
for name, param in tool_parameters.items():
|
||||
json_property = {
|
||||
"type": param.param_type,
|
||||
}
|
||||
|
||||
if param.description:
|
||||
json_property["description"] = param.description
|
||||
if param.required:
|
||||
required.append(name)
|
||||
json_properties[name] = json_property
|
||||
|
||||
input_schema["properties"] = json_properties
|
||||
if required:
|
||||
input_schema["required"] = required
|
||||
return input_schema
|
||||
|
||||
@staticmethod
|
||||
def _tools_to_tool_config(
|
||||
tools: Optional[List[ToolDefinition]], tool_choice: Optional[ToolChoice]
|
||||
) -> Optional[Dict]:
|
||||
if not tools:
|
||||
return None
|
||||
|
||||
bedrock_tools = []
|
||||
for tool in tools:
|
||||
tool_name = (
|
||||
tool.tool_name
|
||||
if isinstance(tool.tool_name, str)
|
||||
else tool.tool_name.value
|
||||
)
|
||||
|
||||
tool_spec = {
|
||||
"toolSpec": {
|
||||
"name": tool_name,
|
||||
"inputSchema": {
|
||||
"json": BedrockInferenceAdapter._tool_parameters_to_input_schema(
|
||||
tool.parameters
|
||||
),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if tool.description:
|
||||
tool_spec["toolSpec"]["description"] = tool.description
|
||||
|
||||
bedrock_tools.append(tool_spec)
|
||||
tool_config = {
|
||||
"tools": bedrock_tools,
|
||||
}
|
||||
|
||||
if tool_choice:
|
||||
tool_config["toolChoice"] = (
|
||||
{"any": {}}
|
||||
if tool_choice.value == ToolChoice.required
|
||||
else {"auto": {}}
|
||||
)
|
||||
return tool_config
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
# zero-shot tool definitions as input to the model
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> (
|
||||
AsyncGenerator
|
||||
): # Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]:
|
||||
bedrock_model = BedrockInferenceAdapter.resolve_bedrock_model(model)
|
||||
inference_config = BedrockInferenceAdapter.get_bedrock_inference_config(
|
||||
sampling_params
|
||||
)
|
||||
|
||||
tool_config = BedrockInferenceAdapter._tools_to_tool_config(tools, tool_choice)
|
||||
bedrock_messages, system_bedrock_messages = (
|
||||
BedrockInferenceAdapter._messages_to_bedrock_messages(messages)
|
||||
)
|
||||
|
||||
converse_api_params = {
|
||||
"modelId": bedrock_model,
|
||||
"messages": bedrock_messages,
|
||||
}
|
||||
if inference_config:
|
||||
converse_api_params["inferenceConfig"] = inference_config
|
||||
|
||||
# Tool use is not supported in streaming mode
|
||||
if tool_config and not stream:
|
||||
converse_api_params["toolConfig"] = tool_config
|
||||
if system_bedrock_messages:
|
||||
converse_api_params["system"] = system_bedrock_messages
|
||||
|
||||
if not stream:
|
||||
converse_api_res = self.client.converse(**converse_api_params)
|
||||
|
||||
output_message = BedrockInferenceAdapter._bedrock_message_to_message(
|
||||
converse_api_res
|
||||
)
|
||||
|
||||
yield ChatCompletionResponse(
|
||||
completion_message=output_message,
|
||||
logprobs=None,
|
||||
)
|
||||
else:
|
||||
converse_stream_api_res = self.client.converse_stream(**converse_api_params)
|
||||
event_stream = converse_stream_api_res["stream"]
|
||||
|
||||
for chunk in event_stream:
|
||||
if "messageStart" in chunk:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta="",
|
||||
)
|
||||
)
|
||||
elif "contentBlockStart" in chunk:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content=ToolCall(
|
||||
tool_name=chunk["contentBlockStart"]["toolUse"][
|
||||
"name"
|
||||
],
|
||||
call_id=chunk["contentBlockStart"]["toolUse"][
|
||||
"toolUseId"
|
||||
],
|
||||
),
|
||||
parse_status=ToolCallParseStatus.started,
|
||||
),
|
||||
)
|
||||
)
|
||||
elif "contentBlockDelta" in chunk:
|
||||
if "text" in chunk["contentBlockDelta"]["delta"]:
|
||||
delta = chunk["contentBlockDelta"]["delta"]["text"]
|
||||
else:
|
||||
delta = ToolCallDelta(
|
||||
content=ToolCall(
|
||||
arguments=chunk["contentBlockDelta"]["delta"][
|
||||
"toolUse"
|
||||
]["input"]
|
||||
),
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=delta,
|
||||
)
|
||||
)
|
||||
elif "contentBlockStop" in chunk:
|
||||
# Ignored
|
||||
pass
|
||||
elif "messageStop" in chunk:
|
||||
stop_reason = (
|
||||
BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
|
||||
chunk["messageStop"]["stopReason"]
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta="",
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
elif "metadata" in chunk:
|
||||
# Ignored
|
||||
pass
|
||||
else:
|
||||
# Ignored
|
||||
pass
|
55
llama_stack/providers/adapters/inference/bedrock/config.py
Normal file
55
llama_stack/providers/adapters/inference/bedrock/config.py
Normal file
|
@ -0,0 +1,55 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import * # noqa: F403
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BedrockConfig(BaseModel):
|
||||
aws_access_key_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID",
|
||||
)
|
||||
aws_secret_access_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY",
|
||||
)
|
||||
aws_session_token: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN",
|
||||
)
|
||||
region_name: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The default AWS Region to use, for example, us-west-1 or us-west-2."
|
||||
"Default use environment variable: AWS_DEFAULT_REGION",
|
||||
)
|
||||
profile_name: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The profile name that contains credentials to use."
|
||||
"Default use environment variable: AWS_PROFILE",
|
||||
)
|
||||
total_max_attempts: Optional[int] = Field(
|
||||
default=None,
|
||||
description="An integer representing the maximum number of attempts that will be made for a single request, "
|
||||
"including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS",
|
||||
)
|
||||
retry_mode: Optional[str] = Field(
|
||||
default=None,
|
||||
description="A string representing the type of retries Boto3 will perform."
|
||||
"Default use environment variable: AWS_RETRY_MODE",
|
||||
)
|
||||
connect_timeout: Optional[float] = Field(
|
||||
default=60,
|
||||
description="The time in seconds till a timeout exception is thrown when attempting to make a connection. "
|
||||
"The default is 60 seconds.",
|
||||
)
|
||||
read_timeout: Optional[float] = Field(
|
||||
default=60,
|
||||
description="The time in seconds till a timeout exception is thrown when attempting to read from a connection."
|
||||
"The default is 60 seconds.",
|
||||
)
|
|
@ -15,14 +15,16 @@ from llama_models.llama3.api.tokenizer import Tokenizer
|
|||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.utils.inference.prepare_messages import prepare_messages
|
||||
from llama_stack.providers.utils.inference.augment_messages import (
|
||||
augment_messages_for_tools,
|
||||
)
|
||||
|
||||
from .config import FireworksImplConfig
|
||||
|
||||
FIREWORKS_SUPPORTED_MODELS = {
|
||||
"Meta-Llama3.1-8B-Instruct": "fireworks/llama-v3p1-8b-instruct",
|
||||
"Meta-Llama3.1-70B-Instruct": "fireworks/llama-v3p1-70b-instruct",
|
||||
"Meta-Llama3.1-405B-Instruct": "fireworks/llama-v3p1-405b-instruct",
|
||||
"Llama3.1-8B-Instruct": "fireworks/llama-v3p1-8b-instruct",
|
||||
"Llama3.1-70B-Instruct": "fireworks/llama-v3p1-70b-instruct",
|
||||
"Llama3.1-405B-Instruct": "fireworks/llama-v3p1-405b-instruct",
|
||||
}
|
||||
|
||||
|
||||
|
@ -106,7 +108,7 @@ class FireworksInferenceAdapter(Inference):
|
|||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
messages = prepare_messages(request)
|
||||
messages = augment_messages_for_tools(request)
|
||||
|
||||
# accumulate sampling params and other options to pass to fireworks
|
||||
options = self.get_fireworks_chat_options(request)
|
||||
|
|
|
@ -16,14 +16,16 @@ from llama_models.sku_list import resolve_model
|
|||
from ollama import AsyncClient
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.utils.inference.prepare_messages import prepare_messages
|
||||
from llama_stack.providers.utils.inference.augment_messages import (
|
||||
augment_messages_for_tools,
|
||||
)
|
||||
|
||||
# TODO: Eventually this will move to the llama cli model list command
|
||||
# mapping of Model SKUs to ollama models
|
||||
OLLAMA_SUPPORTED_SKUS = {
|
||||
# "Meta-Llama3.1-8B-Instruct": "llama3.1",
|
||||
"Meta-Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",
|
||||
"Meta-Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",
|
||||
# "Llama3.1-8B-Instruct": "llama3.1",
|
||||
"Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",
|
||||
"Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",
|
||||
}
|
||||
|
||||
|
||||
|
@ -115,7 +117,7 @@ class OllamaInferenceAdapter(Inference):
|
|||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
messages = prepare_messages(request)
|
||||
messages = augment_messages_for_tools(request)
|
||||
# accumulate sampling params and other options to pass to ollama
|
||||
options = self.get_ollama_chat_options(request)
|
||||
ollama_model = self.resolve_ollama_model(request.model)
|
||||
|
|
|
@ -4,21 +4,26 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import TGIImplConfig
|
||||
from .tgi import InferenceEndpointAdapter, TGIAdapter
|
||||
from typing import Union
|
||||
|
||||
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
|
||||
from .tgi import InferenceAPIAdapter, InferenceEndpointAdapter, TGIAdapter
|
||||
|
||||
|
||||
async def get_adapter_impl(config: TGIImplConfig, _deps):
|
||||
assert isinstance(config, TGIImplConfig), f"Unexpected config type: {type(config)}"
|
||||
|
||||
if config.url is not None:
|
||||
impl = TGIAdapter(config)
|
||||
elif config.is_inference_endpoint():
|
||||
impl = InferenceEndpointAdapter(config)
|
||||
async def get_adapter_impl(
|
||||
config: Union[InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig],
|
||||
_deps,
|
||||
):
|
||||
if isinstance(config, TGIImplConfig):
|
||||
impl = TGIAdapter()
|
||||
elif isinstance(config, InferenceAPIImplConfig):
|
||||
impl = InferenceAPIAdapter()
|
||||
elif isinstance(config, InferenceEndpointImplConfig):
|
||||
impl = InferenceEndpointAdapter()
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid configuration. Specify either an URL or HF Inference Endpoint details (namespace and endpoint name)."
|
||||
f"Invalid configuration. Expected 'TGIAdapter', 'InferenceAPIImplConfig' or 'InferenceEndpointImplConfig'. Got {type(config)}."
|
||||
)
|
||||
|
||||
await impl.initialize()
|
||||
await impl.initialize(config)
|
||||
return impl
|
||||
|
|
|
@ -12,18 +12,32 @@ from pydantic import BaseModel, Field
|
|||
|
||||
@json_schema_type
|
||||
class TGIImplConfig(BaseModel):
|
||||
url: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The URL for the local TGI endpoint (e.g., http://localhost:8080)",
|
||||
url: str = Field(
|
||||
description="The URL for the TGI endpoint (e.g. 'http://localhost:8080')",
|
||||
)
|
||||
api_token: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The HF token for Hugging Face Inference Endpoints (will default to locally saved token if not provided)",
|
||||
)
|
||||
hf_endpoint_name: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The name of the Hugging Face Inference Endpoint : can be either in the format of '{namespace}/{endpoint_name}' (namespace can be the username or organization name) or just '{endpoint_name}' if logged into the same account as the namespace",
|
||||
description="A bearer token if your TGI endpoint is protected.",
|
||||
)
|
||||
|
||||
def is_inference_endpoint(self) -> bool:
|
||||
return self.hf_endpoint_name is not None
|
||||
|
||||
@json_schema_type
|
||||
class InferenceEndpointImplConfig(BaseModel):
|
||||
endpoint_name: str = Field(
|
||||
description="The name of the Hugging Face Inference Endpoint in the format of '{namespace}/{endpoint_name}' (e.g. 'my-cool-org/meta-llama-3-1-8b-instruct-rce'). Namespace is optional and will default to the user account if not provided.",
|
||||
)
|
||||
api_token: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Your Hugging Face user access token (will default to locally saved token if not provided)",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class InferenceAPIImplConfig(BaseModel):
|
||||
model_id: str = Field(
|
||||
description="The model ID of the model on the Hugging Face Hub (e.g. 'meta-llama/Meta-Llama-3.1-70B-Instruct')",
|
||||
)
|
||||
api_token: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Your Hugging Face user access token (will default to locally saved token if not provided)",
|
||||
)
|
||||
|
|
|
@ -5,52 +5,33 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import Any, AsyncGenerator, Dict
|
||||
import logging
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import requests
|
||||
|
||||
from huggingface_hub import HfApi, InferenceClient
|
||||
from huggingface_hub import AsyncInferenceClient, HfApi
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.datatypes import StopReason
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.utils.inference.prepare_messages import prepare_messages
|
||||
from llama_stack.providers.utils.inference.augment_messages import (
|
||||
augment_messages_for_tools,
|
||||
)
|
||||
|
||||
from .config import TGIImplConfig
|
||||
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TGIAdapter(Inference):
|
||||
def __init__(self, config: TGIImplConfig) -> None:
|
||||
self.config = config
|
||||
class _HfAdapter(Inference):
|
||||
client: AsyncInferenceClient
|
||||
max_tokens: int
|
||||
model_id: str
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(self.tokenizer)
|
||||
|
||||
@property
|
||||
def client(self) -> InferenceClient:
|
||||
return InferenceClient(model=self.config.url, token=self.config.api_token)
|
||||
|
||||
def _get_endpoint_info(self) -> Dict[str, Any]:
|
||||
return {
|
||||
**self.client.get_endpoint_info(),
|
||||
"inference_url": self.config.url,
|
||||
}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
try:
|
||||
info = self._get_endpoint_info()
|
||||
if "model_id" not in info:
|
||||
raise RuntimeError("Missing model_id in model info")
|
||||
if "max_total_tokens" not in info:
|
||||
raise RuntimeError("Missing max_total_tokens in model info")
|
||||
self.max_tokens = info["max_total_tokens"]
|
||||
|
||||
self.inference_url = info["inference_url"]
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise RuntimeError(f"Error initializing TGIAdapter: {e}") from e
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
|
@ -95,7 +76,7 @@ class TGIAdapter(Inference):
|
|||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
messages = prepare_messages(request)
|
||||
messages = augment_messages_for_tools(request)
|
||||
model_input = self.formatter.encode_dialog_prompt(messages)
|
||||
prompt = self.tokenizer.decode(model_input.tokens)
|
||||
|
||||
|
@ -109,7 +90,7 @@ class TGIAdapter(Inference):
|
|||
|
||||
options = self.get_chat_options(request)
|
||||
if not request.stream:
|
||||
response = self.client.text_generation(
|
||||
response = await self.client.text_generation(
|
||||
prompt=prompt,
|
||||
stream=False,
|
||||
details=True,
|
||||
|
@ -145,7 +126,7 @@ class TGIAdapter(Inference):
|
|||
stop_reason = None
|
||||
tokens = []
|
||||
|
||||
for response in self.client.text_generation(
|
||||
async for response in await self.client.text_generation(
|
||||
prompt=prompt,
|
||||
stream=True,
|
||||
details=True,
|
||||
|
@ -237,46 +218,36 @@ class TGIAdapter(Inference):
|
|||
)
|
||||
|
||||
|
||||
class InferenceEndpointAdapter(TGIAdapter):
|
||||
def __init__(self, config: TGIImplConfig) -> None:
|
||||
super().__init__(config)
|
||||
self.config.url = self._construct_endpoint_url()
|
||||
class TGIAdapter(_HfAdapter):
|
||||
async def initialize(self, config: TGIImplConfig) -> None:
|
||||
self.client = AsyncInferenceClient(model=config.url, token=config.api_token)
|
||||
endpoint_info = await self.client.get_endpoint_info()
|
||||
self.max_tokens = endpoint_info["max_total_tokens"]
|
||||
self.model_id = endpoint_info["model_id"]
|
||||
|
||||
def _construct_endpoint_url(self) -> str:
|
||||
hf_endpoint_name = self.config.hf_endpoint_name
|
||||
assert hf_endpoint_name.count("/") <= 1, (
|
||||
"Endpoint name must be in the format of 'namespace/endpoint_name' "
|
||||
"or 'endpoint_name'"
|
||||
|
||||
class InferenceAPIAdapter(_HfAdapter):
|
||||
async def initialize(self, config: InferenceAPIImplConfig) -> None:
|
||||
self.client = AsyncInferenceClient(
|
||||
model=config.model_id, token=config.api_token
|
||||
)
|
||||
if "/" not in hf_endpoint_name:
|
||||
hf_namespace: str = self.get_namespace()
|
||||
endpoint_path = f"{hf_namespace}/{hf_endpoint_name}"
|
||||
else:
|
||||
endpoint_path = hf_endpoint_name
|
||||
return f"https://api.endpoints.huggingface.cloud/v2/endpoint/{endpoint_path}"
|
||||
endpoint_info = await self.client.get_endpoint_info()
|
||||
self.max_tokens = endpoint_info["max_total_tokens"]
|
||||
self.model_id = endpoint_info["model_id"]
|
||||
|
||||
def get_namespace(self) -> str:
|
||||
return HfApi().whoami()["name"]
|
||||
|
||||
@property
|
||||
def client(self) -> InferenceClient:
|
||||
return InferenceClient(model=self.inference_url, token=self.config.api_token)
|
||||
class InferenceEndpointAdapter(_HfAdapter):
|
||||
async def initialize(self, config: InferenceEndpointImplConfig) -> None:
|
||||
# Get the inference endpoint details
|
||||
api = HfApi(token=config.api_token)
|
||||
endpoint = api.get_inference_endpoint(config.endpoint_name)
|
||||
|
||||
def _get_endpoint_info(self) -> Dict[str, Any]:
|
||||
headers = {
|
||||
"accept": "application/json",
|
||||
"authorization": f"Bearer {self.config.api_token}",
|
||||
}
|
||||
response = requests.get(self.config.url, headers=headers)
|
||||
response.raise_for_status()
|
||||
endpoint_info = response.json()
|
||||
return {
|
||||
"inference_url": endpoint_info["status"]["url"],
|
||||
"model_id": endpoint_info["model"]["repository"],
|
||||
"max_total_tokens": int(
|
||||
endpoint_info["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"]
|
||||
),
|
||||
}
|
||||
# Wait for the endpoint to be ready (if not already)
|
||||
endpoint.wait(timeout=60)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
await super().initialize()
|
||||
# Initialize the adapter
|
||||
self.client = endpoint.async_client
|
||||
self.model_id = endpoint.repository
|
||||
self.max_tokens = int(
|
||||
endpoint.raw["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"]
|
||||
)
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import TogetherImplConfig, TogetherHeaderExtractor
|
||||
from .config import TogetherImplConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: TogetherImplConfig, _deps):
|
||||
|
|
|
@ -4,17 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
|
||||
from llama_stack.distribution.request_headers import annotate_header
|
||||
|
||||
|
||||
class TogetherHeaderExtractor(BaseModel):
|
||||
api_key: annotate_header(
|
||||
"X-LlamaStack-Together-ApiKey", str, "The API Key for the request"
|
||||
)
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -15,14 +15,20 @@ from llama_models.sku_list import resolve_model
|
|||
from together import Together
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.utils.inference.prepare_messages import prepare_messages
|
||||
from llama_stack.distribution.request_headers import get_request_provider_data
|
||||
from llama_stack.providers.utils.inference.augment_messages import (
|
||||
augment_messages_for_tools,
|
||||
)
|
||||
|
||||
from .config import TogetherImplConfig
|
||||
|
||||
TOGETHER_SUPPORTED_MODELS = {
|
||||
"Meta-Llama3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
|
||||
"Meta-Llama3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
|
||||
"Meta-Llama3.1-405B-Instruct": "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
|
||||
"Llama3.1-8B-Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
|
||||
"Llama3.1-70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
|
||||
"Llama3.1-405B-Instruct": "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
|
||||
"Llama3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct-Turbo",
|
||||
"Llama3.2-11B-Vision-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
|
||||
"Llama3.2-90B-Vision-Instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo",
|
||||
}
|
||||
|
||||
|
||||
|
@ -95,6 +101,16 @@ class TogetherInferenceAdapter(Inference):
|
|||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
|
||||
together_api_key = None
|
||||
provider_data = get_request_provider_data()
|
||||
if provider_data is None or not provider_data.together_api_key:
|
||||
raise ValueError(
|
||||
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
|
||||
)
|
||||
together_api_key = provider_data.together_api_key
|
||||
|
||||
client = Together(api_key=together_api_key)
|
||||
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
||||
request = ChatCompletionRequest(
|
||||
model=model,
|
||||
|
@ -110,11 +126,11 @@ class TogetherInferenceAdapter(Inference):
|
|||
# accumulate sampling params and other options to pass to together
|
||||
options = self.get_together_chat_options(request)
|
||||
together_model = self.resolve_together_model(request.model)
|
||||
messages = prepare_messages(request)
|
||||
messages = augment_messages_for_tools(request)
|
||||
|
||||
if not request.stream:
|
||||
# TODO: might need to add back an async here
|
||||
r = self.client.chat.completions.create(
|
||||
r = client.chat.completions.create(
|
||||
model=together_model,
|
||||
messages=self._messages_to_together_messages(messages),
|
||||
stream=False,
|
||||
|
@ -149,7 +165,7 @@ class TogetherInferenceAdapter(Inference):
|
|||
ipython = False
|
||||
stop_reason = None
|
||||
|
||||
for chunk in self.client.chat.completions.create(
|
||||
for chunk in client.chat.completions.create(
|
||||
model=together_model,
|
||||
messages=self._messages_to_together_messages(messages),
|
||||
stream=True,
|
||||
|
|
18
llama_stack/providers/adapters/safety/bedrock/__init__.py
Normal file
18
llama_stack/providers/adapters/safety/bedrock/__init__.py
Normal file
|
@ -0,0 +1,18 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .config import BedrockSafetyConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: BedrockSafetyConfig, _deps) -> Any:
|
||||
from .bedrock import BedrockSafetyAdapter
|
||||
|
||||
impl = BedrockSafetyAdapter(config)
|
||||
await impl.initialize()
|
||||
return impl
|
109
llama_stack/providers/adapters/safety/bedrock/bedrock.py
Normal file
109
llama_stack/providers/adapters/safety/bedrock/bedrock.py
Normal file
|
@ -0,0 +1,109 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import traceback
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from .config import BedrockSafetyConfig
|
||||
from llama_stack.apis.safety import * # noqa
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
import json
|
||||
import logging
|
||||
|
||||
import boto3
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BedrockSafetyAdapter(Safety):
|
||||
def __init__(self, config: BedrockSafetyConfig) -> None:
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
if not self.config.aws_profile:
|
||||
raise RuntimeError(
|
||||
f"Missing boto_client aws_profile in model info::{self.config}"
|
||||
)
|
||||
|
||||
try:
|
||||
print(f"initializing with profile --- > {self.config}::")
|
||||
self.boto_client_profile = self.config.aws_profile
|
||||
self.boto_client = boto3.Session(
|
||||
profile_name=self.boto_client_profile
|
||||
).client("bedrock-runtime")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error initializing BedrockSafetyAdapter: {e}") from e
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def run_shield(
|
||||
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
||||
) -> RunShieldResponse:
|
||||
"""This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
|
||||
```content = [
|
||||
{
|
||||
"text": {
|
||||
"text": "Is the AB503 Product a better investment than the S&P 500?"
|
||||
}
|
||||
}
|
||||
]```
|
||||
However the incoming messages are of this type UserMessage(content=....) coming from
|
||||
https://github.com/meta-llama/llama-models/blob/main/models/llama3/api/datatypes.py
|
||||
|
||||
They contain content, role . For now we will extract the content and default the "qualifiers": ["query"]
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"run_shield::{params}::messages={messages}")
|
||||
if "guardrailIdentifier" not in params:
|
||||
raise RuntimeError(
|
||||
"Error running request for BedrockGaurdrails:Missing GuardrailID in request"
|
||||
)
|
||||
|
||||
if "guardrailVersion" not in params:
|
||||
raise RuntimeError(
|
||||
"Error running request for BedrockGaurdrails:Missing guardrailVersion in request"
|
||||
)
|
||||
|
||||
# - convert the messages into format Bedrock expects
|
||||
content_messages = []
|
||||
for message in messages:
|
||||
content_messages.append({"text": {"text": message.content}})
|
||||
logger.debug(
|
||||
f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:"
|
||||
)
|
||||
|
||||
response = self.boto_client.apply_guardrail(
|
||||
guardrailIdentifier=params.get("guardrailIdentifier"),
|
||||
guardrailVersion=params.get("guardrailVersion"),
|
||||
source="OUTPUT", # or 'INPUT' depending on your use case
|
||||
content=content_messages,
|
||||
)
|
||||
logger.debug(f"run_shield:: response: {response}::")
|
||||
if response["action"] == "GUARDRAIL_INTERVENED":
|
||||
user_message = ""
|
||||
metadata = {}
|
||||
for output in response["outputs"]:
|
||||
# guardrails returns a list - however for this implementation we will leverage the last values
|
||||
user_message = output["text"]
|
||||
for assessment in response["assessments"]:
|
||||
# guardrails returns a list - however for this implementation we will leverage the last values
|
||||
metadata = dict(assessment)
|
||||
return SafetyViolation(
|
||||
user_message=user_message,
|
||||
violation_level=ViolationLevel.ERROR,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
error_str = traceback.format_exc()
|
||||
logger.error(
|
||||
f"Error in apply_guardrails:{error_str}:: RETURNING None !!!!!"
|
||||
)
|
||||
|
||||
return None
|
16
llama_stack/providers/adapters/safety/bedrock/config.py
Normal file
16
llama_stack/providers/adapters/safety/bedrock/config.py
Normal file
|
@ -0,0 +1,16 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class BedrockSafetyConfig(BaseModel):
|
||||
"""Configuration information for a guardrail that you want to use in the request."""
|
||||
|
||||
aws_profile: str = Field(
|
||||
default="default",
|
||||
description="The profile on the machine having valid aws credentials. This will ensure separation of creation to invocation",
|
||||
)
|
18
llama_stack/providers/adapters/safety/together/__init__.py
Normal file
18
llama_stack/providers/adapters/safety/together/__init__.py
Normal file
|
@ -0,0 +1,18 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import TogetherProviderDataValidator, TogetherSafetyConfig # noqa: F401
|
||||
|
||||
|
||||
async def get_adapter_impl(config: TogetherSafetyConfig, _deps):
|
||||
from .together import TogetherSafetyImpl
|
||||
|
||||
assert isinstance(
|
||||
config, TogetherSafetyConfig
|
||||
), f"Unexpected config type: {type(config)}"
|
||||
impl = TogetherSafetyImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
26
llama_stack/providers/adapters/safety/together/config.py
Normal file
26
llama_stack/providers/adapters/safety/together/config.py
Normal file
|
@ -0,0 +1,26 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class TogetherProviderDataValidator(BaseModel):
|
||||
together_api_key: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class TogetherSafetyConfig(BaseModel):
|
||||
url: str = Field(
|
||||
default="https://api.together.xyz/v1",
|
||||
description="The URL for the Together AI server",
|
||||
)
|
||||
api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The Together AI API Key (default for the distribution, if any)",
|
||||
)
|
99
llama_stack/providers/adapters/safety/together/together.py
Normal file
99
llama_stack/providers/adapters/safety/together/together.py
Normal file
|
@ -0,0 +1,99 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from llama_models.sku_list import resolve_model
|
||||
from together import Together
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.safety import (
|
||||
RunShieldResponse,
|
||||
Safety,
|
||||
SafetyViolation,
|
||||
ViolationLevel,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import get_request_provider_data
|
||||
|
||||
from .config import TogetherSafetyConfig
|
||||
|
||||
SAFETY_SHIELD_TYPES = {
|
||||
"Llama-Guard-3-8B": "meta-llama/Meta-Llama-Guard-3-8B",
|
||||
"Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision-Turbo",
|
||||
}
|
||||
|
||||
|
||||
def shield_type_to_model_name(shield_type: str) -> str:
|
||||
if shield_type == "llama_guard":
|
||||
shield_type = "Llama-Guard-3-8B"
|
||||
|
||||
model = resolve_model(shield_type)
|
||||
if (
|
||||
model is None
|
||||
or not model.descriptor(shorten_default_variant=True) in SAFETY_SHIELD_TYPES
|
||||
or model.model_family is not ModelFamily.safety
|
||||
):
|
||||
raise ValueError(
|
||||
f"{shield_type} is not supported, please use of {','.join(SAFETY_SHIELD_TYPES.keys())}"
|
||||
)
|
||||
|
||||
return SAFETY_SHIELD_TYPES.get(model.descriptor(shorten_default_variant=True))
|
||||
|
||||
|
||||
class TogetherSafetyImpl(Safety):
|
||||
def __init__(self, config: TogetherSafetyConfig) -> None:
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def run_shield(
|
||||
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
||||
) -> RunShieldResponse:
|
||||
|
||||
together_api_key = None
|
||||
provider_data = get_request_provider_data()
|
||||
if provider_data is None or not provider_data.together_api_key:
|
||||
raise ValueError(
|
||||
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
|
||||
)
|
||||
together_api_key = provider_data.together_api_key
|
||||
|
||||
model_name = shield_type_to_model_name(shield_type)
|
||||
|
||||
# messages can have role assistant or user
|
||||
api_messages = []
|
||||
for message in messages:
|
||||
if message.role in (Role.user.value, Role.assistant.value):
|
||||
api_messages.append({"role": message.role, "content": message.content})
|
||||
|
||||
violation = await get_safety_response(
|
||||
together_api_key, model_name, api_messages
|
||||
)
|
||||
return RunShieldResponse(violation=violation)
|
||||
|
||||
|
||||
async def get_safety_response(
|
||||
api_key: str, model_name: str, messages: List[Dict[str, str]]
|
||||
) -> Optional[SafetyViolation]:
|
||||
client = Together(api_key=api_key)
|
||||
response = client.chat.completions.create(messages=messages, model=model_name)
|
||||
if len(response.choices) == 0:
|
||||
return None
|
||||
|
||||
response_text = response.choices[0].message.content
|
||||
if response_text == "safe":
|
||||
return None
|
||||
|
||||
parts = response_text.split("\n")
|
||||
if len(parts) != 2:
|
||||
return None
|
||||
|
||||
if parts[0] == "unsafe":
|
||||
return SafetyViolation(
|
||||
violation_level=ViolationLevel.ERROR,
|
||||
user_message="unsafe",
|
||||
metadata={"violation_type": parts[1]},
|
||||
)
|
||||
|
||||
return None
|
Loading…
Add table
Add a link
Reference in a new issue