Merge pull request #2472 from BerriAI/litellm_anthropic_streaming_tool_calling

fix(anthropic.py): support claude-3 streaming with function calling
This commit is contained in:
Krish Dholakia 2024-03-12 21:36:01 -07:00 committed by GitHub
commit ce3c865adb
5 changed files with 150 additions and 27 deletions

View file

@ -4,7 +4,7 @@ from enum import Enum
import requests, copy
import time, uuid
from typing import Callable, Optional
from litellm.utils import ModelResponse, Usage, map_finish_reason
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
import litellm
from .prompt_templates.factory import (
prompt_factory,
@ -118,6 +118,7 @@ def completion(
headers = validate_environment(api_key, headers)
_is_function_call = False
messages = copy.deepcopy(messages)
optional_params = copy.deepcopy(optional_params)
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
@ -161,6 +162,8 @@ def completion(
) # add the anthropic tool calling prompt to the system prompt
optional_params.pop("tools")
stream = optional_params.pop("stream", None)
data = {
"model": model,
"messages": messages,
@ -177,14 +180,18 @@ def completion(
"headers": headers,
},
)
print_verbose(f"_is_function_call: {_is_function_call}")
## COMPLETION CALL
if "stream" in optional_params and optional_params["stream"] == True:
if (
stream is not None and stream == True and _is_function_call == False
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
print_verbose(f"makes anthropic streaming POST request")
data["stream"] = stream
response = requests.post(
api_base,
headers=headers,
data=json.dumps(data),
stream=optional_params["stream"],
stream=stream,
)
if response.status_code != 200:
@ -255,6 +262,51 @@ def completion(
completion_response["stop_reason"]
)
print_verbose(f"_is_function_call: {_is_function_call}; stream: {stream}")
if _is_function_call == True and stream is not None and stream == True:
print_verbose(f"INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK")
# return an iterator
streaming_model_response = ModelResponse(stream=True)
streaming_model_response.choices[0].finish_reason = model_response.choices[
0
].finish_reason
# streaming_model_response.choices = [litellm.utils.StreamingChoices()]
streaming_choice = litellm.utils.StreamingChoices()
streaming_choice.index = model_response.choices[0].index
_tool_calls = []
print_verbose(
f"type of model_response.choices[0]: {type(model_response.choices[0])}"
)
print_verbose(f"type of streaming_choice: {type(streaming_choice)}")
if isinstance(model_response.choices[0], litellm.Choices):
if getattr(
model_response.choices[0].message, "tool_calls", None
) is not None and isinstance(
model_response.choices[0].message.tool_calls, list
):
for tool_call in model_response.choices[0].message.tool_calls:
_tool_call = {**tool_call.dict(), "index": 0}
_tool_calls.append(_tool_call)
delta_obj = litellm.utils.Delta(
content=getattr(model_response.choices[0].message, "content", None),
role=model_response.choices[0].message.role,
tool_calls=_tool_calls,
)
streaming_choice.delta = delta_obj
streaming_model_response.choices = [streaming_choice]
completion_stream = model_response_iterator(
model_response=streaming_model_response
)
print_verbose(
f"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
)
return CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="cached_response",
logging_obj=logging_obj,
)
## CALCULATING USAGE
prompt_tokens = completion_response["usage"]["input_tokens"]
completion_tokens = completion_response["usage"]["output_tokens"]
@ -271,6 +323,10 @@ def completion(
return model_response
def model_response_iterator(model_response):
yield model_response
def embedding():
# logic for parsing in - calling - parsing out model embedding calls
pass

View file

@ -1144,7 +1144,11 @@ def completion(
logging_obj=logging,
headers=headers,
)
if "stream" in optional_params and optional_params["stream"] == True:
if (
"stream" in optional_params
and optional_params["stream"] == True
and not isinstance(response, CustomStreamWrapper)
):
# don't try to access stream object,
response = CustomStreamWrapper(
response,

View file

@ -36,32 +36,32 @@ test_completion.py . [100%]
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:180: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
@root_validator(pre=True)
../proxy/_types.py:235
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:235: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
../proxy/_types.py:241
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:241: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
@root_validator(pre=True)
../proxy/_types.py:247
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:247: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
../proxy/_types.py:253
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:253: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
@root_validator(pre=True)
../proxy/_types.py:282
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:282: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
../proxy/_types.py:292
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:292: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
@root_validator(pre=True)
../proxy/_types.py:308
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:308: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
../proxy/_types.py:319
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:319: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
@root_validator(pre=True)
../proxy/_types.py:557
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:557: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
../proxy/_types.py:570
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:570: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
@root_validator(pre=True)
../proxy/_types.py:578
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:578: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
../proxy/_types.py:591
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:591: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
@root_validator(pre=True)
../utils.py:36
/Users/krrishdholakia/Documents/litellm/litellm/utils.py:36: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html
../utils.py:35
/Users/krrishdholakia/Documents/litellm/litellm/utils.py:35: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html
import pkg_resources
../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2871: 10 warnings
@ -109,5 +109,11 @@ test_completion.py . [100%]
/Users/krrishdholakia/Documents/litellm/litellm/llms/prompt_templates/factory.py:6: DeprecationWarning: 'imghdr' is deprecated and slated for removal in Python 3.13
import imghdr, base64
test_completion.py::test_completion_claude_3_stream
../utils.py:3249
../utils.py:3249
/Users/krrishdholakia/Documents/litellm/litellm/utils.py:3249: DeprecationWarning: open_text is deprecated. Use files() instead. Refer to https://importlib-resources.readthedocs.io/en/latest/using.html#migrating-from-legacy for migration advice.
with resources.open_text(
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================== 1 passed, 43 warnings in 4.47s ========================
======================== 1 passed, 46 warnings in 3.14s ========================

View file

@ -1749,7 +1749,7 @@ class Chunk(BaseModel):
object: str
created: int
model: str
system_fingerprint: str
# system_fingerprint: str
choices: List[Choices]
@ -1869,7 +1869,7 @@ class Chunk3(BaseModel):
object: str
created: int
model: str
system_fingerprint: str
# system_fingerprint: str
choices: List[Choices3]
@ -2032,3 +2032,56 @@ async def test_azure_astreaming_and_function_calling():
except Exception as e:
pytest.fail(f"Error occurred: {e}")
raise e
def test_completion_claude_3_function_call_with_streaming():
litellm.set_verbose = True
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]
try:
# test without max tokens
response = completion(
model="claude-3-opus-20240229",
messages=messages,
tools=tools,
tool_choice="auto",
stream=True,
)
idx = 0
for chunk in response:
# print(f"chunk: {chunk}")
if idx == 0:
assert (
chunk.choices[0].delta.tool_calls[0].function.arguments is not None
)
assert isinstance(
chunk.choices[0].delta.tool_calls[0].function.arguments, str
)
validate_first_streaming_function_calling_chunk(chunk=chunk)
elif idx == 1:
validate_second_streaming_function_calling_chunk(chunk=chunk)
elif chunk.choices[0].finish_reason is not None: # last chunk
validate_final_streaming_function_calling_chunk(chunk=chunk)
idx += 1
# raise Exception("it worked!")
except Exception as e:
pytest.fail(f"Error occurred: {e}")

View file

@ -480,12 +480,12 @@ class ModelResponse(OpenAIObject):
object=None,
system_fingerprint=None,
usage=None,
stream=False,
stream=None,
response_ms=None,
hidden_params=None,
**params,
):
if stream:
if stream is not None and stream == True:
object = "chat.completion.chunk"
choices = [StreamingChoices()]
else:
@ -9471,14 +9471,18 @@ class CustomStreamWrapper:
def __next__(self):
try:
while True:
if isinstance(self.completion_stream, str) or isinstance(
self.completion_stream, bytes
if (
isinstance(self.completion_stream, str)
or isinstance(self.completion_stream, bytes)
or isinstance(self.completion_stream, ModelResponse)
):
chunk = self.completion_stream
else:
chunk = next(self.completion_stream)
if chunk is not None and chunk != b"":
print_verbose(f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}")
print_verbose(
f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}; custom_llm_provider: {self.custom_llm_provider}"
)
response: Optional[ModelResponse] = self.chunk_creator(chunk=chunk)
print_verbose(f"PROCESSED CHUNK POST CHUNK CREATOR: {response}")