fix linter errors

This commit is contained in:
Dinesh Yeduguru 2024-11-02 11:04:02 -07:00
parent c629615396
commit 0e7e4bfb35

View file

@ -57,7 +57,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
raise NotImplementedError() raise NotImplementedError()
@staticmethod @staticmethod
def _bedrock_stop_reason_to_stop_reason(bedrock_stop_reason: str) -> StopReason: def _bedrock_stop_reason_to_stop_reason(bedrock_stop_reason: str) -> StopReason:
if bedrock_stop_reason == "max_tokens": if bedrock_stop_reason == "max_tokens":
@ -352,7 +352,9 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
delta=ToolCallDelta( delta=ToolCallDelta(
content=ToolCall( content=ToolCall(
tool_name=chunk["contentBlockStart"]["toolUse"]["name"], tool_name=chunk["contentBlockStart"]["toolUse"]["name"],
call_id=chunk["contentBlockStart"]["toolUse"]["toolUseId"], call_id=chunk["contentBlockStart"]["toolUse"][
"toolUseId"
],
), ),
parse_status=ToolCallParseStatus.started, parse_status=ToolCallParseStatus.started,
), ),
@ -364,7 +366,9 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
else: else:
delta = ToolCallDelta( delta = ToolCallDelta(
content=ToolCall( content=ToolCall(
arguments=chunk["contentBlockDelta"]["delta"]["toolUse"]["input"] arguments=chunk["contentBlockDelta"]["delta"]["toolUse"][
"input"
]
), ),
parse_status=ToolCallParseStatus.success, parse_status=ToolCallParseStatus.success,
) )
@ -379,8 +383,10 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
# Ignored # Ignored
pass pass
elif "messageStop" in chunk: elif "messageStop" in chunk:
stop_reason = BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason( stop_reason = (
chunk["messageStop"]["stopReason"] BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
chunk["messageStop"]["stopReason"]
)
) )
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(