fix(anthropic.py): support streaming with function calling

This commit is contained in:
Krrish Dholakia 2024-03-12 09:52:11 -07:00
parent 10f5f342bd
commit 86ed0aaba8
4 changed files with 109 additions and 7 deletions

View file

@ -4,7 +4,7 @@ from enum import Enum
import requests, copy
import time, uuid
from typing import Callable, Optional
from litellm.utils import ModelResponse, Usage, map_finish_reason
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
import litellm
from .prompt_templates.factory import (
prompt_factory,
@ -118,6 +118,7 @@ def completion(
headers = validate_environment(api_key, headers)
_is_function_call = False
messages = copy.deepcopy(messages)
optional_params = copy.deepcopy(optional_params)
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
@ -161,6 +162,8 @@ def completion(
) # add the anthropic tool calling prompt to the system prompt
optional_params.pop("tools")
stream = optional_params.pop("stream", None)
data = {
"model": model,
"messages": messages,
@ -179,7 +182,10 @@ def completion(
)
## COMPLETION CALL
if "stream" in optional_params and optional_params["stream"] == True:
if (
stream is not None and stream == True and _is_function_call == False
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
data["stream"] = stream
response = requests.post(
api_base,
headers=headers,
@ -255,6 +261,39 @@ def completion(
completion_response["stop_reason"]
)
if _is_function_call == True and stream is not None and stream == True:
# return an iterator
streaming_model_response = ModelResponse(stream=True)
streaming_model_response.choices[0].finish_reason = model_response.choices[
0
].finish_reason
streaming_model_response.choices[0].index = model_response.choices[0].index
_tool_calls = []
if isinstance(model_response.choices[0], litellm.Choices):
if getattr(
model_response.choices[0].message, "tool_calls", None
) is not None and isinstance(
model_response.choices[0].message.tool_calls, list
):
for tool_call in model_response.choices[0].message.tool_calls:
_tool_call = {**tool_call.dict(), "index": 0}
_tool_calls.append(_tool_call)
delta_obj = litellm.utils.Delta(
content=getattr(model_response.choices[0].message, "content", None),
role=model_response.choices[0].message.role,
tool_calls=_tool_calls,
)
streaming_model_response.choices[0].delta = delta_obj
completion_stream = model_response_iterator(
model_response=streaming_model_response
)
return CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="cached_response",
logging_obj=logging_obj,
)
## CALCULATING USAGE
prompt_tokens = completion_response["usage"]["input_tokens"]
completion_tokens = completion_response["usage"]["output_tokens"]
@ -271,6 +310,10 @@ def completion(
return model_response
def model_response_iterator(model_response):
yield model_response
def embedding():
# logic for parsing in - calling - parsing out model embedding calls
pass

View file

@ -1073,7 +1073,11 @@ def completion(
logging_obj=logging,
headers=headers,
)
if "stream" in optional_params and optional_params["stream"] == True:
if (
"stream" in optional_params
and optional_params["stream"] == True
and not isinstance(response, CustomStreamWrapper)
):
# don't try to access stream object,
response = CustomStreamWrapper(
response,

View file

@ -1749,7 +1749,7 @@ class Chunk(BaseModel):
object: str
created: int
model: str
system_fingerprint: str
# system_fingerprint: str
choices: List[Choices]
@ -1869,7 +1869,7 @@ class Chunk3(BaseModel):
object: str
created: int
model: str
system_fingerprint: str
# system_fingerprint: str
choices: List[Choices3]
@ -2032,3 +2032,56 @@ async def test_azure_astreaming_and_function_calling():
except Exception as e:
pytest.fail(f"Error occurred: {e}")
raise e
def test_completion_claude_3_function_call_with_streaming():
# litellm.set_verbose = True
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]
try:
# test without max tokens
response = completion(
model="claude-3-opus-20240229",
messages=messages,
tools=tools,
tool_choice="auto",
stream=True,
)
idx = 0
for chunk in response:
# print(f"chunk: {chunk}")
if idx == 0:
assert (
chunk.choices[0].delta.tool_calls[0].function.arguments is not None
)
assert isinstance(
chunk.choices[0].delta.tool_calls[0].function.arguments, str
)
validate_first_streaming_function_calling_chunk(chunk=chunk)
elif idx == 1:
validate_second_streaming_function_calling_chunk(chunk=chunk)
elif chunk.choices[0].finish_reason is not None: # last chunk
validate_final_streaming_function_calling_chunk(chunk=chunk)
idx += 1
# raise Exception("it worked!")
except Exception as e:
pytest.fail(f"Error occurred: {e}")

View file

@ -9364,8 +9364,10 @@ class CustomStreamWrapper:
def __next__(self):
try:
while True:
if isinstance(self.completion_stream, str) or isinstance(
self.completion_stream, bytes
if (
isinstance(self.completion_stream, str)
or isinstance(self.completion_stream, bytes)
or isinstance(self.completion_stream, ModelResponse)
):
chunk = self.completion_stream
else: