fix bedrock impl (#359)

* fix bedrock impl

* fix linter errors

* fix return type and remove debug print
This commit is contained in:
Dinesh Yeduguru 2024-11-03 07:32:30 -08:00 committed by GitHub
parent bf4f97a2e1
commit ac93dd89cf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -55,7 +55,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ) -> AsyncGenerator:
raise NotImplementedError() raise NotImplementedError()
@staticmethod @staticmethod
@ -290,51 +290,51 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
# zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> ( ) -> Union[
AsyncGenerator ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
): # Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ]:
bedrock_model = self.map_to_provider_model(model) request = ChatCompletionRequest(
inference_config = BedrockInferenceAdapter.get_bedrock_inference_config( model=model,
sampling_params messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream,
logprobs=logprobs,
) )
tool_config = BedrockInferenceAdapter._tools_to_tool_config(tools, tool_choice) if stream:
bedrock_messages, system_bedrock_messages = ( return self._stream_chat_completion(request)
BedrockInferenceAdapter._messages_to_bedrock_messages(messages) else:
) return await self._nonstream_chat_completion(request)
converse_api_params = { async def _nonstream_chat_completion(
"modelId": bedrock_model, self, request: ChatCompletionRequest
"messages": bedrock_messages, ) -> ChatCompletionResponse:
} params = self._get_params_for_chat_completion(request)
if inference_config: converse_api_res = self.client.converse(**params)
converse_api_params["inferenceConfig"] = inference_config
# Tool use is not supported in streaming mode
if tool_config and not stream:
converse_api_params["toolConfig"] = tool_config
if system_bedrock_messages:
converse_api_params["system"] = system_bedrock_messages
if not stream:
converse_api_res = self.client.converse(**converse_api_params)
output_message = BedrockInferenceAdapter._bedrock_message_to_message( output_message = BedrockInferenceAdapter._bedrock_message_to_message(
converse_api_res converse_api_res
) )
yield ChatCompletionResponse( return ChatCompletionResponse(
completion_message=output_message, completion_message=output_message,
logprobs=None, logprobs=None,
) )
else:
converse_stream_api_res = self.client.converse_stream(**converse_api_params) async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
params = self._get_params_for_chat_completion(request)
converse_stream_api_res = self.client.converse_stream(**params)
event_stream = converse_stream_api_res["stream"] event_stream = converse_stream_api_res["stream"]
for chunk in event_stream: for chunk in event_stream:
@ -351,9 +351,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
event_type=ChatCompletionResponseEventType.progress, event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta( delta=ToolCallDelta(
content=ToolCall( content=ToolCall(
tool_name=chunk["contentBlockStart"]["toolUse"][ tool_name=chunk["contentBlockStart"]["toolUse"]["name"],
"name"
],
call_id=chunk["contentBlockStart"]["toolUse"][ call_id=chunk["contentBlockStart"]["toolUse"][
"toolUseId" "toolUseId"
], ],
@ -368,9 +366,9 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
else: else:
delta = ToolCallDelta( delta = ToolCallDelta(
content=ToolCall( content=ToolCall(
arguments=chunk["contentBlockDelta"]["delta"][ arguments=chunk["contentBlockDelta"]["delta"]["toolUse"][
"toolUse" "input"
]["input"] ]
), ),
parse_status=ToolCallParseStatus.success, parse_status=ToolCallParseStatus.success,
) )
@ -405,6 +403,34 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
# Ignored # Ignored
pass pass
def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> Dict:
bedrock_model = self.map_to_provider_model(request.model)
inference_config = BedrockInferenceAdapter.get_bedrock_inference_config(
request.sampling_params
)
tool_config = BedrockInferenceAdapter._tools_to_tool_config(
request.tools, request.tool_choice
)
bedrock_messages, system_bedrock_messages = (
BedrockInferenceAdapter._messages_to_bedrock_messages(request.messages)
)
converse_api_params = {
"modelId": bedrock_model,
"messages": bedrock_messages,
}
if inference_config:
converse_api_params["inferenceConfig"] = inference_config
# Tool use is not supported in streaming mode
if tool_config and not request.stream:
converse_api_params["toolConfig"] = tool_config
if system_bedrock_messages:
converse_api_params["system"] = system_bedrock_messages
return converse_api_params
async def embeddings( async def embeddings(
self, self,
model: str, model: str,