diff --git a/llama_stack/providers/adapters/inference/bedrock/bedrock.py b/llama_stack/providers/adapters/inference/bedrock/bedrock.py index 3800c0496..caf886c0b 100644 --- a/llama_stack/providers/adapters/inference/bedrock/bedrock.py +++ b/llama_stack/providers/adapters/inference/bedrock/bedrock.py @@ -55,7 +55,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, - ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: + ) -> AsyncGenerator: raise NotImplementedError() @staticmethod @@ -290,23 +290,130 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), response_format: Optional[ResponseFormat] = None, - # zero-shot tool definitions as input to the model tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, - ) -> ( - AsyncGenerator - ): # Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: - bedrock_model = self.map_to_provider_model(model) - inference_config = BedrockInferenceAdapter.get_bedrock_inference_config( - sampling_params + ) -> Union[ + ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk] + ]: + request = ChatCompletionRequest( + model=model, + messages=messages, + sampling_params=sampling_params, + tools=tools or [], + tool_choice=tool_choice, + tool_prompt_format=tool_prompt_format, + response_format=response_format, + stream=stream, + logprobs=logprobs, ) - tool_config = BedrockInferenceAdapter._tools_to_tool_config(tools, tool_choice) + if stream: + return self._stream_chat_completion(request) + else: + return await self._nonstream_chat_completion(request) + + async def _nonstream_chat_completion( + self, request: ChatCompletionRequest + ) -> ChatCompletionResponse: + params = self._get_params_for_chat_completion(request) + converse_api_res = self.client.converse(**params) + + output_message = BedrockInferenceAdapter._bedrock_message_to_message( + converse_api_res + ) + + return ChatCompletionResponse( + completion_message=output_message, + logprobs=None, + ) + + async def _stream_chat_completion( + self, request: ChatCompletionRequest + ) -> AsyncGenerator: + params = self._get_params_for_chat_completion(request) + converse_stream_api_res = self.client.converse_stream(**params) + event_stream = converse_stream_api_res["stream"] + + for chunk in event_stream: + if "messageStart" in chunk: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.start, + delta="", + ) + ) + elif "contentBlockStart" in chunk: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=ToolCallDelta( + content=ToolCall( + tool_name=chunk["contentBlockStart"]["toolUse"]["name"], + call_id=chunk["contentBlockStart"]["toolUse"][ + "toolUseId" + ], + ), + parse_status=ToolCallParseStatus.started, + ), + ) + ) + elif "contentBlockDelta" in chunk: + if "text" in chunk["contentBlockDelta"]["delta"]: + delta = chunk["contentBlockDelta"]["delta"]["text"] + else: + delta = ToolCallDelta( + content=ToolCall( + arguments=chunk["contentBlockDelta"]["delta"]["toolUse"][ + "input" + ] + ), + parse_status=ToolCallParseStatus.success, + ) + + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=delta, + ) + ) + elif "contentBlockStop" in chunk: + # Ignored + pass + elif "messageStop" in chunk: + stop_reason = ( + BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason( + chunk["messageStop"]["stopReason"] + ) + ) + + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.complete, + delta="", + stop_reason=stop_reason, + ) + ) + elif "metadata" in chunk: + # Ignored + pass + else: + # Ignored + pass + + def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> Dict: + bedrock_model = self.map_to_provider_model(request.model) + inference_config = BedrockInferenceAdapter.get_bedrock_inference_config( + request.sampling_params + ) + + tool_config = BedrockInferenceAdapter._tools_to_tool_config( + request.tools, request.tool_choice + ) bedrock_messages, system_bedrock_messages = ( - BedrockInferenceAdapter._messages_to_bedrock_messages(messages) + BedrockInferenceAdapter._messages_to_bedrock_messages(request.messages) ) converse_api_params = { @@ -317,93 +424,12 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): converse_api_params["inferenceConfig"] = inference_config # Tool use is not supported in streaming mode - if tool_config and not stream: + if tool_config and not request.stream: converse_api_params["toolConfig"] = tool_config if system_bedrock_messages: converse_api_params["system"] = system_bedrock_messages - if not stream: - converse_api_res = self.client.converse(**converse_api_params) - - output_message = BedrockInferenceAdapter._bedrock_message_to_message( - converse_api_res - ) - - yield ChatCompletionResponse( - completion_message=output_message, - logprobs=None, - ) - else: - converse_stream_api_res = self.client.converse_stream(**converse_api_params) - event_stream = converse_stream_api_res["stream"] - - for chunk in event_stream: - if "messageStart" in chunk: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.start, - delta="", - ) - ) - elif "contentBlockStart" in chunk: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - content=ToolCall( - tool_name=chunk["contentBlockStart"]["toolUse"][ - "name" - ], - call_id=chunk["contentBlockStart"]["toolUse"][ - "toolUseId" - ], - ), - parse_status=ToolCallParseStatus.started, - ), - ) - ) - elif "contentBlockDelta" in chunk: - if "text" in chunk["contentBlockDelta"]["delta"]: - delta = chunk["contentBlockDelta"]["delta"]["text"] - else: - delta = ToolCallDelta( - content=ToolCall( - arguments=chunk["contentBlockDelta"]["delta"][ - "toolUse" - ]["input"] - ), - parse_status=ToolCallParseStatus.success, - ) - - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=delta, - ) - ) - elif "contentBlockStop" in chunk: - # Ignored - pass - elif "messageStop" in chunk: - stop_reason = ( - BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason( - chunk["messageStop"]["stopReason"] - ) - ) - - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.complete, - delta="", - stop_reason=stop_reason, - ) - ) - elif "metadata" in chunk: - # Ignored - pass - else: - # Ignored - pass + return converse_api_params async def embeddings( self,