fix bedrock impl

This commit is contained in:
Dinesh Yeduguru 2024-11-02 10:59:30 -07:00
parent bf4f97a2e1
commit c629615396

View file

@ -55,7 +55,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ) -> AsyncGenerator:
raise NotImplementedError() raise NotImplementedError()
@staticmethod @staticmethod
@ -290,51 +290,51 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
# zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> ( ) -> AsyncGenerator:
AsyncGenerator request = ChatCompletionRequest(
): # Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: model=model,
bedrock_model = self.map_to_provider_model(model) messages=messages,
inference_config = BedrockInferenceAdapter.get_bedrock_inference_config( sampling_params=sampling_params,
sampling_params tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream,
logprobs=logprobs,
) )
tool_config = BedrockInferenceAdapter._tools_to_tool_config(tools, tool_choice) if stream:
bedrock_messages, system_bedrock_messages = ( return self._stream_chat_completion(request)
BedrockInferenceAdapter._messages_to_bedrock_messages(messages) else:
) return await self._nonstream_chat_completion(request)
converse_api_params = { async def _nonstream_chat_completion(
"modelId": bedrock_model, self, request: ChatCompletionRequest
"messages": bedrock_messages, ) -> ChatCompletionResponse:
} print("non-streaming chat completion")
if inference_config: params = self._get_params_for_chat_completion(request)
converse_api_params["inferenceConfig"] = inference_config converse_api_res = self.client.converse(**params)
# Tool use is not supported in streaming mode
if tool_config and not stream:
converse_api_params["toolConfig"] = tool_config
if system_bedrock_messages:
converse_api_params["system"] = system_bedrock_messages
if not stream:
converse_api_res = self.client.converse(**converse_api_params)
output_message = BedrockInferenceAdapter._bedrock_message_to_message( output_message = BedrockInferenceAdapter._bedrock_message_to_message(
converse_api_res converse_api_res
) )
yield ChatCompletionResponse( return ChatCompletionResponse(
completion_message=output_message, completion_message=output_message,
logprobs=None, logprobs=None,
) )
else:
converse_stream_api_res = self.client.converse_stream(**converse_api_params) async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
print("streaming chat completion")
params = self._get_params_for_chat_completion(request)
converse_stream_api_res = self.client.converse_stream(**params)
event_stream = converse_stream_api_res["stream"] event_stream = converse_stream_api_res["stream"]
for chunk in event_stream: for chunk in event_stream:
@ -351,12 +351,8 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
event_type=ChatCompletionResponseEventType.progress, event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta( delta=ToolCallDelta(
content=ToolCall( content=ToolCall(
tool_name=chunk["contentBlockStart"]["toolUse"][ tool_name=chunk["contentBlockStart"]["toolUse"]["name"],
"name" call_id=chunk["contentBlockStart"]["toolUse"]["toolUseId"],
],
call_id=chunk["contentBlockStart"]["toolUse"][
"toolUseId"
],
), ),
parse_status=ToolCallParseStatus.started, parse_status=ToolCallParseStatus.started,
), ),
@ -368,9 +364,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
else: else:
delta = ToolCallDelta( delta = ToolCallDelta(
content=ToolCall( content=ToolCall(
arguments=chunk["contentBlockDelta"]["delta"][ arguments=chunk["contentBlockDelta"]["delta"]["toolUse"]["input"]
"toolUse"
]["input"]
), ),
parse_status=ToolCallParseStatus.success, parse_status=ToolCallParseStatus.success,
) )
@ -385,11 +379,9 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
# Ignored # Ignored
pass pass
elif "messageStop" in chunk: elif "messageStop" in chunk:
stop_reason = ( stop_reason = BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
chunk["messageStop"]["stopReason"] chunk["messageStop"]["stopReason"]
) )
)
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
@ -405,6 +397,34 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
# Ignored # Ignored
pass pass
def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> Dict:
bedrock_model = self.map_to_provider_model(request.model)
inference_config = BedrockInferenceAdapter.get_bedrock_inference_config(
request.sampling_params
)
tool_config = BedrockInferenceAdapter._tools_to_tool_config(
request.tools, request.tool_choice
)
bedrock_messages, system_bedrock_messages = (
BedrockInferenceAdapter._messages_to_bedrock_messages(request.messages)
)
converse_api_params = {
"modelId": bedrock_model,
"messages": bedrock_messages,
}
if inference_config:
converse_api_params["inferenceConfig"] = inference_config
# Tool use is not supported in streaming mode
if tool_config and not request.stream:
converse_api_params["toolConfig"] = tool_config
if system_bedrock_messages:
converse_api_params["system"] = system_bedrock_messages
return converse_api_params
async def embeddings( async def embeddings(
self, self,
model: str, model: str,