Merge pull request #4106 from BerriAI/litellm_anthropic_bedrock_tool_calling_fix

fix(bedrock_httpx.py): fix tool calling for anthropic bedrock calls w/ streaming
This commit is contained in:
Krish Dholakia 2024-06-10 20:21:16 -07:00 committed by GitHub
commit 8379d58318
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 125 additions and 43 deletions

View file

@ -51,6 +51,7 @@ from litellm.types.llms.openai import (
ChatCompletionResponseMessage,
ChatCompletionToolCallChunk,
ChatCompletionToolCallFunctionChunk,
ChatCompletionDeltaChunk,
)
@ -1859,29 +1860,59 @@ class AWSEventStreamDecoder:
self.parser = EventStreamJSONParser()
def converse_chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
text = ""
tool_str = ""
is_finished = False
finish_reason = ""
usage: Optional[ConverseTokenUsageBlock] = None
if "delta" in chunk_data:
delta_obj = ContentBlockDeltaEvent(**chunk_data["delta"])
if "text" in delta_obj:
text = delta_obj["text"]
elif "toolUse" in delta_obj:
tool_str = delta_obj["toolUse"]["input"]
elif "stopReason" in chunk_data:
finish_reason = map_finish_reason(chunk_data.get("stopReason", "stop"))
elif "usage" in chunk_data:
usage = ConverseTokenUsageBlock(**chunk_data["usage"]) # type: ignore
response = GenericStreamingChunk(
text=text,
tool_str=tool_str,
is_finished=is_finished,
finish_reason=finish_reason,
usage=usage,
)
return response
try:
text = ""
tool_use: Optional[ChatCompletionToolCallChunk] = None
is_finished = False
finish_reason = ""
usage: Optional[ConverseTokenUsageBlock] = None
index = int(chunk_data.get("contentBlockIndex", 0))
if "start" in chunk_data:
start_obj = ContentBlockStartEvent(**chunk_data["start"])
if (
start_obj is not None
and "toolUse" in start_obj
and start_obj["toolUse"] is not None
):
tool_use = {
"id": start_obj["toolUse"]["toolUseId"],
"type": "function",
"function": {
"name": start_obj["toolUse"]["name"],
"arguments": "",
},
}
elif "delta" in chunk_data:
delta_obj = ContentBlockDeltaEvent(**chunk_data["delta"])
if "text" in delta_obj:
text = delta_obj["text"]
elif "toolUse" in delta_obj:
tool_use = {
"id": None,
"type": "function",
"function": {
"name": None,
"arguments": delta_obj["toolUse"]["input"],
},
}
elif "stopReason" in chunk_data:
finish_reason = map_finish_reason(chunk_data.get("stopReason", "stop"))
is_finished = True
elif "usage" in chunk_data:
usage = ConverseTokenUsageBlock(**chunk_data["usage"]) # type: ignore
response = GenericStreamingChunk(
text=text,
tool_use=tool_use,
is_finished=is_finished,
finish_reason=finish_reason,
usage=usage,
index=index,
)
return response
except Exception as e:
raise Exception("Received streaming error - {}".format(str(e)))
def _chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
text = ""
@ -1890,12 +1921,16 @@ class AWSEventStreamDecoder:
if "outputText" in chunk_data:
text = chunk_data["outputText"]
# ai21 mapping
if "ai21" in self.model: # fake ai21 streaming
elif "ai21" in self.model: # fake ai21 streaming
text = chunk_data.get("completions")[0].get("data").get("text") # type: ignore
is_finished = True
finish_reason = "stop"
######## bedrock.anthropic mappings ###############
elif "delta" in chunk_data:
elif (
"contentBlockIndex" in chunk_data
or "stopReason" in chunk_data
or "metrics" in chunk_data
):
return self.converse_chunk_parser(chunk_data=chunk_data)
######## bedrock.mistral mappings ###############
elif "outputs" in chunk_data:
@ -1905,7 +1940,7 @@ class AWSEventStreamDecoder:
):
text = chunk_data["outputs"][0]["text"]
stop_reason = chunk_data.get("stop_reason", None)
if stop_reason != None:
if stop_reason is not None:
is_finished = True
finish_reason = stop_reason
######## bedrock.cohere mappings ###############
@ -1926,8 +1961,9 @@ class AWSEventStreamDecoder:
text=text,
is_finished=is_finished,
finish_reason=finish_reason,
tool_str="",
usage=None,
index=0,
tool_use=None,
)
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GenericStreamingChunk]:

View file

@ -2535,7 +2535,10 @@ def streaming_and_function_calling_format_tests(idx, chunk):
return extracted_chunk, finished
def test_openai_streaming_and_function_calling():
@pytest.mark.parametrize(
"model", ["gpt-3.5-turbo", "anthropic.claude-3-sonnet-20240229-v1:0"]
)
def test_streaming_and_function_calling(model):
tools = [
{
"type": "function",
@ -2556,16 +2559,21 @@ def test_openai_streaming_and_function_calling():
},
}
]
messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
try:
response = completion(
model="gpt-3.5-turbo",
litellm.set_verbose = True
response: litellm.CustomStreamWrapper = completion(
model=model,
tools=tools,
messages=messages,
stream=True,
)
tool_choice="required",
) # type: ignore
# Add any assertions here to check the response
for idx, chunk in enumerate(response):
# continue
print("\n{}\n".format(chunk))
if idx == 0:
assert (
chunk.choices[0].delta.tool_calls[0].function.arguments is not None
@ -2573,6 +2581,7 @@ def test_openai_streaming_and_function_calling():
assert isinstance(
chunk.choices[0].delta.tool_calls[0].function.arguments, str
)
# assert False
except Exception as e:
pytest.fail(f"Error occurred: {e}")
raise e

View file

@ -3990,6 +3990,7 @@ def test_async_text_completion():
asyncio.run(test_get_response())
@pytest.mark.skip(reason="Tgai endpoints are unstable")
def test_async_text_completion_together_ai():
litellm.set_verbose = True
print("test_async_text_completion")

View file

@ -1,5 +1,6 @@
from typing import TypedDict, Any, Union, Optional, Literal, List
import json
from .openai import ChatCompletionToolCallChunk
from typing_extensions import (
Self,
Protocol,
@ -118,6 +119,15 @@ class ToolBlockDeltaEvent(TypedDict):
input: str
class ToolUseBlockStartEvent(TypedDict):
name: str
toolUseId: str
class ContentBlockStartEvent(TypedDict, total=False):
toolUse: Optional[ToolUseBlockStartEvent]
class ContentBlockDeltaEvent(TypedDict, total=False):
"""
Either 'text' or 'toolUse' will be specified for Converse API streaming response.
@ -138,10 +148,11 @@ class RequestObject(TypedDict, total=False):
class GenericStreamingChunk(TypedDict):
text: Required[str]
tool_str: Required[str]
tool_use: Optional[ChatCompletionToolCallChunk]
is_finished: Required[bool]
finish_reason: Required[str]
usage: Optional[ConverseTokenUsageBlock]
index: int
class Document(TypedDict):

View file

@ -296,14 +296,27 @@ class ListBatchRequest(TypedDict, total=False):
class ChatCompletionToolCallFunctionChunk(TypedDict):
name: str
name: Optional[str]
arguments: str
class ChatCompletionToolCallChunk(TypedDict):
id: Optional[str]
type: Literal["function"]
function: ChatCompletionToolCallFunctionChunk
class ChatCompletionDeltaToolCallChunk(TypedDict):
id: str
type: Literal["function"]
function: ChatCompletionToolCallFunctionChunk
index: int
class ChatCompletionDeltaChunk(TypedDict, total=False):
content: Optional[str]
tool_calls: List[ChatCompletionDeltaToolCallChunk]
role: str
class ChatCompletionResponseMessage(TypedDict, total=False):

View file

@ -63,6 +63,11 @@ claude_json_str = json.dumps(json_data)
import importlib.metadata
from ._logging import verbose_logger
from .types.router import LiteLLM_Params
from .types.llms.openai import (
ChatCompletionToolCallChunk,
ChatCompletionToolCallFunctionChunk,
ChatCompletionDeltaToolCallChunk,
)
from .integrations.traceloop import TraceloopLogger
from .integrations.athina import AthinaLogger
from .integrations.helicone import HeliconeLogger
@ -3250,7 +3255,7 @@ def client(original_function):
stream=kwargs.get("stream", False),
)
if kwargs.get("stream", False) == True:
if kwargs.get("stream", False) is True:
cached_result = CustomStreamWrapper(
completion_stream=cached_result,
model=model,
@ -11301,7 +11306,6 @@ class CustomStreamWrapper:
raise StopIteration
response_obj: GenericStreamingChunk = chunk
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
@ -11316,6 +11320,10 @@ class CustomStreamWrapper:
completion_tokens=response_obj["usage"]["outputTokens"],
total_tokens=response_obj["usage"]["totalTokens"],
)
if "tool_use" in response_obj and response_obj["tool_use"] is not None:
completion_obj["tool_calls"] = [response_obj["tool_use"]]
elif self.custom_llm_provider == "sagemaker":
print_verbose(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}")
response_obj = self.handle_sagemaker_stream(chunk)
@ -11332,7 +11340,6 @@ class CustomStreamWrapper:
new_chunk = self.completion_stream[:chunk_size]
completion_obj["content"] = new_chunk
self.completion_stream = self.completion_stream[chunk_size:]
time.sleep(0.05)
elif self.custom_llm_provider == "palm":
# fake streaming
response_obj = {}
@ -11345,7 +11352,6 @@ class CustomStreamWrapper:
new_chunk = self.completion_stream[:chunk_size]
completion_obj["content"] = new_chunk
self.completion_stream = self.completion_stream[chunk_size:]
time.sleep(0.05)
elif self.custom_llm_provider == "ollama":
response_obj = self.handle_ollama_stream(chunk)
completion_obj["content"] = response_obj["text"]
@ -11432,7 +11438,7 @@ class CustomStreamWrapper:
# for azure, we need to pass the model from the orignal chunk
self.model = chunk.model
response_obj = self.handle_openai_chat_completion_chunk(chunk)
if response_obj == None:
if response_obj is None:
return
completion_obj["content"] = response_obj["text"]
print_verbose(f"completion obj content: {completion_obj['content']}")
@ -11565,7 +11571,7 @@ class CustomStreamWrapper:
else:
if (
self.stream_options is not None
and self.stream_options["include_usage"] == True
and self.stream_options["include_usage"] is True
):
return model_response
return
@ -11590,8 +11596,14 @@ class CustomStreamWrapper:
return model_response
elif (
"content" in completion_obj
and isinstance(completion_obj["content"], str)
and len(completion_obj["content"]) > 0
and (
isinstance(completion_obj["content"], str)
and len(completion_obj["content"]) > 0
)
or (
"tool_calls" in completion_obj
and len(completion_obj["tool_calls"]) > 0
)
): # cannot set content of an OpenAI Object to be an empty string
hold, model_response_str = self.check_special_tokens(
chunk=completion_obj["content"],
@ -11647,7 +11659,7 @@ class CustomStreamWrapper:
else:
## else
completion_obj["content"] = model_response_str
if self.sent_first_chunk == False:
if self.sent_first_chunk is False:
completion_obj["role"] = "assistant"
self.sent_first_chunk = True
model_response.choices[0].delta = Delta(**completion_obj)
@ -11656,7 +11668,7 @@ class CustomStreamWrapper:
else:
return
elif self.received_finish_reason is not None:
if self.sent_last_chunk == True:
if self.sent_last_chunk is True:
raise StopIteration
# flush any remaining holding chunk
if len(self.holding_chunk) > 0: