mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
Merge pull request #4106 from BerriAI/litellm_anthropic_bedrock_tool_calling_fix
fix(bedrock_httpx.py): fix tool calling for anthropic bedrock calls w/ streaming
This commit is contained in:
commit
8379d58318
6 changed files with 125 additions and 43 deletions
|
@ -51,6 +51,7 @@ from litellm.types.llms.openai import (
|
||||||
ChatCompletionResponseMessage,
|
ChatCompletionResponseMessage,
|
||||||
ChatCompletionToolCallChunk,
|
ChatCompletionToolCallChunk,
|
||||||
ChatCompletionToolCallFunctionChunk,
|
ChatCompletionToolCallFunctionChunk,
|
||||||
|
ChatCompletionDeltaChunk,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1859,29 +1860,59 @@ class AWSEventStreamDecoder:
|
||||||
self.parser = EventStreamJSONParser()
|
self.parser = EventStreamJSONParser()
|
||||||
|
|
||||||
def converse_chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
|
def converse_chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
|
||||||
text = ""
|
try:
|
||||||
tool_str = ""
|
text = ""
|
||||||
is_finished = False
|
tool_use: Optional[ChatCompletionToolCallChunk] = None
|
||||||
finish_reason = ""
|
is_finished = False
|
||||||
usage: Optional[ConverseTokenUsageBlock] = None
|
finish_reason = ""
|
||||||
if "delta" in chunk_data:
|
usage: Optional[ConverseTokenUsageBlock] = None
|
||||||
delta_obj = ContentBlockDeltaEvent(**chunk_data["delta"])
|
|
||||||
if "text" in delta_obj:
|
index = int(chunk_data.get("contentBlockIndex", 0))
|
||||||
text = delta_obj["text"]
|
if "start" in chunk_data:
|
||||||
elif "toolUse" in delta_obj:
|
start_obj = ContentBlockStartEvent(**chunk_data["start"])
|
||||||
tool_str = delta_obj["toolUse"]["input"]
|
if (
|
||||||
elif "stopReason" in chunk_data:
|
start_obj is not None
|
||||||
finish_reason = map_finish_reason(chunk_data.get("stopReason", "stop"))
|
and "toolUse" in start_obj
|
||||||
elif "usage" in chunk_data:
|
and start_obj["toolUse"] is not None
|
||||||
usage = ConverseTokenUsageBlock(**chunk_data["usage"]) # type: ignore
|
):
|
||||||
response = GenericStreamingChunk(
|
tool_use = {
|
||||||
text=text,
|
"id": start_obj["toolUse"]["toolUseId"],
|
||||||
tool_str=tool_str,
|
"type": "function",
|
||||||
is_finished=is_finished,
|
"function": {
|
||||||
finish_reason=finish_reason,
|
"name": start_obj["toolUse"]["name"],
|
||||||
usage=usage,
|
"arguments": "",
|
||||||
)
|
},
|
||||||
return response
|
}
|
||||||
|
elif "delta" in chunk_data:
|
||||||
|
delta_obj = ContentBlockDeltaEvent(**chunk_data["delta"])
|
||||||
|
if "text" in delta_obj:
|
||||||
|
text = delta_obj["text"]
|
||||||
|
elif "toolUse" in delta_obj:
|
||||||
|
tool_use = {
|
||||||
|
"id": None,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": None,
|
||||||
|
"arguments": delta_obj["toolUse"]["input"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
elif "stopReason" in chunk_data:
|
||||||
|
finish_reason = map_finish_reason(chunk_data.get("stopReason", "stop"))
|
||||||
|
is_finished = True
|
||||||
|
elif "usage" in chunk_data:
|
||||||
|
usage = ConverseTokenUsageBlock(**chunk_data["usage"]) # type: ignore
|
||||||
|
|
||||||
|
response = GenericStreamingChunk(
|
||||||
|
text=text,
|
||||||
|
tool_use=tool_use,
|
||||||
|
is_finished=is_finished,
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
usage=usage,
|
||||||
|
index=index,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception("Received streaming error - {}".format(str(e)))
|
||||||
|
|
||||||
def _chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
|
def _chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
|
||||||
text = ""
|
text = ""
|
||||||
|
@ -1890,12 +1921,16 @@ class AWSEventStreamDecoder:
|
||||||
if "outputText" in chunk_data:
|
if "outputText" in chunk_data:
|
||||||
text = chunk_data["outputText"]
|
text = chunk_data["outputText"]
|
||||||
# ai21 mapping
|
# ai21 mapping
|
||||||
if "ai21" in self.model: # fake ai21 streaming
|
elif "ai21" in self.model: # fake ai21 streaming
|
||||||
text = chunk_data.get("completions")[0].get("data").get("text") # type: ignore
|
text = chunk_data.get("completions")[0].get("data").get("text") # type: ignore
|
||||||
is_finished = True
|
is_finished = True
|
||||||
finish_reason = "stop"
|
finish_reason = "stop"
|
||||||
######## bedrock.anthropic mappings ###############
|
######## bedrock.anthropic mappings ###############
|
||||||
elif "delta" in chunk_data:
|
elif (
|
||||||
|
"contentBlockIndex" in chunk_data
|
||||||
|
or "stopReason" in chunk_data
|
||||||
|
or "metrics" in chunk_data
|
||||||
|
):
|
||||||
return self.converse_chunk_parser(chunk_data=chunk_data)
|
return self.converse_chunk_parser(chunk_data=chunk_data)
|
||||||
######## bedrock.mistral mappings ###############
|
######## bedrock.mistral mappings ###############
|
||||||
elif "outputs" in chunk_data:
|
elif "outputs" in chunk_data:
|
||||||
|
@ -1905,7 +1940,7 @@ class AWSEventStreamDecoder:
|
||||||
):
|
):
|
||||||
text = chunk_data["outputs"][0]["text"]
|
text = chunk_data["outputs"][0]["text"]
|
||||||
stop_reason = chunk_data.get("stop_reason", None)
|
stop_reason = chunk_data.get("stop_reason", None)
|
||||||
if stop_reason != None:
|
if stop_reason is not None:
|
||||||
is_finished = True
|
is_finished = True
|
||||||
finish_reason = stop_reason
|
finish_reason = stop_reason
|
||||||
######## bedrock.cohere mappings ###############
|
######## bedrock.cohere mappings ###############
|
||||||
|
@ -1926,8 +1961,9 @@ class AWSEventStreamDecoder:
|
||||||
text=text,
|
text=text,
|
||||||
is_finished=is_finished,
|
is_finished=is_finished,
|
||||||
finish_reason=finish_reason,
|
finish_reason=finish_reason,
|
||||||
tool_str="",
|
|
||||||
usage=None,
|
usage=None,
|
||||||
|
index=0,
|
||||||
|
tool_use=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GenericStreamingChunk]:
|
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GenericStreamingChunk]:
|
||||||
|
|
|
@ -2535,7 +2535,10 @@ def streaming_and_function_calling_format_tests(idx, chunk):
|
||||||
return extracted_chunk, finished
|
return extracted_chunk, finished
|
||||||
|
|
||||||
|
|
||||||
def test_openai_streaming_and_function_calling():
|
@pytest.mark.parametrize(
|
||||||
|
"model", ["gpt-3.5-turbo", "anthropic.claude-3-sonnet-20240229-v1:0"]
|
||||||
|
)
|
||||||
|
def test_streaming_and_function_calling(model):
|
||||||
tools = [
|
tools = [
|
||||||
{
|
{
|
||||||
"type": "function",
|
"type": "function",
|
||||||
|
@ -2556,16 +2559,21 @@ def test_openai_streaming_and_function_calling():
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
|
messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
|
||||||
try:
|
try:
|
||||||
response = completion(
|
litellm.set_verbose = True
|
||||||
model="gpt-3.5-turbo",
|
response: litellm.CustomStreamWrapper = completion(
|
||||||
|
model=model,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
tool_choice="required",
|
||||||
|
) # type: ignore
|
||||||
# Add any assertions here to check the response
|
# Add any assertions here to check the response
|
||||||
for idx, chunk in enumerate(response):
|
for idx, chunk in enumerate(response):
|
||||||
|
# continue
|
||||||
|
print("\n{}\n".format(chunk))
|
||||||
if idx == 0:
|
if idx == 0:
|
||||||
assert (
|
assert (
|
||||||
chunk.choices[0].delta.tool_calls[0].function.arguments is not None
|
chunk.choices[0].delta.tool_calls[0].function.arguments is not None
|
||||||
|
@ -2573,6 +2581,7 @@ def test_openai_streaming_and_function_calling():
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
chunk.choices[0].delta.tool_calls[0].function.arguments, str
|
chunk.choices[0].delta.tool_calls[0].function.arguments, str
|
||||||
)
|
)
|
||||||
|
# assert False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
raise e
|
raise e
|
||||||
|
|
|
@ -3990,6 +3990,7 @@ def test_async_text_completion():
|
||||||
asyncio.run(test_get_response())
|
asyncio.run(test_get_response())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="Tgai endpoints are unstable")
|
||||||
def test_async_text_completion_together_ai():
|
def test_async_text_completion_together_ai():
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
print("test_async_text_completion")
|
print("test_async_text_completion")
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from typing import TypedDict, Any, Union, Optional, Literal, List
|
from typing import TypedDict, Any, Union, Optional, Literal, List
|
||||||
import json
|
import json
|
||||||
|
from .openai import ChatCompletionToolCallChunk
|
||||||
from typing_extensions import (
|
from typing_extensions import (
|
||||||
Self,
|
Self,
|
||||||
Protocol,
|
Protocol,
|
||||||
|
@ -118,6 +119,15 @@ class ToolBlockDeltaEvent(TypedDict):
|
||||||
input: str
|
input: str
|
||||||
|
|
||||||
|
|
||||||
|
class ToolUseBlockStartEvent(TypedDict):
|
||||||
|
name: str
|
||||||
|
toolUseId: str
|
||||||
|
|
||||||
|
|
||||||
|
class ContentBlockStartEvent(TypedDict, total=False):
|
||||||
|
toolUse: Optional[ToolUseBlockStartEvent]
|
||||||
|
|
||||||
|
|
||||||
class ContentBlockDeltaEvent(TypedDict, total=False):
|
class ContentBlockDeltaEvent(TypedDict, total=False):
|
||||||
"""
|
"""
|
||||||
Either 'text' or 'toolUse' will be specified for Converse API streaming response.
|
Either 'text' or 'toolUse' will be specified for Converse API streaming response.
|
||||||
|
@ -138,10 +148,11 @@ class RequestObject(TypedDict, total=False):
|
||||||
|
|
||||||
class GenericStreamingChunk(TypedDict):
|
class GenericStreamingChunk(TypedDict):
|
||||||
text: Required[str]
|
text: Required[str]
|
||||||
tool_str: Required[str]
|
tool_use: Optional[ChatCompletionToolCallChunk]
|
||||||
is_finished: Required[bool]
|
is_finished: Required[bool]
|
||||||
finish_reason: Required[str]
|
finish_reason: Required[str]
|
||||||
usage: Optional[ConverseTokenUsageBlock]
|
usage: Optional[ConverseTokenUsageBlock]
|
||||||
|
index: int
|
||||||
|
|
||||||
|
|
||||||
class Document(TypedDict):
|
class Document(TypedDict):
|
||||||
|
|
|
@ -296,14 +296,27 @@ class ListBatchRequest(TypedDict, total=False):
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionToolCallFunctionChunk(TypedDict):
|
class ChatCompletionToolCallFunctionChunk(TypedDict):
|
||||||
name: str
|
name: Optional[str]
|
||||||
arguments: str
|
arguments: str
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionToolCallChunk(TypedDict):
|
class ChatCompletionToolCallChunk(TypedDict):
|
||||||
|
id: Optional[str]
|
||||||
|
type: Literal["function"]
|
||||||
|
function: ChatCompletionToolCallFunctionChunk
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionDeltaToolCallChunk(TypedDict):
|
||||||
id: str
|
id: str
|
||||||
type: Literal["function"]
|
type: Literal["function"]
|
||||||
function: ChatCompletionToolCallFunctionChunk
|
function: ChatCompletionToolCallFunctionChunk
|
||||||
|
index: int
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionDeltaChunk(TypedDict, total=False):
|
||||||
|
content: Optional[str]
|
||||||
|
tool_calls: List[ChatCompletionDeltaToolCallChunk]
|
||||||
|
role: str
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponseMessage(TypedDict, total=False):
|
class ChatCompletionResponseMessage(TypedDict, total=False):
|
||||||
|
|
|
@ -63,6 +63,11 @@ claude_json_str = json.dumps(json_data)
|
||||||
import importlib.metadata
|
import importlib.metadata
|
||||||
from ._logging import verbose_logger
|
from ._logging import verbose_logger
|
||||||
from .types.router import LiteLLM_Params
|
from .types.router import LiteLLM_Params
|
||||||
|
from .types.llms.openai import (
|
||||||
|
ChatCompletionToolCallChunk,
|
||||||
|
ChatCompletionToolCallFunctionChunk,
|
||||||
|
ChatCompletionDeltaToolCallChunk,
|
||||||
|
)
|
||||||
from .integrations.traceloop import TraceloopLogger
|
from .integrations.traceloop import TraceloopLogger
|
||||||
from .integrations.athina import AthinaLogger
|
from .integrations.athina import AthinaLogger
|
||||||
from .integrations.helicone import HeliconeLogger
|
from .integrations.helicone import HeliconeLogger
|
||||||
|
@ -3250,7 +3255,7 @@ def client(original_function):
|
||||||
stream=kwargs.get("stream", False),
|
stream=kwargs.get("stream", False),
|
||||||
)
|
)
|
||||||
|
|
||||||
if kwargs.get("stream", False) == True:
|
if kwargs.get("stream", False) is True:
|
||||||
cached_result = CustomStreamWrapper(
|
cached_result = CustomStreamWrapper(
|
||||||
completion_stream=cached_result,
|
completion_stream=cached_result,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -11301,7 +11306,6 @@ class CustomStreamWrapper:
|
||||||
raise StopIteration
|
raise StopIteration
|
||||||
response_obj: GenericStreamingChunk = chunk
|
response_obj: GenericStreamingChunk = chunk
|
||||||
completion_obj["content"] = response_obj["text"]
|
completion_obj["content"] = response_obj["text"]
|
||||||
|
|
||||||
if response_obj["is_finished"]:
|
if response_obj["is_finished"]:
|
||||||
self.received_finish_reason = response_obj["finish_reason"]
|
self.received_finish_reason = response_obj["finish_reason"]
|
||||||
|
|
||||||
|
@ -11316,6 +11320,10 @@ class CustomStreamWrapper:
|
||||||
completion_tokens=response_obj["usage"]["outputTokens"],
|
completion_tokens=response_obj["usage"]["outputTokens"],
|
||||||
total_tokens=response_obj["usage"]["totalTokens"],
|
total_tokens=response_obj["usage"]["totalTokens"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if "tool_use" in response_obj and response_obj["tool_use"] is not None:
|
||||||
|
completion_obj["tool_calls"] = [response_obj["tool_use"]]
|
||||||
|
|
||||||
elif self.custom_llm_provider == "sagemaker":
|
elif self.custom_llm_provider == "sagemaker":
|
||||||
print_verbose(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}")
|
print_verbose(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}")
|
||||||
response_obj = self.handle_sagemaker_stream(chunk)
|
response_obj = self.handle_sagemaker_stream(chunk)
|
||||||
|
@ -11332,7 +11340,6 @@ class CustomStreamWrapper:
|
||||||
new_chunk = self.completion_stream[:chunk_size]
|
new_chunk = self.completion_stream[:chunk_size]
|
||||||
completion_obj["content"] = new_chunk
|
completion_obj["content"] = new_chunk
|
||||||
self.completion_stream = self.completion_stream[chunk_size:]
|
self.completion_stream = self.completion_stream[chunk_size:]
|
||||||
time.sleep(0.05)
|
|
||||||
elif self.custom_llm_provider == "palm":
|
elif self.custom_llm_provider == "palm":
|
||||||
# fake streaming
|
# fake streaming
|
||||||
response_obj = {}
|
response_obj = {}
|
||||||
|
@ -11345,7 +11352,6 @@ class CustomStreamWrapper:
|
||||||
new_chunk = self.completion_stream[:chunk_size]
|
new_chunk = self.completion_stream[:chunk_size]
|
||||||
completion_obj["content"] = new_chunk
|
completion_obj["content"] = new_chunk
|
||||||
self.completion_stream = self.completion_stream[chunk_size:]
|
self.completion_stream = self.completion_stream[chunk_size:]
|
||||||
time.sleep(0.05)
|
|
||||||
elif self.custom_llm_provider == "ollama":
|
elif self.custom_llm_provider == "ollama":
|
||||||
response_obj = self.handle_ollama_stream(chunk)
|
response_obj = self.handle_ollama_stream(chunk)
|
||||||
completion_obj["content"] = response_obj["text"]
|
completion_obj["content"] = response_obj["text"]
|
||||||
|
@ -11432,7 +11438,7 @@ class CustomStreamWrapper:
|
||||||
# for azure, we need to pass the model from the orignal chunk
|
# for azure, we need to pass the model from the orignal chunk
|
||||||
self.model = chunk.model
|
self.model = chunk.model
|
||||||
response_obj = self.handle_openai_chat_completion_chunk(chunk)
|
response_obj = self.handle_openai_chat_completion_chunk(chunk)
|
||||||
if response_obj == None:
|
if response_obj is None:
|
||||||
return
|
return
|
||||||
completion_obj["content"] = response_obj["text"]
|
completion_obj["content"] = response_obj["text"]
|
||||||
print_verbose(f"completion obj content: {completion_obj['content']}")
|
print_verbose(f"completion obj content: {completion_obj['content']}")
|
||||||
|
@ -11565,7 +11571,7 @@ class CustomStreamWrapper:
|
||||||
else:
|
else:
|
||||||
if (
|
if (
|
||||||
self.stream_options is not None
|
self.stream_options is not None
|
||||||
and self.stream_options["include_usage"] == True
|
and self.stream_options["include_usage"] is True
|
||||||
):
|
):
|
||||||
return model_response
|
return model_response
|
||||||
return
|
return
|
||||||
|
@ -11590,8 +11596,14 @@ class CustomStreamWrapper:
|
||||||
return model_response
|
return model_response
|
||||||
elif (
|
elif (
|
||||||
"content" in completion_obj
|
"content" in completion_obj
|
||||||
and isinstance(completion_obj["content"], str)
|
and (
|
||||||
and len(completion_obj["content"]) > 0
|
isinstance(completion_obj["content"], str)
|
||||||
|
and len(completion_obj["content"]) > 0
|
||||||
|
)
|
||||||
|
or (
|
||||||
|
"tool_calls" in completion_obj
|
||||||
|
and len(completion_obj["tool_calls"]) > 0
|
||||||
|
)
|
||||||
): # cannot set content of an OpenAI Object to be an empty string
|
): # cannot set content of an OpenAI Object to be an empty string
|
||||||
hold, model_response_str = self.check_special_tokens(
|
hold, model_response_str = self.check_special_tokens(
|
||||||
chunk=completion_obj["content"],
|
chunk=completion_obj["content"],
|
||||||
|
@ -11647,7 +11659,7 @@ class CustomStreamWrapper:
|
||||||
else:
|
else:
|
||||||
## else
|
## else
|
||||||
completion_obj["content"] = model_response_str
|
completion_obj["content"] = model_response_str
|
||||||
if self.sent_first_chunk == False:
|
if self.sent_first_chunk is False:
|
||||||
completion_obj["role"] = "assistant"
|
completion_obj["role"] = "assistant"
|
||||||
self.sent_first_chunk = True
|
self.sent_first_chunk = True
|
||||||
model_response.choices[0].delta = Delta(**completion_obj)
|
model_response.choices[0].delta = Delta(**completion_obj)
|
||||||
|
@ -11656,7 +11668,7 @@ class CustomStreamWrapper:
|
||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
elif self.received_finish_reason is not None:
|
elif self.received_finish_reason is not None:
|
||||||
if self.sent_last_chunk == True:
|
if self.sent_last_chunk is True:
|
||||||
raise StopIteration
|
raise StopIteration
|
||||||
# flush any remaining holding chunk
|
# flush any remaining holding chunk
|
||||||
if len(self.holding_chunk) > 0:
|
if len(self.holding_chunk) > 0:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue