forked from phoenix/litellm-mirror
fix(anthropic.py): support streaming with function calling
This commit is contained in:
parent
10f5f342bd
commit
86ed0aaba8
4 changed files with 109 additions and 7 deletions
|
@ -4,7 +4,7 @@ from enum import Enum
|
|||
import requests, copy
|
||||
import time, uuid
|
||||
from typing import Callable, Optional
|
||||
from litellm.utils import ModelResponse, Usage, map_finish_reason
|
||||
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
|
||||
import litellm
|
||||
from .prompt_templates.factory import (
|
||||
prompt_factory,
|
||||
|
@ -118,6 +118,7 @@ def completion(
|
|||
headers = validate_environment(api_key, headers)
|
||||
_is_function_call = False
|
||||
messages = copy.deepcopy(messages)
|
||||
optional_params = copy.deepcopy(optional_params)
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
|
@ -161,6 +162,8 @@ def completion(
|
|||
) # add the anthropic tool calling prompt to the system prompt
|
||||
optional_params.pop("tools")
|
||||
|
||||
stream = optional_params.pop("stream", None)
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
|
@ -179,7 +182,10 @@ def completion(
|
|||
)
|
||||
|
||||
## COMPLETION CALL
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
if (
|
||||
stream is not None and stream == True and _is_function_call == False
|
||||
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
|
||||
data["stream"] = stream
|
||||
response = requests.post(
|
||||
api_base,
|
||||
headers=headers,
|
||||
|
@ -255,6 +261,39 @@ def completion(
|
|||
completion_response["stop_reason"]
|
||||
)
|
||||
|
||||
if _is_function_call == True and stream is not None and stream == True:
|
||||
# return an iterator
|
||||
streaming_model_response = ModelResponse(stream=True)
|
||||
streaming_model_response.choices[0].finish_reason = model_response.choices[
|
||||
0
|
||||
].finish_reason
|
||||
streaming_model_response.choices[0].index = model_response.choices[0].index
|
||||
_tool_calls = []
|
||||
if isinstance(model_response.choices[0], litellm.Choices):
|
||||
if getattr(
|
||||
model_response.choices[0].message, "tool_calls", None
|
||||
) is not None and isinstance(
|
||||
model_response.choices[0].message.tool_calls, list
|
||||
):
|
||||
for tool_call in model_response.choices[0].message.tool_calls:
|
||||
_tool_call = {**tool_call.dict(), "index": 0}
|
||||
_tool_calls.append(_tool_call)
|
||||
delta_obj = litellm.utils.Delta(
|
||||
content=getattr(model_response.choices[0].message, "content", None),
|
||||
role=model_response.choices[0].message.role,
|
||||
tool_calls=_tool_calls,
|
||||
)
|
||||
streaming_model_response.choices[0].delta = delta_obj
|
||||
completion_stream = model_response_iterator(
|
||||
model_response=streaming_model_response
|
||||
)
|
||||
return CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider="cached_response",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
## CALCULATING USAGE
|
||||
prompt_tokens = completion_response["usage"]["input_tokens"]
|
||||
completion_tokens = completion_response["usage"]["output_tokens"]
|
||||
|
@ -271,6 +310,10 @@ def completion(
|
|||
return model_response
|
||||
|
||||
|
||||
def model_response_iterator(model_response):
|
||||
yield model_response
|
||||
|
||||
|
||||
def embedding():
|
||||
# logic for parsing in - calling - parsing out model embedding calls
|
||||
pass
|
||||
|
|
|
@ -1073,7 +1073,11 @@ def completion(
|
|||
logging_obj=logging,
|
||||
headers=headers,
|
||||
)
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
if (
|
||||
"stream" in optional_params
|
||||
and optional_params["stream"] == True
|
||||
and not isinstance(response, CustomStreamWrapper)
|
||||
):
|
||||
# don't try to access stream object,
|
||||
response = CustomStreamWrapper(
|
||||
response,
|
||||
|
|
|
@ -1749,7 +1749,7 @@ class Chunk(BaseModel):
|
|||
object: str
|
||||
created: int
|
||||
model: str
|
||||
system_fingerprint: str
|
||||
# system_fingerprint: str
|
||||
choices: List[Choices]
|
||||
|
||||
|
||||
|
@ -1869,7 +1869,7 @@ class Chunk3(BaseModel):
|
|||
object: str
|
||||
created: int
|
||||
model: str
|
||||
system_fingerprint: str
|
||||
# system_fingerprint: str
|
||||
choices: List[Choices3]
|
||||
|
||||
|
||||
|
@ -2032,3 +2032,56 @@ async def test_azure_astreaming_and_function_calling():
|
|||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
raise e
|
||||
|
||||
|
||||
def test_completion_claude_3_function_call_with_streaming():
|
||||
# litellm.set_verbose = True
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]
|
||||
try:
|
||||
# test without max tokens
|
||||
response = completion(
|
||||
model="claude-3-opus-20240229",
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
stream=True,
|
||||
)
|
||||
idx = 0
|
||||
for chunk in response:
|
||||
# print(f"chunk: {chunk}")
|
||||
if idx == 0:
|
||||
assert (
|
||||
chunk.choices[0].delta.tool_calls[0].function.arguments is not None
|
||||
)
|
||||
assert isinstance(
|
||||
chunk.choices[0].delta.tool_calls[0].function.arguments, str
|
||||
)
|
||||
validate_first_streaming_function_calling_chunk(chunk=chunk)
|
||||
elif idx == 1:
|
||||
validate_second_streaming_function_calling_chunk(chunk=chunk)
|
||||
elif chunk.choices[0].finish_reason is not None: # last chunk
|
||||
validate_final_streaming_function_calling_chunk(chunk=chunk)
|
||||
idx += 1
|
||||
# raise Exception("it worked!")
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
|
|
@ -9364,8 +9364,10 @@ class CustomStreamWrapper:
|
|||
def __next__(self):
|
||||
try:
|
||||
while True:
|
||||
if isinstance(self.completion_stream, str) or isinstance(
|
||||
self.completion_stream, bytes
|
||||
if (
|
||||
isinstance(self.completion_stream, str)
|
||||
or isinstance(self.completion_stream, bytes)
|
||||
or isinstance(self.completion_stream, ModelResponse)
|
||||
):
|
||||
chunk = self.completion_stream
|
||||
else:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue