mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 19:54:13 +00:00
Merge pull request #2472 from BerriAI/litellm_anthropic_streaming_tool_calling
fix(anthropic.py): support claude-3 streaming with function calling
This commit is contained in:
commit
ce3c865adb
5 changed files with 150 additions and 27 deletions
|
@ -4,7 +4,7 @@ from enum import Enum
|
||||||
import requests, copy
|
import requests, copy
|
||||||
import time, uuid
|
import time, uuid
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
from litellm.utils import ModelResponse, Usage, map_finish_reason
|
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
|
||||||
import litellm
|
import litellm
|
||||||
from .prompt_templates.factory import (
|
from .prompt_templates.factory import (
|
||||||
prompt_factory,
|
prompt_factory,
|
||||||
|
@ -118,6 +118,7 @@ def completion(
|
||||||
headers = validate_environment(api_key, headers)
|
headers = validate_environment(api_key, headers)
|
||||||
_is_function_call = False
|
_is_function_call = False
|
||||||
messages = copy.deepcopy(messages)
|
messages = copy.deepcopy(messages)
|
||||||
|
optional_params = copy.deepcopy(optional_params)
|
||||||
if model in custom_prompt_dict:
|
if model in custom_prompt_dict:
|
||||||
# check if the model has a registered custom prompt
|
# check if the model has a registered custom prompt
|
||||||
model_prompt_details = custom_prompt_dict[model]
|
model_prompt_details = custom_prompt_dict[model]
|
||||||
|
@ -161,6 +162,8 @@ def completion(
|
||||||
) # add the anthropic tool calling prompt to the system prompt
|
) # add the anthropic tool calling prompt to the system prompt
|
||||||
optional_params.pop("tools")
|
optional_params.pop("tools")
|
||||||
|
|
||||||
|
stream = optional_params.pop("stream", None)
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
|
@ -177,14 +180,18 @@ def completion(
|
||||||
"headers": headers,
|
"headers": headers,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
print_verbose(f"_is_function_call: {_is_function_call}")
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if (
|
||||||
|
stream is not None and stream == True and _is_function_call == False
|
||||||
|
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
|
||||||
|
print_verbose(f"makes anthropic streaming POST request")
|
||||||
|
data["stream"] = stream
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
api_base,
|
api_base,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
data=json.dumps(data),
|
data=json.dumps(data),
|
||||||
stream=optional_params["stream"],
|
stream=stream,
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
|
@ -255,6 +262,51 @@ def completion(
|
||||||
completion_response["stop_reason"]
|
completion_response["stop_reason"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
print_verbose(f"_is_function_call: {_is_function_call}; stream: {stream}")
|
||||||
|
if _is_function_call == True and stream is not None and stream == True:
|
||||||
|
print_verbose(f"INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK")
|
||||||
|
# return an iterator
|
||||||
|
streaming_model_response = ModelResponse(stream=True)
|
||||||
|
streaming_model_response.choices[0].finish_reason = model_response.choices[
|
||||||
|
0
|
||||||
|
].finish_reason
|
||||||
|
# streaming_model_response.choices = [litellm.utils.StreamingChoices()]
|
||||||
|
streaming_choice = litellm.utils.StreamingChoices()
|
||||||
|
streaming_choice.index = model_response.choices[0].index
|
||||||
|
_tool_calls = []
|
||||||
|
print_verbose(
|
||||||
|
f"type of model_response.choices[0]: {type(model_response.choices[0])}"
|
||||||
|
)
|
||||||
|
print_verbose(f"type of streaming_choice: {type(streaming_choice)}")
|
||||||
|
if isinstance(model_response.choices[0], litellm.Choices):
|
||||||
|
if getattr(
|
||||||
|
model_response.choices[0].message, "tool_calls", None
|
||||||
|
) is not None and isinstance(
|
||||||
|
model_response.choices[0].message.tool_calls, list
|
||||||
|
):
|
||||||
|
for tool_call in model_response.choices[0].message.tool_calls:
|
||||||
|
_tool_call = {**tool_call.dict(), "index": 0}
|
||||||
|
_tool_calls.append(_tool_call)
|
||||||
|
delta_obj = litellm.utils.Delta(
|
||||||
|
content=getattr(model_response.choices[0].message, "content", None),
|
||||||
|
role=model_response.choices[0].message.role,
|
||||||
|
tool_calls=_tool_calls,
|
||||||
|
)
|
||||||
|
streaming_choice.delta = delta_obj
|
||||||
|
streaming_model_response.choices = [streaming_choice]
|
||||||
|
completion_stream = model_response_iterator(
|
||||||
|
model_response=streaming_model_response
|
||||||
|
)
|
||||||
|
print_verbose(
|
||||||
|
f"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
|
||||||
|
)
|
||||||
|
return CustomStreamWrapper(
|
||||||
|
completion_stream=completion_stream,
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="cached_response",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
|
||||||
## CALCULATING USAGE
|
## CALCULATING USAGE
|
||||||
prompt_tokens = completion_response["usage"]["input_tokens"]
|
prompt_tokens = completion_response["usage"]["input_tokens"]
|
||||||
completion_tokens = completion_response["usage"]["output_tokens"]
|
completion_tokens = completion_response["usage"]["output_tokens"]
|
||||||
|
@ -271,6 +323,10 @@ def completion(
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
|
|
||||||
|
def model_response_iterator(model_response):
|
||||||
|
yield model_response
|
||||||
|
|
||||||
|
|
||||||
def embedding():
|
def embedding():
|
||||||
# logic for parsing in - calling - parsing out model embedding calls
|
# logic for parsing in - calling - parsing out model embedding calls
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -1144,7 +1144,11 @@ def completion(
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
)
|
)
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if (
|
||||||
|
"stream" in optional_params
|
||||||
|
and optional_params["stream"] == True
|
||||||
|
and not isinstance(response, CustomStreamWrapper)
|
||||||
|
):
|
||||||
# don't try to access stream object,
|
# don't try to access stream object,
|
||||||
response = CustomStreamWrapper(
|
response = CustomStreamWrapper(
|
||||||
response,
|
response,
|
||||||
|
|
|
@ -36,32 +36,32 @@ test_completion.py . [100%]
|
||||||
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:180: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
|
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:180: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
|
||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
|
|
||||||
../proxy/_types.py:235
|
../proxy/_types.py:241
|
||||||
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:235: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
|
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:241: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
|
||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
|
|
||||||
../proxy/_types.py:247
|
../proxy/_types.py:253
|
||||||
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:247: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
|
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:253: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
|
||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
|
|
||||||
../proxy/_types.py:282
|
../proxy/_types.py:292
|
||||||
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:282: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
|
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:292: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
|
||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
|
|
||||||
../proxy/_types.py:308
|
../proxy/_types.py:319
|
||||||
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:308: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
|
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:319: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
|
||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
|
|
||||||
../proxy/_types.py:557
|
../proxy/_types.py:570
|
||||||
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:557: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
|
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:570: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
|
||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
|
|
||||||
../proxy/_types.py:578
|
../proxy/_types.py:591
|
||||||
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:578: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
|
/Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:591: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
|
||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
|
|
||||||
../utils.py:36
|
../utils.py:35
|
||||||
/Users/krrishdholakia/Documents/litellm/litellm/utils.py:36: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html
|
/Users/krrishdholakia/Documents/litellm/litellm/utils.py:35: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
|
|
||||||
../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2871: 10 warnings
|
../../../../../../opt/homebrew/lib/python3.11/site-packages/pkg_resources/__init__.py:2871: 10 warnings
|
||||||
|
@ -109,5 +109,11 @@ test_completion.py . [100%]
|
||||||
/Users/krrishdholakia/Documents/litellm/litellm/llms/prompt_templates/factory.py:6: DeprecationWarning: 'imghdr' is deprecated and slated for removal in Python 3.13
|
/Users/krrishdholakia/Documents/litellm/litellm/llms/prompt_templates/factory.py:6: DeprecationWarning: 'imghdr' is deprecated and slated for removal in Python 3.13
|
||||||
import imghdr, base64
|
import imghdr, base64
|
||||||
|
|
||||||
|
test_completion.py::test_completion_claude_3_stream
|
||||||
|
../utils.py:3249
|
||||||
|
../utils.py:3249
|
||||||
|
/Users/krrishdholakia/Documents/litellm/litellm/utils.py:3249: DeprecationWarning: open_text is deprecated. Use files() instead. Refer to https://importlib-resources.readthedocs.io/en/latest/using.html#migrating-from-legacy for migration advice.
|
||||||
|
with resources.open_text(
|
||||||
|
|
||||||
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
|
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
|
||||||
======================== 1 passed, 43 warnings in 4.47s ========================
|
======================== 1 passed, 46 warnings in 3.14s ========================
|
||||||
|
|
|
@ -1749,7 +1749,7 @@ class Chunk(BaseModel):
|
||||||
object: str
|
object: str
|
||||||
created: int
|
created: int
|
||||||
model: str
|
model: str
|
||||||
system_fingerprint: str
|
# system_fingerprint: str
|
||||||
choices: List[Choices]
|
choices: List[Choices]
|
||||||
|
|
||||||
|
|
||||||
|
@ -1869,7 +1869,7 @@ class Chunk3(BaseModel):
|
||||||
object: str
|
object: str
|
||||||
created: int
|
created: int
|
||||||
model: str
|
model: str
|
||||||
system_fingerprint: str
|
# system_fingerprint: str
|
||||||
choices: List[Choices3]
|
choices: List[Choices3]
|
||||||
|
|
||||||
|
|
||||||
|
@ -2032,3 +2032,56 @@ async def test_azure_astreaming_and_function_calling():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
def test_completion_claude_3_function_call_with_streaming():
|
||||||
|
litellm.set_verbose = True
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]
|
||||||
|
try:
|
||||||
|
# test without max tokens
|
||||||
|
response = completion(
|
||||||
|
model="claude-3-opus-20240229",
|
||||||
|
messages=messages,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice="auto",
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
idx = 0
|
||||||
|
for chunk in response:
|
||||||
|
# print(f"chunk: {chunk}")
|
||||||
|
if idx == 0:
|
||||||
|
assert (
|
||||||
|
chunk.choices[0].delta.tool_calls[0].function.arguments is not None
|
||||||
|
)
|
||||||
|
assert isinstance(
|
||||||
|
chunk.choices[0].delta.tool_calls[0].function.arguments, str
|
||||||
|
)
|
||||||
|
validate_first_streaming_function_calling_chunk(chunk=chunk)
|
||||||
|
elif idx == 1:
|
||||||
|
validate_second_streaming_function_calling_chunk(chunk=chunk)
|
||||||
|
elif chunk.choices[0].finish_reason is not None: # last chunk
|
||||||
|
validate_final_streaming_function_calling_chunk(chunk=chunk)
|
||||||
|
idx += 1
|
||||||
|
# raise Exception("it worked!")
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
|
@ -480,12 +480,12 @@ class ModelResponse(OpenAIObject):
|
||||||
object=None,
|
object=None,
|
||||||
system_fingerprint=None,
|
system_fingerprint=None,
|
||||||
usage=None,
|
usage=None,
|
||||||
stream=False,
|
stream=None,
|
||||||
response_ms=None,
|
response_ms=None,
|
||||||
hidden_params=None,
|
hidden_params=None,
|
||||||
**params,
|
**params,
|
||||||
):
|
):
|
||||||
if stream:
|
if stream is not None and stream == True:
|
||||||
object = "chat.completion.chunk"
|
object = "chat.completion.chunk"
|
||||||
choices = [StreamingChoices()]
|
choices = [StreamingChoices()]
|
||||||
else:
|
else:
|
||||||
|
@ -9471,14 +9471,18 @@ class CustomStreamWrapper:
|
||||||
def __next__(self):
|
def __next__(self):
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
if isinstance(self.completion_stream, str) or isinstance(
|
if (
|
||||||
self.completion_stream, bytes
|
isinstance(self.completion_stream, str)
|
||||||
|
or isinstance(self.completion_stream, bytes)
|
||||||
|
or isinstance(self.completion_stream, ModelResponse)
|
||||||
):
|
):
|
||||||
chunk = self.completion_stream
|
chunk = self.completion_stream
|
||||||
else:
|
else:
|
||||||
chunk = next(self.completion_stream)
|
chunk = next(self.completion_stream)
|
||||||
if chunk is not None and chunk != b"":
|
if chunk is not None and chunk != b"":
|
||||||
print_verbose(f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}")
|
print_verbose(
|
||||||
|
f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}; custom_llm_provider: {self.custom_llm_provider}"
|
||||||
|
)
|
||||||
response: Optional[ModelResponse] = self.chunk_creator(chunk=chunk)
|
response: Optional[ModelResponse] = self.chunk_creator(chunk=chunk)
|
||||||
print_verbose(f"PROCESSED CHUNK POST CHUNK CREATOR: {response}")
|
print_verbose(f"PROCESSED CHUNK POST CHUNK CREATOR: {response}")
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue