fix(bedrock_httpx.py): fix tool calling for anthropic bedrock calls w/ streaming

Fixes https://github.com/BerriAI/litellm/issues/4091
This commit is contained in:
Krrish Dholakia 2024-06-10 14:20:25 -07:00
parent 6306914e56
commit 84652dd946
5 changed files with 117 additions and 40 deletions

View file

@ -63,6 +63,11 @@ claude_json_str = json.dumps(json_data)
import importlib.metadata
from ._logging import verbose_logger
from .types.router import LiteLLM_Params
from .types.llms.openai import (
ChatCompletionToolCallChunk,
ChatCompletionToolCallFunctionChunk,
ChatCompletionDeltaToolCallChunk,
)
from .integrations.traceloop import TraceloopLogger
from .integrations.athina import AthinaLogger
from .integrations.helicone import HeliconeLogger
@ -3218,7 +3223,7 @@ def client(original_function):
stream=kwargs.get("stream", False),
)
if kwargs.get("stream", False) == True:
if kwargs.get("stream", False) is True:
cached_result = CustomStreamWrapper(
completion_stream=cached_result,
model=model,
@ -11468,6 +11473,10 @@ class CustomStreamWrapper:
completion_tokens=response_obj["usage"]["outputTokens"],
total_tokens=response_obj["usage"]["totalTokens"],
)
if "tool_use" in response_obj and response_obj["tool_use"] is not None:
completion_obj["tool_calls"] = [response_obj["tool_use"]]
elif self.custom_llm_provider == "sagemaker":
print_verbose(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}")
response_obj = self.handle_sagemaker_stream(chunk)
@ -11484,7 +11493,6 @@ class CustomStreamWrapper:
new_chunk = self.completion_stream[:chunk_size]
completion_obj["content"] = new_chunk
self.completion_stream = self.completion_stream[chunk_size:]
time.sleep(0.05)
elif self.custom_llm_provider == "palm":
# fake streaming
response_obj = {}
@ -11497,7 +11505,6 @@ class CustomStreamWrapper:
new_chunk = self.completion_stream[:chunk_size]
completion_obj["content"] = new_chunk
self.completion_stream = self.completion_stream[chunk_size:]
time.sleep(0.05)
elif self.custom_llm_provider == "ollama":
response_obj = self.handle_ollama_stream(chunk)
completion_obj["content"] = response_obj["text"]
@ -11717,7 +11724,7 @@ class CustomStreamWrapper:
else:
if (
self.stream_options is not None
and self.stream_options["include_usage"] == True
and self.stream_options["include_usage"] is True
):
return model_response
return
@ -11742,8 +11749,14 @@ class CustomStreamWrapper:
return model_response
elif (
"content" in completion_obj
and isinstance(completion_obj["content"], str)
and len(completion_obj["content"]) > 0
and (
isinstance(completion_obj["content"], str)
and len(completion_obj["content"]) > 0
)
or (
"tool_calls" in completion_obj
and len(completion_obj["tool_calls"]) > 0
)
): # cannot set content of an OpenAI Object to be an empty string
hold, model_response_str = self.check_special_tokens(
chunk=completion_obj["content"],
@ -11799,7 +11812,7 @@ class CustomStreamWrapper:
else:
## else
completion_obj["content"] = model_response_str
if self.sent_first_chunk == False:
if self.sent_first_chunk is False:
completion_obj["role"] = "assistant"
self.sent_first_chunk = True
model_response.choices[0].delta = Delta(**completion_obj)