forked from phoenix-oss/llama-stack-mirror
fix bedrock impl (#359)
* fix bedrock impl * fix linter errors * fix return type and remove debug print
This commit is contained in:
parent
bf4f97a2e1
commit
ac93dd89cf
1 changed files with 119 additions and 93 deletions
|
@ -55,7 +55,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
) -> AsyncGenerator:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -290,51 +290,51 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
# zero-shot tool definitions as input to the model
|
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> (
|
) -> Union[
|
||||||
AsyncGenerator
|
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
||||||
): # Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]:
|
]:
|
||||||
bedrock_model = self.map_to_provider_model(model)
|
request = ChatCompletionRequest(
|
||||||
inference_config = BedrockInferenceAdapter.get_bedrock_inference_config(
|
model=model,
|
||||||
sampling_params
|
messages=messages,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
tools=tools or [],
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
tool_prompt_format=tool_prompt_format,
|
||||||
|
response_format=response_format,
|
||||||
|
stream=stream,
|
||||||
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
tool_config = BedrockInferenceAdapter._tools_to_tool_config(tools, tool_choice)
|
if stream:
|
||||||
bedrock_messages, system_bedrock_messages = (
|
return self._stream_chat_completion(request)
|
||||||
BedrockInferenceAdapter._messages_to_bedrock_messages(messages)
|
else:
|
||||||
)
|
return await self._nonstream_chat_completion(request)
|
||||||
|
|
||||||
converse_api_params = {
|
async def _nonstream_chat_completion(
|
||||||
"modelId": bedrock_model,
|
self, request: ChatCompletionRequest
|
||||||
"messages": bedrock_messages,
|
) -> ChatCompletionResponse:
|
||||||
}
|
params = self._get_params_for_chat_completion(request)
|
||||||
if inference_config:
|
converse_api_res = self.client.converse(**params)
|
||||||
converse_api_params["inferenceConfig"] = inference_config
|
|
||||||
|
|
||||||
# Tool use is not supported in streaming mode
|
|
||||||
if tool_config and not stream:
|
|
||||||
converse_api_params["toolConfig"] = tool_config
|
|
||||||
if system_bedrock_messages:
|
|
||||||
converse_api_params["system"] = system_bedrock_messages
|
|
||||||
|
|
||||||
if not stream:
|
|
||||||
converse_api_res = self.client.converse(**converse_api_params)
|
|
||||||
|
|
||||||
output_message = BedrockInferenceAdapter._bedrock_message_to_message(
|
output_message = BedrockInferenceAdapter._bedrock_message_to_message(
|
||||||
converse_api_res
|
converse_api_res
|
||||||
)
|
)
|
||||||
|
|
||||||
yield ChatCompletionResponse(
|
return ChatCompletionResponse(
|
||||||
completion_message=output_message,
|
completion_message=output_message,
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
converse_stream_api_res = self.client.converse_stream(**converse_api_params)
|
async def _stream_chat_completion(
|
||||||
|
self, request: ChatCompletionRequest
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
params = self._get_params_for_chat_completion(request)
|
||||||
|
converse_stream_api_res = self.client.converse_stream(**params)
|
||||||
event_stream = converse_stream_api_res["stream"]
|
event_stream = converse_stream_api_res["stream"]
|
||||||
|
|
||||||
for chunk in event_stream:
|
for chunk in event_stream:
|
||||||
|
@ -351,9 +351,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
delta=ToolCallDelta(
|
delta=ToolCallDelta(
|
||||||
content=ToolCall(
|
content=ToolCall(
|
||||||
tool_name=chunk["contentBlockStart"]["toolUse"][
|
tool_name=chunk["contentBlockStart"]["toolUse"]["name"],
|
||||||
"name"
|
|
||||||
],
|
|
||||||
call_id=chunk["contentBlockStart"]["toolUse"][
|
call_id=chunk["contentBlockStart"]["toolUse"][
|
||||||
"toolUseId"
|
"toolUseId"
|
||||||
],
|
],
|
||||||
|
@ -368,9 +366,9 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
else:
|
else:
|
||||||
delta = ToolCallDelta(
|
delta = ToolCallDelta(
|
||||||
content=ToolCall(
|
content=ToolCall(
|
||||||
arguments=chunk["contentBlockDelta"]["delta"][
|
arguments=chunk["contentBlockDelta"]["delta"]["toolUse"][
|
||||||
"toolUse"
|
"input"
|
||||||
]["input"]
|
]
|
||||||
),
|
),
|
||||||
parse_status=ToolCallParseStatus.success,
|
parse_status=ToolCallParseStatus.success,
|
||||||
)
|
)
|
||||||
|
@ -405,6 +403,34 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
# Ignored
|
# Ignored
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> Dict:
|
||||||
|
bedrock_model = self.map_to_provider_model(request.model)
|
||||||
|
inference_config = BedrockInferenceAdapter.get_bedrock_inference_config(
|
||||||
|
request.sampling_params
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_config = BedrockInferenceAdapter._tools_to_tool_config(
|
||||||
|
request.tools, request.tool_choice
|
||||||
|
)
|
||||||
|
bedrock_messages, system_bedrock_messages = (
|
||||||
|
BedrockInferenceAdapter._messages_to_bedrock_messages(request.messages)
|
||||||
|
)
|
||||||
|
|
||||||
|
converse_api_params = {
|
||||||
|
"modelId": bedrock_model,
|
||||||
|
"messages": bedrock_messages,
|
||||||
|
}
|
||||||
|
if inference_config:
|
||||||
|
converse_api_params["inferenceConfig"] = inference_config
|
||||||
|
|
||||||
|
# Tool use is not supported in streaming mode
|
||||||
|
if tool_config and not request.stream:
|
||||||
|
converse_api_params["toolConfig"] = tool_config
|
||||||
|
if system_bedrock_messages:
|
||||||
|
converse_api_params["system"] = system_bedrock_messages
|
||||||
|
|
||||||
|
return converse_api_params
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue