forked from phoenix/litellm-mirror
Merge pull request #4106 from BerriAI/litellm_anthropic_bedrock_tool_calling_fix
fix(bedrock_httpx.py): fix tool calling for anthropic bedrock calls w/ streaming
This commit is contained in:
commit
8379d58318
6 changed files with 125 additions and 43 deletions
|
@ -51,6 +51,7 @@ from litellm.types.llms.openai import (
|
|||
ChatCompletionResponseMessage,
|
||||
ChatCompletionToolCallChunk,
|
||||
ChatCompletionToolCallFunctionChunk,
|
||||
ChatCompletionDeltaChunk,
|
||||
)
|
||||
|
||||
|
||||
|
@ -1859,29 +1860,59 @@ class AWSEventStreamDecoder:
|
|||
self.parser = EventStreamJSONParser()
|
||||
|
||||
def converse_chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
|
||||
text = ""
|
||||
tool_str = ""
|
||||
is_finished = False
|
||||
finish_reason = ""
|
||||
usage: Optional[ConverseTokenUsageBlock] = None
|
||||
if "delta" in chunk_data:
|
||||
delta_obj = ContentBlockDeltaEvent(**chunk_data["delta"])
|
||||
if "text" in delta_obj:
|
||||
text = delta_obj["text"]
|
||||
elif "toolUse" in delta_obj:
|
||||
tool_str = delta_obj["toolUse"]["input"]
|
||||
elif "stopReason" in chunk_data:
|
||||
finish_reason = map_finish_reason(chunk_data.get("stopReason", "stop"))
|
||||
elif "usage" in chunk_data:
|
||||
usage = ConverseTokenUsageBlock(**chunk_data["usage"]) # type: ignore
|
||||
response = GenericStreamingChunk(
|
||||
text=text,
|
||||
tool_str=tool_str,
|
||||
is_finished=is_finished,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
)
|
||||
return response
|
||||
try:
|
||||
text = ""
|
||||
tool_use: Optional[ChatCompletionToolCallChunk] = None
|
||||
is_finished = False
|
||||
finish_reason = ""
|
||||
usage: Optional[ConverseTokenUsageBlock] = None
|
||||
|
||||
index = int(chunk_data.get("contentBlockIndex", 0))
|
||||
if "start" in chunk_data:
|
||||
start_obj = ContentBlockStartEvent(**chunk_data["start"])
|
||||
if (
|
||||
start_obj is not None
|
||||
and "toolUse" in start_obj
|
||||
and start_obj["toolUse"] is not None
|
||||
):
|
||||
tool_use = {
|
||||
"id": start_obj["toolUse"]["toolUseId"],
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": start_obj["toolUse"]["name"],
|
||||
"arguments": "",
|
||||
},
|
||||
}
|
||||
elif "delta" in chunk_data:
|
||||
delta_obj = ContentBlockDeltaEvent(**chunk_data["delta"])
|
||||
if "text" in delta_obj:
|
||||
text = delta_obj["text"]
|
||||
elif "toolUse" in delta_obj:
|
||||
tool_use = {
|
||||
"id": None,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": None,
|
||||
"arguments": delta_obj["toolUse"]["input"],
|
||||
},
|
||||
}
|
||||
elif "stopReason" in chunk_data:
|
||||
finish_reason = map_finish_reason(chunk_data.get("stopReason", "stop"))
|
||||
is_finished = True
|
||||
elif "usage" in chunk_data:
|
||||
usage = ConverseTokenUsageBlock(**chunk_data["usage"]) # type: ignore
|
||||
|
||||
response = GenericStreamingChunk(
|
||||
text=text,
|
||||
tool_use=tool_use,
|
||||
is_finished=is_finished,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
index=index,
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise Exception("Received streaming error - {}".format(str(e)))
|
||||
|
||||
def _chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
|
||||
text = ""
|
||||
|
@ -1890,12 +1921,16 @@ class AWSEventStreamDecoder:
|
|||
if "outputText" in chunk_data:
|
||||
text = chunk_data["outputText"]
|
||||
# ai21 mapping
|
||||
if "ai21" in self.model: # fake ai21 streaming
|
||||
elif "ai21" in self.model: # fake ai21 streaming
|
||||
text = chunk_data.get("completions")[0].get("data").get("text") # type: ignore
|
||||
is_finished = True
|
||||
finish_reason = "stop"
|
||||
######## bedrock.anthropic mappings ###############
|
||||
elif "delta" in chunk_data:
|
||||
elif (
|
||||
"contentBlockIndex" in chunk_data
|
||||
or "stopReason" in chunk_data
|
||||
or "metrics" in chunk_data
|
||||
):
|
||||
return self.converse_chunk_parser(chunk_data=chunk_data)
|
||||
######## bedrock.mistral mappings ###############
|
||||
elif "outputs" in chunk_data:
|
||||
|
@ -1905,7 +1940,7 @@ class AWSEventStreamDecoder:
|
|||
):
|
||||
text = chunk_data["outputs"][0]["text"]
|
||||
stop_reason = chunk_data.get("stop_reason", None)
|
||||
if stop_reason != None:
|
||||
if stop_reason is not None:
|
||||
is_finished = True
|
||||
finish_reason = stop_reason
|
||||
######## bedrock.cohere mappings ###############
|
||||
|
@ -1926,8 +1961,9 @@ class AWSEventStreamDecoder:
|
|||
text=text,
|
||||
is_finished=is_finished,
|
||||
finish_reason=finish_reason,
|
||||
tool_str="",
|
||||
usage=None,
|
||||
index=0,
|
||||
tool_use=None,
|
||||
)
|
||||
|
||||
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GenericStreamingChunk]:
|
||||
|
|
|
@ -2535,7 +2535,10 @@ def streaming_and_function_calling_format_tests(idx, chunk):
|
|||
return extracted_chunk, finished
|
||||
|
||||
|
||||
def test_openai_streaming_and_function_calling():
|
||||
@pytest.mark.parametrize(
|
||||
"model", ["gpt-3.5-turbo", "anthropic.claude-3-sonnet-20240229-v1:0"]
|
||||
)
|
||||
def test_streaming_and_function_calling(model):
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
|
@ -2556,16 +2559,21 @@ def test_openai_streaming_and_function_calling():
|
|||
},
|
||||
}
|
||||
]
|
||||
|
||||
messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
|
||||
try:
|
||||
response = completion(
|
||||
model="gpt-3.5-turbo",
|
||||
litellm.set_verbose = True
|
||||
response: litellm.CustomStreamWrapper = completion(
|
||||
model=model,
|
||||
tools=tools,
|
||||
messages=messages,
|
||||
stream=True,
|
||||
)
|
||||
tool_choice="required",
|
||||
) # type: ignore
|
||||
# Add any assertions here to check the response
|
||||
for idx, chunk in enumerate(response):
|
||||
# continue
|
||||
print("\n{}\n".format(chunk))
|
||||
if idx == 0:
|
||||
assert (
|
||||
chunk.choices[0].delta.tool_calls[0].function.arguments is not None
|
||||
|
@ -2573,6 +2581,7 @@ def test_openai_streaming_and_function_calling():
|
|||
assert isinstance(
|
||||
chunk.choices[0].delta.tool_calls[0].function.arguments, str
|
||||
)
|
||||
# assert False
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
raise e
|
||||
|
|
|
@ -3990,6 +3990,7 @@ def test_async_text_completion():
|
|||
asyncio.run(test_get_response())
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Tgai endpoints are unstable")
|
||||
def test_async_text_completion_together_ai():
|
||||
litellm.set_verbose = True
|
||||
print("test_async_text_completion")
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from typing import TypedDict, Any, Union, Optional, Literal, List
|
||||
import json
|
||||
from .openai import ChatCompletionToolCallChunk
|
||||
from typing_extensions import (
|
||||
Self,
|
||||
Protocol,
|
||||
|
@ -118,6 +119,15 @@ class ToolBlockDeltaEvent(TypedDict):
|
|||
input: str
|
||||
|
||||
|
||||
class ToolUseBlockStartEvent(TypedDict):
|
||||
name: str
|
||||
toolUseId: str
|
||||
|
||||
|
||||
class ContentBlockStartEvent(TypedDict, total=False):
|
||||
toolUse: Optional[ToolUseBlockStartEvent]
|
||||
|
||||
|
||||
class ContentBlockDeltaEvent(TypedDict, total=False):
|
||||
"""
|
||||
Either 'text' or 'toolUse' will be specified for Converse API streaming response.
|
||||
|
@ -138,10 +148,11 @@ class RequestObject(TypedDict, total=False):
|
|||
|
||||
class GenericStreamingChunk(TypedDict):
|
||||
text: Required[str]
|
||||
tool_str: Required[str]
|
||||
tool_use: Optional[ChatCompletionToolCallChunk]
|
||||
is_finished: Required[bool]
|
||||
finish_reason: Required[str]
|
||||
usage: Optional[ConverseTokenUsageBlock]
|
||||
index: int
|
||||
|
||||
|
||||
class Document(TypedDict):
|
||||
|
|
|
@ -296,14 +296,27 @@ class ListBatchRequest(TypedDict, total=False):
|
|||
|
||||
|
||||
class ChatCompletionToolCallFunctionChunk(TypedDict):
|
||||
name: str
|
||||
name: Optional[str]
|
||||
arguments: str
|
||||
|
||||
|
||||
class ChatCompletionToolCallChunk(TypedDict):
|
||||
id: Optional[str]
|
||||
type: Literal["function"]
|
||||
function: ChatCompletionToolCallFunctionChunk
|
||||
|
||||
|
||||
class ChatCompletionDeltaToolCallChunk(TypedDict):
|
||||
id: str
|
||||
type: Literal["function"]
|
||||
function: ChatCompletionToolCallFunctionChunk
|
||||
index: int
|
||||
|
||||
|
||||
class ChatCompletionDeltaChunk(TypedDict, total=False):
|
||||
content: Optional[str]
|
||||
tool_calls: List[ChatCompletionDeltaToolCallChunk]
|
||||
role: str
|
||||
|
||||
|
||||
class ChatCompletionResponseMessage(TypedDict, total=False):
|
||||
|
|
|
@ -63,6 +63,11 @@ claude_json_str = json.dumps(json_data)
|
|||
import importlib.metadata
|
||||
from ._logging import verbose_logger
|
||||
from .types.router import LiteLLM_Params
|
||||
from .types.llms.openai import (
|
||||
ChatCompletionToolCallChunk,
|
||||
ChatCompletionToolCallFunctionChunk,
|
||||
ChatCompletionDeltaToolCallChunk,
|
||||
)
|
||||
from .integrations.traceloop import TraceloopLogger
|
||||
from .integrations.athina import AthinaLogger
|
||||
from .integrations.helicone import HeliconeLogger
|
||||
|
@ -3250,7 +3255,7 @@ def client(original_function):
|
|||
stream=kwargs.get("stream", False),
|
||||
)
|
||||
|
||||
if kwargs.get("stream", False) == True:
|
||||
if kwargs.get("stream", False) is True:
|
||||
cached_result = CustomStreamWrapper(
|
||||
completion_stream=cached_result,
|
||||
model=model,
|
||||
|
@ -11301,7 +11306,6 @@ class CustomStreamWrapper:
|
|||
raise StopIteration
|
||||
response_obj: GenericStreamingChunk = chunk
|
||||
completion_obj["content"] = response_obj["text"]
|
||||
|
||||
if response_obj["is_finished"]:
|
||||
self.received_finish_reason = response_obj["finish_reason"]
|
||||
|
||||
|
@ -11316,6 +11320,10 @@ class CustomStreamWrapper:
|
|||
completion_tokens=response_obj["usage"]["outputTokens"],
|
||||
total_tokens=response_obj["usage"]["totalTokens"],
|
||||
)
|
||||
|
||||
if "tool_use" in response_obj and response_obj["tool_use"] is not None:
|
||||
completion_obj["tool_calls"] = [response_obj["tool_use"]]
|
||||
|
||||
elif self.custom_llm_provider == "sagemaker":
|
||||
print_verbose(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}")
|
||||
response_obj = self.handle_sagemaker_stream(chunk)
|
||||
|
@ -11332,7 +11340,6 @@ class CustomStreamWrapper:
|
|||
new_chunk = self.completion_stream[:chunk_size]
|
||||
completion_obj["content"] = new_chunk
|
||||
self.completion_stream = self.completion_stream[chunk_size:]
|
||||
time.sleep(0.05)
|
||||
elif self.custom_llm_provider == "palm":
|
||||
# fake streaming
|
||||
response_obj = {}
|
||||
|
@ -11345,7 +11352,6 @@ class CustomStreamWrapper:
|
|||
new_chunk = self.completion_stream[:chunk_size]
|
||||
completion_obj["content"] = new_chunk
|
||||
self.completion_stream = self.completion_stream[chunk_size:]
|
||||
time.sleep(0.05)
|
||||
elif self.custom_llm_provider == "ollama":
|
||||
response_obj = self.handle_ollama_stream(chunk)
|
||||
completion_obj["content"] = response_obj["text"]
|
||||
|
@ -11432,7 +11438,7 @@ class CustomStreamWrapper:
|
|||
# for azure, we need to pass the model from the orignal chunk
|
||||
self.model = chunk.model
|
||||
response_obj = self.handle_openai_chat_completion_chunk(chunk)
|
||||
if response_obj == None:
|
||||
if response_obj is None:
|
||||
return
|
||||
completion_obj["content"] = response_obj["text"]
|
||||
print_verbose(f"completion obj content: {completion_obj['content']}")
|
||||
|
@ -11565,7 +11571,7 @@ class CustomStreamWrapper:
|
|||
else:
|
||||
if (
|
||||
self.stream_options is not None
|
||||
and self.stream_options["include_usage"] == True
|
||||
and self.stream_options["include_usage"] is True
|
||||
):
|
||||
return model_response
|
||||
return
|
||||
|
@ -11590,8 +11596,14 @@ class CustomStreamWrapper:
|
|||
return model_response
|
||||
elif (
|
||||
"content" in completion_obj
|
||||
and isinstance(completion_obj["content"], str)
|
||||
and len(completion_obj["content"]) > 0
|
||||
and (
|
||||
isinstance(completion_obj["content"], str)
|
||||
and len(completion_obj["content"]) > 0
|
||||
)
|
||||
or (
|
||||
"tool_calls" in completion_obj
|
||||
and len(completion_obj["tool_calls"]) > 0
|
||||
)
|
||||
): # cannot set content of an OpenAI Object to be an empty string
|
||||
hold, model_response_str = self.check_special_tokens(
|
||||
chunk=completion_obj["content"],
|
||||
|
@ -11647,7 +11659,7 @@ class CustomStreamWrapper:
|
|||
else:
|
||||
## else
|
||||
completion_obj["content"] = model_response_str
|
||||
if self.sent_first_chunk == False:
|
||||
if self.sent_first_chunk is False:
|
||||
completion_obj["role"] = "assistant"
|
||||
self.sent_first_chunk = True
|
||||
model_response.choices[0].delta = Delta(**completion_obj)
|
||||
|
@ -11656,7 +11668,7 @@ class CustomStreamWrapper:
|
|||
else:
|
||||
return
|
||||
elif self.received_finish_reason is not None:
|
||||
if self.sent_last_chunk == True:
|
||||
if self.sent_last_chunk is True:
|
||||
raise StopIteration
|
||||
# flush any remaining holding chunk
|
||||
if len(self.holding_chunk) > 0:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue