forked from phoenix-oss/llama-stack-mirror
Make bedrock "just" work
This commit is contained in:
parent
d6fcdefec7
commit
c39a3777b5
3 changed files with 75 additions and 325 deletions
|
@ -35,6 +35,8 @@ The following models are available by default:
|
||||||
- `meta-llama/Llama-3.1-8B-Instruct (meta.llama3-1-8b-instruct-v1:0)`
|
- `meta-llama/Llama-3.1-8B-Instruct (meta.llama3-1-8b-instruct-v1:0)`
|
||||||
- `meta-llama/Llama-3.1-70B-Instruct (meta.llama3-1-70b-instruct-v1:0)`
|
- `meta-llama/Llama-3.1-70B-Instruct (meta.llama3-1-70b-instruct-v1:0)`
|
||||||
- `meta-llama/Llama-3.1-405B-Instruct-FP8 (meta.llama3-1-405b-instruct-v1:0)`
|
- `meta-llama/Llama-3.1-405B-Instruct-FP8 (meta.llama3-1-405b-instruct-v1:0)`
|
||||||
|
- `meta-llama/Llama-3.2-3B-Instruct (meta.llama3-2-3b-instruct-v1:0)`
|
||||||
|
- `meta-llama/Llama-3.2-1B-Instruct (meta.llama3-2-1b-instruct-v1:0)`
|
||||||
|
|
||||||
|
|
||||||
### Prerequisite: API Keys
|
### Prerequisite: API Keys
|
||||||
|
|
|
@ -6,20 +6,25 @@
|
||||||
|
|
||||||
from typing import * # noqa: F403
|
from typing import * # noqa: F403
|
||||||
import json
|
import json
|
||||||
import uuid
|
|
||||||
|
|
||||||
from botocore.client import BaseClient
|
from botocore.client import BaseClient
|
||||||
from llama_models.datatypes import CoreModelId
|
from llama_models.datatypes import CoreModelId
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import ToolParamDefinition
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
build_model_alias,
|
build_model_alias,
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
)
|
)
|
||||||
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
|
OpenAICompatCompletionChoice,
|
||||||
|
OpenAICompatCompletionResponse,
|
||||||
|
process_chat_completion_response,
|
||||||
|
process_chat_completion_stream_response,
|
||||||
|
)
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
chat_completion_request_to_prompt,
|
||||||
content_has_media,
|
content_has_media,
|
||||||
interleaved_content_as_str,
|
interleaved_content_as_str,
|
||||||
)
|
)
|
||||||
|
@ -43,10 +48,17 @@ MODEL_ALIASES = [
|
||||||
"meta.llama3-1-405b-instruct-v1:0",
|
"meta.llama3-1-405b-instruct-v1:0",
|
||||||
CoreModelId.llama3_1_405b_instruct.value,
|
CoreModelId.llama3_1_405b_instruct.value,
|
||||||
),
|
),
|
||||||
|
build_model_alias(
|
||||||
|
"meta.llama3-2-3b-instruct-v1:0",
|
||||||
|
CoreModelId.llama3_2_3b_instruct.value,
|
||||||
|
),
|
||||||
|
build_model_alias(
|
||||||
|
"meta.llama3-2-1b-instruct-v1:0",
|
||||||
|
CoreModelId.llama3_2_1b_instruct.value,
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
# NOTE: this is not quite tested after the recent refactors
|
|
||||||
class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
def __init__(self, config: BedrockConfig) -> None:
|
def __init__(self, config: BedrockConfig) -> None:
|
||||||
ModelRegistryHelper.__init__(self, MODEL_ALIASES)
|
ModelRegistryHelper.__init__(self, MODEL_ALIASES)
|
||||||
|
@ -76,232 +88,6 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _bedrock_stop_reason_to_stop_reason(bedrock_stop_reason: str) -> StopReason:
|
|
||||||
if bedrock_stop_reason == "max_tokens":
|
|
||||||
return StopReason.out_of_tokens
|
|
||||||
return StopReason.end_of_turn
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _builtin_tool_name_to_enum(tool_name_str: str) -> Union[BuiltinTool, str]:
|
|
||||||
for builtin_tool in BuiltinTool:
|
|
||||||
if builtin_tool.value == tool_name_str:
|
|
||||||
return builtin_tool
|
|
||||||
else:
|
|
||||||
return tool_name_str
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _bedrock_message_to_message(converse_api_res: Dict) -> Message:
|
|
||||||
stop_reason = BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
|
|
||||||
converse_api_res["stopReason"]
|
|
||||||
)
|
|
||||||
|
|
||||||
bedrock_message = converse_api_res["output"]["message"]
|
|
||||||
|
|
||||||
role = bedrock_message["role"]
|
|
||||||
contents = bedrock_message["content"]
|
|
||||||
|
|
||||||
tool_calls = []
|
|
||||||
text_content = ""
|
|
||||||
for content in contents:
|
|
||||||
if "toolUse" in content:
|
|
||||||
tool_use = content["toolUse"]
|
|
||||||
tool_calls.append(
|
|
||||||
ToolCall(
|
|
||||||
tool_name=BedrockInferenceAdapter._builtin_tool_name_to_enum(
|
|
||||||
tool_use["name"]
|
|
||||||
),
|
|
||||||
arguments=tool_use["input"] if "input" in tool_use else None,
|
|
||||||
call_id=tool_use["toolUseId"],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif "text" in content:
|
|
||||||
text_content += content["text"]
|
|
||||||
|
|
||||||
return CompletionMessage(
|
|
||||||
role=role,
|
|
||||||
content=text_content,
|
|
||||||
stop_reason=stop_reason,
|
|
||||||
tool_calls=tool_calls,
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _messages_to_bedrock_messages(
|
|
||||||
messages: List[Message],
|
|
||||||
) -> Tuple[List[Dict], Optional[List[Dict]]]:
|
|
||||||
bedrock_messages = []
|
|
||||||
system_bedrock_messages = []
|
|
||||||
|
|
||||||
user_contents = []
|
|
||||||
assistant_contents = None
|
|
||||||
for message in messages:
|
|
||||||
role = message.role
|
|
||||||
content_list = (
|
|
||||||
message.content
|
|
||||||
if isinstance(message.content, list)
|
|
||||||
else [message.content]
|
|
||||||
)
|
|
||||||
if role == "ipython" or role == "user":
|
|
||||||
if not user_contents:
|
|
||||||
user_contents = []
|
|
||||||
|
|
||||||
if role == "ipython":
|
|
||||||
user_contents.extend(
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"toolResult": {
|
|
||||||
"toolUseId": message.call_id or str(uuid.uuid4()),
|
|
||||||
"content": [
|
|
||||||
{"text": content} for content in content_list
|
|
||||||
],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
user_contents.extend(
|
|
||||||
[{"text": content} for content in content_list]
|
|
||||||
)
|
|
||||||
|
|
||||||
if assistant_contents:
|
|
||||||
bedrock_messages.append(
|
|
||||||
{"role": "assistant", "content": assistant_contents}
|
|
||||||
)
|
|
||||||
assistant_contents = None
|
|
||||||
elif role == "system":
|
|
||||||
system_bedrock_messages.extend(
|
|
||||||
[{"text": content} for content in content_list]
|
|
||||||
)
|
|
||||||
elif role == "assistant":
|
|
||||||
if not assistant_contents:
|
|
||||||
assistant_contents = []
|
|
||||||
|
|
||||||
assistant_contents.extend(
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"text": content,
|
|
||||||
}
|
|
||||||
for content in content_list
|
|
||||||
]
|
|
||||||
+ [
|
|
||||||
{
|
|
||||||
"toolUse": {
|
|
||||||
"input": tool_call.arguments,
|
|
||||||
"name": (
|
|
||||||
tool_call.tool_name
|
|
||||||
if isinstance(tool_call.tool_name, str)
|
|
||||||
else tool_call.tool_name.value
|
|
||||||
),
|
|
||||||
"toolUseId": tool_call.call_id,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for tool_call in message.tool_calls
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
if user_contents:
|
|
||||||
bedrock_messages.append({"role": "user", "content": user_contents})
|
|
||||||
user_contents = None
|
|
||||||
else:
|
|
||||||
# Unknown role
|
|
||||||
pass
|
|
||||||
|
|
||||||
if user_contents:
|
|
||||||
bedrock_messages.append({"role": "user", "content": user_contents})
|
|
||||||
if assistant_contents:
|
|
||||||
bedrock_messages.append(
|
|
||||||
{"role": "assistant", "content": assistant_contents}
|
|
||||||
)
|
|
||||||
|
|
||||||
if system_bedrock_messages:
|
|
||||||
return bedrock_messages, system_bedrock_messages
|
|
||||||
|
|
||||||
return bedrock_messages, None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_bedrock_inference_config(sampling_params: Optional[SamplingParams]) -> Dict:
|
|
||||||
inference_config = {}
|
|
||||||
if sampling_params:
|
|
||||||
param_mapping = {
|
|
||||||
"max_tokens": "maxTokens",
|
|
||||||
"temperature": "temperature",
|
|
||||||
"top_p": "topP",
|
|
||||||
}
|
|
||||||
|
|
||||||
for k, v in param_mapping.items():
|
|
||||||
if getattr(sampling_params, k):
|
|
||||||
inference_config[v] = getattr(sampling_params, k)
|
|
||||||
|
|
||||||
return inference_config
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _tool_parameters_to_input_schema(
|
|
||||||
tool_parameters: Optional[Dict[str, ToolParamDefinition]],
|
|
||||||
) -> Dict:
|
|
||||||
input_schema = {"type": "object"}
|
|
||||||
if not tool_parameters:
|
|
||||||
return input_schema
|
|
||||||
|
|
||||||
json_properties = {}
|
|
||||||
required = []
|
|
||||||
for name, param in tool_parameters.items():
|
|
||||||
json_property = {
|
|
||||||
"type": param.param_type,
|
|
||||||
}
|
|
||||||
|
|
||||||
if param.description:
|
|
||||||
json_property["description"] = param.description
|
|
||||||
if param.required:
|
|
||||||
required.append(name)
|
|
||||||
json_properties[name] = json_property
|
|
||||||
|
|
||||||
input_schema["properties"] = json_properties
|
|
||||||
if required:
|
|
||||||
input_schema["required"] = required
|
|
||||||
return input_schema
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _tools_to_tool_config(
|
|
||||||
tools: Optional[List[ToolDefinition]], tool_choice: Optional[ToolChoice]
|
|
||||||
) -> Optional[Dict]:
|
|
||||||
if not tools:
|
|
||||||
return None
|
|
||||||
|
|
||||||
bedrock_tools = []
|
|
||||||
for tool in tools:
|
|
||||||
tool_name = (
|
|
||||||
tool.tool_name
|
|
||||||
if isinstance(tool.tool_name, str)
|
|
||||||
else tool.tool_name.value
|
|
||||||
)
|
|
||||||
|
|
||||||
tool_spec = {
|
|
||||||
"toolSpec": {
|
|
||||||
"name": tool_name,
|
|
||||||
"inputSchema": {
|
|
||||||
"json": BedrockInferenceAdapter._tool_parameters_to_input_schema(
|
|
||||||
tool.parameters
|
|
||||||
),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if tool.description:
|
|
||||||
tool_spec["toolSpec"]["description"] = tool.description
|
|
||||||
|
|
||||||
bedrock_tools.append(tool_spec)
|
|
||||||
tool_config = {
|
|
||||||
"tools": bedrock_tools,
|
|
||||||
}
|
|
||||||
|
|
||||||
if tool_choice:
|
|
||||||
tool_config["toolChoice"] = (
|
|
||||||
{"any": {}}
|
|
||||||
if tool_choice.value == ToolChoice.required
|
|
||||||
else {"auto": {}}
|
|
||||||
)
|
|
||||||
return tool_config
|
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
@ -337,118 +123,70 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
async def _nonstream_chat_completion(
|
async def _nonstream_chat_completion(
|
||||||
self, request: ChatCompletionRequest
|
self, request: ChatCompletionRequest
|
||||||
) -> ChatCompletionResponse:
|
) -> ChatCompletionResponse:
|
||||||
params = self._get_params_for_chat_completion(request)
|
params = await self._get_params_for_chat_completion(request)
|
||||||
converse_api_res = self.client.converse(**params)
|
res = self.client.invoke_model(**params)
|
||||||
|
chunk = next(res["body"])
|
||||||
|
result = json.loads(chunk.decode("utf-8"))
|
||||||
|
|
||||||
output_message = BedrockInferenceAdapter._bedrock_message_to_message(
|
choice = OpenAICompatCompletionChoice(
|
||||||
converse_api_res
|
finish_reason=result["stop_reason"],
|
||||||
|
text=result["generation"],
|
||||||
)
|
)
|
||||||
|
|
||||||
return ChatCompletionResponse(
|
response = OpenAICompatCompletionResponse(choices=[choice])
|
||||||
completion_message=output_message,
|
return process_chat_completion_response(response, self.formatter)
|
||||||
logprobs=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _stream_chat_completion(
|
async def _stream_chat_completion(
|
||||||
self, request: ChatCompletionRequest
|
self, request: ChatCompletionRequest
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
params = self._get_params_for_chat_completion(request)
|
params = await self._get_params_for_chat_completion(request)
|
||||||
converse_stream_api_res = self.client.converse_stream(**params)
|
res = self.client.invoke_model_with_response_stream(**params)
|
||||||
event_stream = converse_stream_api_res["stream"]
|
event_stream = res["body"]
|
||||||
|
|
||||||
for chunk in event_stream:
|
async def _generate_and_convert_to_openai_compat():
|
||||||
if "messageStart" in chunk:
|
for chunk in event_stream:
|
||||||
yield ChatCompletionResponseStreamChunk(
|
chunk = chunk["chunk"]["bytes"]
|
||||||
event=ChatCompletionResponseEvent(
|
result = json.loads(chunk.decode("utf-8"))
|
||||||
event_type=ChatCompletionResponseEventType.start,
|
choice = OpenAICompatCompletionChoice(
|
||||||
delta="",
|
finish_reason=result["stop_reason"],
|
||||||
)
|
text=result["generation"],
|
||||||
)
|
)
|
||||||
elif "contentBlockStart" in chunk:
|
yield OpenAICompatCompletionResponse(choices=[choice])
|
||||||
yield ChatCompletionResponseStreamChunk(
|
|
||||||
event=ChatCompletionResponseEvent(
|
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
|
||||||
delta=ToolCallDelta(
|
|
||||||
content=ToolCall(
|
|
||||||
tool_name=chunk["contentBlockStart"]["toolUse"]["name"],
|
|
||||||
call_id=chunk["contentBlockStart"]["toolUse"][
|
|
||||||
"toolUseId"
|
|
||||||
],
|
|
||||||
),
|
|
||||||
parse_status=ToolCallParseStatus.started,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif "contentBlockDelta" in chunk:
|
|
||||||
if "text" in chunk["contentBlockDelta"]["delta"]:
|
|
||||||
delta = chunk["contentBlockDelta"]["delta"]["text"]
|
|
||||||
else:
|
|
||||||
delta = ToolCallDelta(
|
|
||||||
content=ToolCall(
|
|
||||||
arguments=chunk["contentBlockDelta"]["delta"]["toolUse"][
|
|
||||||
"input"
|
|
||||||
]
|
|
||||||
),
|
|
||||||
parse_status=ToolCallParseStatus.success,
|
|
||||||
)
|
|
||||||
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
stream = _generate_and_convert_to_openai_compat()
|
||||||
event=ChatCompletionResponseEvent(
|
async for chunk in process_chat_completion_stream_response(
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
stream, self.formatter
|
||||||
delta=delta,
|
):
|
||||||
)
|
yield chunk
|
||||||
)
|
|
||||||
elif "contentBlockStop" in chunk:
|
|
||||||
# Ignored
|
|
||||||
pass
|
|
||||||
elif "messageStop" in chunk:
|
|
||||||
stop_reason = (
|
|
||||||
BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
|
|
||||||
chunk["messageStop"]["stopReason"]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
async def _get_params_for_chat_completion(
|
||||||
event=ChatCompletionResponseEvent(
|
self, request: ChatCompletionRequest
|
||||||
event_type=ChatCompletionResponseEventType.complete,
|
) -> Dict:
|
||||||
delta="",
|
|
||||||
stop_reason=stop_reason,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif "metadata" in chunk:
|
|
||||||
# Ignored
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
# Ignored
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> Dict:
|
|
||||||
bedrock_model = request.model
|
bedrock_model = request.model
|
||||||
inference_config = BedrockInferenceAdapter.get_bedrock_inference_config(
|
|
||||||
request.sampling_params
|
|
||||||
)
|
|
||||||
|
|
||||||
tool_config = BedrockInferenceAdapter._tools_to_tool_config(
|
inference_config = {}
|
||||||
request.tools, request.tool_choice
|
param_mapping = {
|
||||||
)
|
"max_tokens": "max_gen_len",
|
||||||
bedrock_messages, system_bedrock_messages = (
|
"temperature": "temperature",
|
||||||
BedrockInferenceAdapter._messages_to_bedrock_messages(request.messages)
|
"top_p": "top_p",
|
||||||
)
|
|
||||||
|
|
||||||
converse_api_params = {
|
|
||||||
"modelId": bedrock_model,
|
|
||||||
"messages": bedrock_messages,
|
|
||||||
}
|
}
|
||||||
if inference_config:
|
|
||||||
converse_api_params["inferenceConfig"] = inference_config
|
|
||||||
|
|
||||||
# Tool use is not supported in streaming mode
|
for k, v in param_mapping.items():
|
||||||
if tool_config and not request.stream:
|
if getattr(request.sampling_params, k):
|
||||||
converse_api_params["toolConfig"] = tool_config
|
inference_config[v] = getattr(request.sampling_params, k)
|
||||||
if system_bedrock_messages:
|
|
||||||
converse_api_params["system"] = system_bedrock_messages
|
|
||||||
|
|
||||||
return converse_api_params
|
prompt = await chat_completion_request_to_prompt(
|
||||||
|
request, self.get_llama_model(request.model), self.formatter
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"modelId": bedrock_model,
|
||||||
|
"body": json.dumps(
|
||||||
|
{
|
||||||
|
"prompt": prompt,
|
||||||
|
**inference_config,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -85,6 +85,16 @@ models:
|
||||||
provider_id: bedrock
|
provider_id: bedrock
|
||||||
provider_model_id: meta.llama3-1-405b-instruct-v1:0
|
provider_model_id: meta.llama3-1-405b-instruct-v1:0
|
||||||
model_type: llm
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.2-3B-Instruct
|
||||||
|
provider_id: bedrock
|
||||||
|
provider_model_id: meta.llama3-2-3b-instruct-v1:0
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.2-1B-Instruct
|
||||||
|
provider_id: bedrock
|
||||||
|
provider_model_id: meta.llama3-2-1b-instruct-v1:0
|
||||||
|
model_type: llm
|
||||||
shields: []
|
shields: []
|
||||||
memory_banks: []
|
memory_banks: []
|
||||||
datasets: []
|
datasets: []
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue