fix(anthropic.py): support streaming with function calling

This commit is contained in:
Krrish Dholakia 2024-03-12 09:52:11 -07:00
parent 10f5f342bd
commit 86ed0aaba8
4 changed files with 109 additions and 7 deletions

View file

@ -4,7 +4,7 @@ from enum import Enum
import requests, copy import requests, copy
import time, uuid import time, uuid
from typing import Callable, Optional from typing import Callable, Optional
from litellm.utils import ModelResponse, Usage, map_finish_reason from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
import litellm import litellm
from .prompt_templates.factory import ( from .prompt_templates.factory import (
prompt_factory, prompt_factory,
@ -118,6 +118,7 @@ def completion(
headers = validate_environment(api_key, headers) headers = validate_environment(api_key, headers)
_is_function_call = False _is_function_call = False
messages = copy.deepcopy(messages) messages = copy.deepcopy(messages)
optional_params = copy.deepcopy(optional_params)
if model in custom_prompt_dict: if model in custom_prompt_dict:
# check if the model has a registered custom prompt # check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model] model_prompt_details = custom_prompt_dict[model]
@ -161,6 +162,8 @@ def completion(
) # add the anthropic tool calling prompt to the system prompt ) # add the anthropic tool calling prompt to the system prompt
optional_params.pop("tools") optional_params.pop("tools")
stream = optional_params.pop("stream", None)
data = { data = {
"model": model, "model": model,
"messages": messages, "messages": messages,
@ -179,7 +182,10 @@ def completion(
) )
## COMPLETION CALL ## COMPLETION CALL
if "stream" in optional_params and optional_params["stream"] == True: if (
stream is not None and stream == True and _is_function_call == False
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
data["stream"] = stream
response = requests.post( response = requests.post(
api_base, api_base,
headers=headers, headers=headers,
@ -255,6 +261,39 @@ def completion(
completion_response["stop_reason"] completion_response["stop_reason"]
) )
if _is_function_call == True and stream is not None and stream == True:
# return an iterator
streaming_model_response = ModelResponse(stream=True)
streaming_model_response.choices[0].finish_reason = model_response.choices[
0
].finish_reason
streaming_model_response.choices[0].index = model_response.choices[0].index
_tool_calls = []
if isinstance(model_response.choices[0], litellm.Choices):
if getattr(
model_response.choices[0].message, "tool_calls", None
) is not None and isinstance(
model_response.choices[0].message.tool_calls, list
):
for tool_call in model_response.choices[0].message.tool_calls:
_tool_call = {**tool_call.dict(), "index": 0}
_tool_calls.append(_tool_call)
delta_obj = litellm.utils.Delta(
content=getattr(model_response.choices[0].message, "content", None),
role=model_response.choices[0].message.role,
tool_calls=_tool_calls,
)
streaming_model_response.choices[0].delta = delta_obj
completion_stream = model_response_iterator(
model_response=streaming_model_response
)
return CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="cached_response",
logging_obj=logging_obj,
)
## CALCULATING USAGE ## CALCULATING USAGE
prompt_tokens = completion_response["usage"]["input_tokens"] prompt_tokens = completion_response["usage"]["input_tokens"]
completion_tokens = completion_response["usage"]["output_tokens"] completion_tokens = completion_response["usage"]["output_tokens"]
@ -271,6 +310,10 @@ def completion(
return model_response return model_response
def model_response_iterator(model_response):
yield model_response
def embedding(): def embedding():
# logic for parsing in - calling - parsing out model embedding calls # logic for parsing in - calling - parsing out model embedding calls
pass pass

View file

@ -1073,7 +1073,11 @@ def completion(
logging_obj=logging, logging_obj=logging,
headers=headers, headers=headers,
) )
if "stream" in optional_params and optional_params["stream"] == True: if (
"stream" in optional_params
and optional_params["stream"] == True
and not isinstance(response, CustomStreamWrapper)
):
# don't try to access stream object, # don't try to access stream object,
response = CustomStreamWrapper( response = CustomStreamWrapper(
response, response,

View file

@ -1749,7 +1749,7 @@ class Chunk(BaseModel):
object: str object: str
created: int created: int
model: str model: str
system_fingerprint: str # system_fingerprint: str
choices: List[Choices] choices: List[Choices]
@ -1869,7 +1869,7 @@ class Chunk3(BaseModel):
object: str object: str
created: int created: int
model: str model: str
system_fingerprint: str # system_fingerprint: str
choices: List[Choices3] choices: List[Choices3]
@ -2032,3 +2032,56 @@ async def test_azure_astreaming_and_function_calling():
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
raise e raise e
def test_completion_claude_3_function_call_with_streaming():
# litellm.set_verbose = True
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]
try:
# test without max tokens
response = completion(
model="claude-3-opus-20240229",
messages=messages,
tools=tools,
tool_choice="auto",
stream=True,
)
idx = 0
for chunk in response:
# print(f"chunk: {chunk}")
if idx == 0:
assert (
chunk.choices[0].delta.tool_calls[0].function.arguments is not None
)
assert isinstance(
chunk.choices[0].delta.tool_calls[0].function.arguments, str
)
validate_first_streaming_function_calling_chunk(chunk=chunk)
elif idx == 1:
validate_second_streaming_function_calling_chunk(chunk=chunk)
elif chunk.choices[0].finish_reason is not None: # last chunk
validate_final_streaming_function_calling_chunk(chunk=chunk)
idx += 1
# raise Exception("it worked!")
except Exception as e:
pytest.fail(f"Error occurred: {e}")

View file

@ -9364,8 +9364,10 @@ class CustomStreamWrapper:
def __next__(self): def __next__(self):
try: try:
while True: while True:
if isinstance(self.completion_stream, str) or isinstance( if (
self.completion_stream, bytes isinstance(self.completion_stream, str)
or isinstance(self.completion_stream, bytes)
or isinstance(self.completion_stream, ModelResponse)
): ):
chunk = self.completion_stream chunk = self.completion_stream
else: else: