mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-29 19:34:19 +00:00
fix bedrock impl (#359)
* fix bedrock impl * fix linter errors * fix return type and remove debug print
This commit is contained in:
parent
bf4f97a2e1
commit
ac93dd89cf
1 changed files with 119 additions and 93 deletions
|
@ -55,7 +55,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
) -> AsyncGenerator:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -290,23 +290,130 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
# zero-shot tool definitions as input to the model
|
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> (
|
) -> Union[
|
||||||
AsyncGenerator
|
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
||||||
): # Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]:
|
]:
|
||||||
bedrock_model = self.map_to_provider_model(model)
|
request = ChatCompletionRequest(
|
||||||
inference_config = BedrockInferenceAdapter.get_bedrock_inference_config(
|
model=model,
|
||||||
sampling_params
|
messages=messages,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
tools=tools or [],
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
tool_prompt_format=tool_prompt_format,
|
||||||
|
response_format=response_format,
|
||||||
|
stream=stream,
|
||||||
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
tool_config = BedrockInferenceAdapter._tools_to_tool_config(tools, tool_choice)
|
if stream:
|
||||||
|
return self._stream_chat_completion(request)
|
||||||
|
else:
|
||||||
|
return await self._nonstream_chat_completion(request)
|
||||||
|
|
||||||
|
async def _nonstream_chat_completion(
|
||||||
|
self, request: ChatCompletionRequest
|
||||||
|
) -> ChatCompletionResponse:
|
||||||
|
params = self._get_params_for_chat_completion(request)
|
||||||
|
converse_api_res = self.client.converse(**params)
|
||||||
|
|
||||||
|
output_message = BedrockInferenceAdapter._bedrock_message_to_message(
|
||||||
|
converse_api_res
|
||||||
|
)
|
||||||
|
|
||||||
|
return ChatCompletionResponse(
|
||||||
|
completion_message=output_message,
|
||||||
|
logprobs=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _stream_chat_completion(
|
||||||
|
self, request: ChatCompletionRequest
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
params = self._get_params_for_chat_completion(request)
|
||||||
|
converse_stream_api_res = self.client.converse_stream(**params)
|
||||||
|
event_stream = converse_stream_api_res["stream"]
|
||||||
|
|
||||||
|
for chunk in event_stream:
|
||||||
|
if "messageStart" in chunk:
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=ChatCompletionResponseEventType.start,
|
||||||
|
delta="",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif "contentBlockStart" in chunk:
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
|
delta=ToolCallDelta(
|
||||||
|
content=ToolCall(
|
||||||
|
tool_name=chunk["contentBlockStart"]["toolUse"]["name"],
|
||||||
|
call_id=chunk["contentBlockStart"]["toolUse"][
|
||||||
|
"toolUseId"
|
||||||
|
],
|
||||||
|
),
|
||||||
|
parse_status=ToolCallParseStatus.started,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif "contentBlockDelta" in chunk:
|
||||||
|
if "text" in chunk["contentBlockDelta"]["delta"]:
|
||||||
|
delta = chunk["contentBlockDelta"]["delta"]["text"]
|
||||||
|
else:
|
||||||
|
delta = ToolCallDelta(
|
||||||
|
content=ToolCall(
|
||||||
|
arguments=chunk["contentBlockDelta"]["delta"]["toolUse"][
|
||||||
|
"input"
|
||||||
|
]
|
||||||
|
),
|
||||||
|
parse_status=ToolCallParseStatus.success,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
|
delta=delta,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif "contentBlockStop" in chunk:
|
||||||
|
# Ignored
|
||||||
|
pass
|
||||||
|
elif "messageStop" in chunk:
|
||||||
|
stop_reason = (
|
||||||
|
BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
|
||||||
|
chunk["messageStop"]["stopReason"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=ChatCompletionResponseEventType.complete,
|
||||||
|
delta="",
|
||||||
|
stop_reason=stop_reason,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif "metadata" in chunk:
|
||||||
|
# Ignored
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# Ignored
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> Dict:
|
||||||
|
bedrock_model = self.map_to_provider_model(request.model)
|
||||||
|
inference_config = BedrockInferenceAdapter.get_bedrock_inference_config(
|
||||||
|
request.sampling_params
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_config = BedrockInferenceAdapter._tools_to_tool_config(
|
||||||
|
request.tools, request.tool_choice
|
||||||
|
)
|
||||||
bedrock_messages, system_bedrock_messages = (
|
bedrock_messages, system_bedrock_messages = (
|
||||||
BedrockInferenceAdapter._messages_to_bedrock_messages(messages)
|
BedrockInferenceAdapter._messages_to_bedrock_messages(request.messages)
|
||||||
)
|
)
|
||||||
|
|
||||||
converse_api_params = {
|
converse_api_params = {
|
||||||
|
@ -317,93 +424,12 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
converse_api_params["inferenceConfig"] = inference_config
|
converse_api_params["inferenceConfig"] = inference_config
|
||||||
|
|
||||||
# Tool use is not supported in streaming mode
|
# Tool use is not supported in streaming mode
|
||||||
if tool_config and not stream:
|
if tool_config and not request.stream:
|
||||||
converse_api_params["toolConfig"] = tool_config
|
converse_api_params["toolConfig"] = tool_config
|
||||||
if system_bedrock_messages:
|
if system_bedrock_messages:
|
||||||
converse_api_params["system"] = system_bedrock_messages
|
converse_api_params["system"] = system_bedrock_messages
|
||||||
|
|
||||||
if not stream:
|
return converse_api_params
|
||||||
converse_api_res = self.client.converse(**converse_api_params)
|
|
||||||
|
|
||||||
output_message = BedrockInferenceAdapter._bedrock_message_to_message(
|
|
||||||
converse_api_res
|
|
||||||
)
|
|
||||||
|
|
||||||
yield ChatCompletionResponse(
|
|
||||||
completion_message=output_message,
|
|
||||||
logprobs=None,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
converse_stream_api_res = self.client.converse_stream(**converse_api_params)
|
|
||||||
event_stream = converse_stream_api_res["stream"]
|
|
||||||
|
|
||||||
for chunk in event_stream:
|
|
||||||
if "messageStart" in chunk:
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
|
||||||
event=ChatCompletionResponseEvent(
|
|
||||||
event_type=ChatCompletionResponseEventType.start,
|
|
||||||
delta="",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif "contentBlockStart" in chunk:
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
|
||||||
event=ChatCompletionResponseEvent(
|
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
|
||||||
delta=ToolCallDelta(
|
|
||||||
content=ToolCall(
|
|
||||||
tool_name=chunk["contentBlockStart"]["toolUse"][
|
|
||||||
"name"
|
|
||||||
],
|
|
||||||
call_id=chunk["contentBlockStart"]["toolUse"][
|
|
||||||
"toolUseId"
|
|
||||||
],
|
|
||||||
),
|
|
||||||
parse_status=ToolCallParseStatus.started,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif "contentBlockDelta" in chunk:
|
|
||||||
if "text" in chunk["contentBlockDelta"]["delta"]:
|
|
||||||
delta = chunk["contentBlockDelta"]["delta"]["text"]
|
|
||||||
else:
|
|
||||||
delta = ToolCallDelta(
|
|
||||||
content=ToolCall(
|
|
||||||
arguments=chunk["contentBlockDelta"]["delta"][
|
|
||||||
"toolUse"
|
|
||||||
]["input"]
|
|
||||||
),
|
|
||||||
parse_status=ToolCallParseStatus.success,
|
|
||||||
)
|
|
||||||
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
|
||||||
event=ChatCompletionResponseEvent(
|
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
|
||||||
delta=delta,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif "contentBlockStop" in chunk:
|
|
||||||
# Ignored
|
|
||||||
pass
|
|
||||||
elif "messageStop" in chunk:
|
|
||||||
stop_reason = (
|
|
||||||
BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
|
|
||||||
chunk["messageStop"]["stopReason"]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
|
||||||
event=ChatCompletionResponseEvent(
|
|
||||||
event_type=ChatCompletionResponseEventType.complete,
|
|
||||||
delta="",
|
|
||||||
stop_reason=stop_reason,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif "metadata" in chunk:
|
|
||||||
# Ignored
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
# Ignored
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue