Merge pull request #4106 from BerriAI/litellm_anthropic_bedrock_tool_calling_fix

fix(bedrock_httpx.py): fix tool calling for anthropic bedrock calls w/ streaming
This commit is contained in:
Krish Dholakia 2024-06-10 20:21:16 -07:00 committed by GitHub
commit 8379d58318
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 125 additions and 43 deletions

View file

@ -63,6 +63,11 @@ claude_json_str = json.dumps(json_data)
import importlib.metadata
from ._logging import verbose_logger
from .types.router import LiteLLM_Params
from .types.llms.openai import (
ChatCompletionToolCallChunk,
ChatCompletionToolCallFunctionChunk,
ChatCompletionDeltaToolCallChunk,
)
from .integrations.traceloop import TraceloopLogger
from .integrations.athina import AthinaLogger
from .integrations.helicone import HeliconeLogger
@ -3250,7 +3255,7 @@ def client(original_function):
stream=kwargs.get("stream", False),
)
if kwargs.get("stream", False) == True:
if kwargs.get("stream", False) is True:
cached_result = CustomStreamWrapper(
completion_stream=cached_result,
model=model,
@ -11301,7 +11306,6 @@ class CustomStreamWrapper:
raise StopIteration
response_obj: GenericStreamingChunk = chunk
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
@ -11316,6 +11320,10 @@ class CustomStreamWrapper:
completion_tokens=response_obj["usage"]["outputTokens"],
total_tokens=response_obj["usage"]["totalTokens"],
)
if "tool_use" in response_obj and response_obj["tool_use"] is not None:
completion_obj["tool_calls"] = [response_obj["tool_use"]]
elif self.custom_llm_provider == "sagemaker":
print_verbose(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}")
response_obj = self.handle_sagemaker_stream(chunk)
@ -11332,7 +11340,6 @@ class CustomStreamWrapper:
new_chunk = self.completion_stream[:chunk_size]
completion_obj["content"] = new_chunk
self.completion_stream = self.completion_stream[chunk_size:]
time.sleep(0.05)
elif self.custom_llm_provider == "palm":
# fake streaming
response_obj = {}
@ -11345,7 +11352,6 @@ class CustomStreamWrapper:
new_chunk = self.completion_stream[:chunk_size]
completion_obj["content"] = new_chunk
self.completion_stream = self.completion_stream[chunk_size:]
time.sleep(0.05)
elif self.custom_llm_provider == "ollama":
response_obj = self.handle_ollama_stream(chunk)
completion_obj["content"] = response_obj["text"]
@ -11432,7 +11438,7 @@ class CustomStreamWrapper:
# for azure, we need to pass the model from the orignal chunk
self.model = chunk.model
response_obj = self.handle_openai_chat_completion_chunk(chunk)
if response_obj == None:
if response_obj is None:
return
completion_obj["content"] = response_obj["text"]
print_verbose(f"completion obj content: {completion_obj['content']}")
@ -11565,7 +11571,7 @@ class CustomStreamWrapper:
else:
if (
self.stream_options is not None
and self.stream_options["include_usage"] == True
and self.stream_options["include_usage"] is True
):
return model_response
return
@ -11590,8 +11596,14 @@ class CustomStreamWrapper:
return model_response
elif (
"content" in completion_obj
and isinstance(completion_obj["content"], str)
and len(completion_obj["content"]) > 0
and (
isinstance(completion_obj["content"], str)
and len(completion_obj["content"]) > 0
)
or (
"tool_calls" in completion_obj
and len(completion_obj["tool_calls"]) > 0
)
): # cannot set content of an OpenAI Object to be an empty string
hold, model_response_str = self.check_special_tokens(
chunk=completion_obj["content"],
@ -11647,7 +11659,7 @@ class CustomStreamWrapper:
else:
## else
completion_obj["content"] = model_response_str
if self.sent_first_chunk == False:
if self.sent_first_chunk is False:
completion_obj["role"] = "assistant"
self.sent_first_chunk = True
model_response.choices[0].delta = Delta(**completion_obj)
@ -11656,7 +11668,7 @@ class CustomStreamWrapper:
else:
return
elif self.received_finish_reason is not None:
if self.sent_last_chunk == True:
if self.sent_last_chunk is True:
raise StopIteration
# flush any remaining holding chunk
if len(self.holding_chunk) > 0: