Merge pull request #2472 from BerriAI/litellm_anthropic_streaming_tool_calling

fix(anthropic.py): support claude-3 streaming with function calling
This commit is contained in:
Krish Dholakia 2024-03-12 21:36:01 -07:00 committed by GitHub
commit ce3c865adb
5 changed files with 150 additions and 27 deletions

View file

@ -4,7 +4,7 @@ from enum import Enum
import requests, copy import requests, copy
import time, uuid import time, uuid
from typing import Callable, Optional from typing import Callable, Optional
from litellm.utils import ModelResponse, Usage, map_finish_reason from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
import litellm import litellm
from .prompt_templates.factory import ( from .prompt_templates.factory import (
prompt_factory, prompt_factory,
@ -118,6 +118,7 @@ def completion(
headers = validate_environment(api_key, headers) headers = validate_environment(api_key, headers)
_is_function_call = False _is_function_call = False
messages = copy.deepcopy(messages) messages = copy.deepcopy(messages)
optional_params = copy.deepcopy(optional_params)
if model in custom_prompt_dict: if model in custom_prompt_dict:
# check if the model has a registered custom prompt # check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model] model_prompt_details = custom_prompt_dict[model]
@ -161,6 +162,8 @@ def completion(
) # add the anthropic tool calling prompt to the system prompt ) # add the anthropic tool calling prompt to the system prompt
optional_params.pop("tools") optional_params.pop("tools")
stream = optional_params.pop("stream", None)
data = { data = {
"model": model, "model": model,
"messages": messages, "messages": messages,
@ -177,14 +180,18 @@ def completion(
"headers": headers, "headers": headers,
}, },
) )
print_verbose(f"_is_function_call: {_is_function_call}")
## COMPLETION CALL ## COMPLETION CALL
if "stream" in optional_params and optional_params["stream"] == True: if (
stream is not None and stream == True and _is_function_call == False
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
print_verbose(f"makes anthropic streaming POST request")
data["stream"] = stream
response = requests.post( response = requests.post(
api_base, api_base,
headers=headers, headers=headers,
data=json.dumps(data), data=json.dumps(data),
stream=optional_params["stream"], stream=stream,
) )
if response.status_code != 200: if response.status_code != 200:
@ -255,6 +262,51 @@ def completion(
completion_response["stop_reason"] completion_response["stop_reason"]
) )
print_verbose(f"_is_function_call: {_is_function_call}; stream: {stream}")
if _is_function_call == True and stream is not None and stream == True:
print_verbose(f"INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK")
# return an iterator
streaming_model_response = ModelResponse(stream=True)
streaming_model_response.choices[0].finish_reason = model_response.choices[
0
].finish_reason
# streaming_model_response.choices = [litellm.utils.StreamingChoices()]
streaming_choice = litellm.utils.StreamingChoices()
streaming_choice.index = model_response.choices[0].index
_tool_calls = []
print_verbose(
f"type of model_response.choices[0]: {type(model_response.choices[0])}"
)
print_verbose(f"type of streaming_choice: {type(streaming_choice)}")
if isinstance(model_response.choices[0], litellm.Choices):
if getattr(
model_response.choices[0].message, "tool_calls", None
) is not None and isinstance(
model_response.choices[0].message.tool_calls, list
):
for tool_call in model_response.choices[0].message.tool_calls:
_tool_call = {**tool_call.dict(), "index": 0}
_tool_calls.append(_tool_call)
delta_obj = litellm.utils.Delta(
content=getattr(model_response.choices[0].message, "content", None),
role=model_response.choices[0].message.role,
tool_calls=_tool_calls,
)
streaming_choice.delta = delta_obj
streaming_model_response.choices = [streaming_choice]
completion_stream = model_response_iterator(
model_response=streaming_model_response
)
print_verbose(
f"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
)
return CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="cached_response",
logging_obj=logging_obj,
)
## CALCULATING USAGE ## CALCULATING USAGE
prompt_tokens = completion_response["usage"]["input_tokens"] prompt_tokens = completion_response["usage"]["input_tokens"]
completion_tokens = completion_response["usage"]["output_tokens"] completion_tokens = completion_response["usage"]["output_tokens"]
@ -271,6 +323,10 @@ def completion(
return model_response return model_response
def model_response_iterator(model_response):
yield model_response
def embedding(): def embedding():
# logic for parsing in - calling - parsing out model embedding calls # logic for parsing in - calling - parsing out model embedding calls
pass pass

View file

@ -1144,7 +1144,11 @@ def completion(
logging_obj=logging, logging_obj=logging,
headers=headers, headers=headers,
) )
if "stream" in optional_params and optional_params["stream"] == True: if (
"stream" in optional_params
and optional_params["stream"] == True
and not isinstance(response, CustomStreamWrapper)
):
# don't try to access stream object, # don't try to access stream object,
response = CustomStreamWrapper( response = CustomStreamWrapper(
response, response,

View file

@ -36,32 +36,32 @@ test_completion.py . [100%]
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:180: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/ /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:180: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
@root_validator(pre=True) @root_validator(pre=True)
../proxy/_types.py:235 ../proxy/_types.py:241
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:235: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/ /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:241: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
@root_validator(pre=True) @root_validator(pre=True)
../proxy/_types.py:247 ../proxy/_types.py:253
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:247: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/ /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:253: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
@root_validator(pre=True) @root_validator(pre=True)
../proxy/_types.py:282 ../proxy/_types.py:292
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:282: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/ /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:292: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
@root_validator(pre=True) @root_validator(pre=True)
../proxy/_types.py:308 ../proxy/_types.py:319
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:308: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/ /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:319: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
@root_validator(pre=True) @root_validator(pre=True)
../proxy/_types.py:557 ../proxy/_types.py:570
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:557: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/ /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:570: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
@root_validator(pre=True) @root_validator(pre=True)
../proxy/_types.py:578 ../proxy/_types.py:591
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:578: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/ /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:591: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
@root_validator(pre=True) @root_validator(pre=True)
../utils.py:36 ../utils.py:35
/Users/krrishdholakia/Documents/litellm/litellm/utils.py:36: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html /Users/krrishdholakia/Documents/litellm/litellm/utils.py:35: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html
import pkg_resources import pkg_resources
../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2871: 10 warnings ../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2871: 10 warnings
@ -109,5 +109,11 @@ test_completion.py . [100%]
/Users/krrishdholakia/Documents/litellm/litellm/llms/prompt_templates/factory.py:6: DeprecationWarning: 'imghdr' is deprecated and slated for removal in Python 3.13 /Users/krrishdholakia/Documents/litellm/litellm/llms/prompt_templates/factory.py:6: DeprecationWarning: 'imghdr' is deprecated and slated for removal in Python 3.13
import imghdr, base64 import imghdr, base64
test_completion.py::test_completion_claude_3_stream
../utils.py:3249
../utils.py:3249
/Users/krrishdholakia/Documents/litellm/litellm/utils.py:3249: DeprecationWarning: open_text is deprecated. Use files() instead. Refer to https://importlib-resources.readthedocs.io/en/latest/using.html#migrating-from-legacy for migration advice.
with resources.open_text(
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html -- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================== 1 passed, 43 warnings in 4.47s ======================== ======================== 1 passed, 46 warnings in 3.14s ========================

View file

@ -1749,7 +1749,7 @@ class Chunk(BaseModel):
object: str object: str
created: int created: int
model: str model: str
system_fingerprint: str # system_fingerprint: str
choices: List[Choices] choices: List[Choices]
@ -1869,7 +1869,7 @@ class Chunk3(BaseModel):
object: str object: str
created: int created: int
model: str model: str
system_fingerprint: str # system_fingerprint: str
choices: List[Choices3] choices: List[Choices3]
@ -2032,3 +2032,56 @@ async def test_azure_astreaming_and_function_calling():
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
raise e raise e
def test_completion_claude_3_function_call_with_streaming():
litellm.set_verbose = True
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]
try:
# test without max tokens
response = completion(
model="claude-3-opus-20240229",
messages=messages,
tools=tools,
tool_choice="auto",
stream=True,
)
idx = 0
for chunk in response:
# print(f"chunk: {chunk}")
if idx == 0:
assert (
chunk.choices[0].delta.tool_calls[0].function.arguments is not None
)
assert isinstance(
chunk.choices[0].delta.tool_calls[0].function.arguments, str
)
validate_first_streaming_function_calling_chunk(chunk=chunk)
elif idx == 1:
validate_second_streaming_function_calling_chunk(chunk=chunk)
elif chunk.choices[0].finish_reason is not None: # last chunk
validate_final_streaming_function_calling_chunk(chunk=chunk)
idx += 1
# raise Exception("it worked!")
except Exception as e:
pytest.fail(f"Error occurred: {e}")

View file

@ -480,12 +480,12 @@ class ModelResponse(OpenAIObject):
object=None, object=None,
system_fingerprint=None, system_fingerprint=None,
usage=None, usage=None,
stream=False, stream=None,
response_ms=None, response_ms=None,
hidden_params=None, hidden_params=None,
**params, **params,
): ):
if stream: if stream is not None and stream == True:
object = "chat.completion.chunk" object = "chat.completion.chunk"
choices = [StreamingChoices()] choices = [StreamingChoices()]
else: else:
@ -9471,14 +9471,18 @@ class CustomStreamWrapper:
def __next__(self): def __next__(self):
try: try:
while True: while True:
if isinstance(self.completion_stream, str) or isinstance( if (
self.completion_stream, bytes isinstance(self.completion_stream, str)
or isinstance(self.completion_stream, bytes)
or isinstance(self.completion_stream, ModelResponse)
): ):
chunk = self.completion_stream chunk = self.completion_stream
else: else:
chunk = next(self.completion_stream) chunk = next(self.completion_stream)
if chunk is not None and chunk != b"": if chunk is not None and chunk != b"":
print_verbose(f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}") print_verbose(
f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}; custom_llm_provider: {self.custom_llm_provider}"
)
response: Optional[ModelResponse] = self.chunk_creator(chunk=chunk) response: Optional[ModelResponse] = self.chunk_creator(chunk=chunk)
print_verbose(f"PROCESSED CHUNK POST CHUNK CREATOR: {response}") print_verbose(f"PROCESSED CHUNK POST CHUNK CREATOR: {response}")