fix bedrock impl (#359)

* fix bedrock impl

* fix linter errors

* fix return type and remove debug print
This commit is contained in:
Dinesh Yeduguru 2024-11-03 07:32:30 -08:00 committed by GitHub
parent bf4f97a2e1
commit ac93dd89cf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -55,7 +55,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
) -> AsyncGenerator:
raise NotImplementedError()
@staticmethod
@ -290,51 +290,51 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
# zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> (
AsyncGenerator
): # Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]:
bedrock_model = self.map_to_provider_model(model)
inference_config = BedrockInferenceAdapter.get_bedrock_inference_config(
sampling_params
) -> Union[
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]:
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
tool_config = BedrockInferenceAdapter._tools_to_tool_config(tools, tool_choice)
bedrock_messages, system_bedrock_messages = (
BedrockInferenceAdapter._messages_to_bedrock_messages(messages)
)
if stream:
return self._stream_chat_completion(request)
else:
return await self._nonstream_chat_completion(request)
converse_api_params = {
"modelId": bedrock_model,
"messages": bedrock_messages,
}
if inference_config:
converse_api_params["inferenceConfig"] = inference_config
# Tool use is not supported in streaming mode
if tool_config and not stream:
converse_api_params["toolConfig"] = tool_config
if system_bedrock_messages:
converse_api_params["system"] = system_bedrock_messages
if not stream:
converse_api_res = self.client.converse(**converse_api_params)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
params = self._get_params_for_chat_completion(request)
converse_api_res = self.client.converse(**params)
output_message = BedrockInferenceAdapter._bedrock_message_to_message(
converse_api_res
)
yield ChatCompletionResponse(
return ChatCompletionResponse(
completion_message=output_message,
logprobs=None,
)
else:
converse_stream_api_res = self.client.converse_stream(**converse_api_params)
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
params = self._get_params_for_chat_completion(request)
converse_stream_api_res = self.client.converse_stream(**params)
event_stream = converse_stream_api_res["stream"]
for chunk in event_stream:
@ -351,9 +351,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=ToolCall(
tool_name=chunk["contentBlockStart"]["toolUse"][
"name"
],
tool_name=chunk["contentBlockStart"]["toolUse"]["name"],
call_id=chunk["contentBlockStart"]["toolUse"][
"toolUseId"
],
@ -368,9 +366,9 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
else:
delta = ToolCallDelta(
content=ToolCall(
arguments=chunk["contentBlockDelta"]["delta"][
"toolUse"
]["input"]
arguments=chunk["contentBlockDelta"]["delta"]["toolUse"][
"input"
]
),
parse_status=ToolCallParseStatus.success,
)
@ -405,6 +403,34 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
# Ignored
pass
def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> Dict:
bedrock_model = self.map_to_provider_model(request.model)
inference_config = BedrockInferenceAdapter.get_bedrock_inference_config(
request.sampling_params
)
tool_config = BedrockInferenceAdapter._tools_to_tool_config(
request.tools, request.tool_choice
)
bedrock_messages, system_bedrock_messages = (
BedrockInferenceAdapter._messages_to_bedrock_messages(request.messages)
)
converse_api_params = {
"modelId": bedrock_model,
"messages": bedrock_messages,
}
if inference_config:
converse_api_params["inferenceConfig"] = inference_config
# Tool use is not supported in streaming mode
if tool_config and not request.stream:
converse_api_params["toolConfig"] = tool_config
if system_bedrock_messages:
converse_api_params["system"] = system_bedrock_messages
return converse_api_params
async def embeddings(
self,
model: str,