feat(anthropic_adapter.py): support streaming requests for /v1/messages endpoint

Fixes https://github.com/BerriAI/litellm/issues/5011
This commit is contained in:
Krrish Dholakia 2024-08-03 20:16:19 -07:00
parent 39a98a2882
commit ac6c39c283
9 changed files with 425 additions and 35 deletions

View file

@ -4,7 +4,7 @@ import json
import os import os
import traceback import traceback
import uuid import uuid
from typing import Literal, Optional from typing import Any, Literal, Optional
import dotenv import dotenv
import httpx import httpx
@ -13,7 +13,12 @@ from pydantic import BaseModel
import litellm import litellm
from litellm import ChatCompletionRequest, verbose_logger from litellm import ChatCompletionRequest, verbose_logger
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.types.llms.anthropic import AnthropicMessagesRequest, AnthropicResponse from litellm.types.llms.anthropic import (
AnthropicMessagesRequest,
AnthropicResponse,
ContentBlockDelta,
)
from litellm.types.utils import AdapterCompletionStreamWrapper
class AnthropicAdapter(CustomLogger): class AnthropicAdapter(CustomLogger):
@ -43,8 +48,147 @@ class AnthropicAdapter(CustomLogger):
response=response response=response
) )
def translate_completion_output_params_streaming(self) -> Optional[BaseModel]: def translate_completion_output_params_streaming(
return super().translate_completion_output_params_streaming() self, completion_stream: Any
) -> AdapterCompletionStreamWrapper | None:
return AnthropicStreamWrapper(completion_stream=completion_stream)
anthropic_adapter = AnthropicAdapter() anthropic_adapter = AnthropicAdapter()
class AnthropicStreamWrapper(AdapterCompletionStreamWrapper):
"""
- first chunk return 'message_start'
- content block must be started and stopped
- finish_reason must map exactly to anthropic reason, else anthropic client won't be able to parse it.
"""
sent_first_chunk: bool = False
sent_content_block_start: bool = False
sent_content_block_finish: bool = False
sent_last_message: bool = False
holding_chunk: Optional[Any] = None
def __next__(self):
try:
if self.sent_first_chunk is False:
self.sent_first_chunk = True
return {
"type": "message_start",
"message": {
"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY",
"type": "message",
"role": "assistant",
"content": [],
"model": "claude-3-5-sonnet-20240620",
"stop_reason": None,
"stop_sequence": None,
"usage": {"input_tokens": 25, "output_tokens": 1},
},
}
if self.sent_content_block_start is False:
self.sent_content_block_start = True
return {
"type": "content_block_start",
"index": 0,
"content_block": {"type": "text", "text": ""},
}
for chunk in self.completion_stream:
if chunk == "None" or chunk is None:
raise Exception
processed_chunk = litellm.AnthropicConfig().translate_streaming_openai_response_to_anthropic(
response=chunk
)
if (
processed_chunk["type"] == "message_delta"
and self.sent_content_block_finish is False
):
self.holding_chunk = processed_chunk
self.sent_content_block_finish = True
return {
"type": "content_block_stop",
"index": 0,
}
elif self.holding_chunk is not None:
return_chunk = self.holding_chunk
self.holding_chunk = processed_chunk
return return_chunk
else:
return processed_chunk
if self.sent_last_message is False:
self.sent_last_message = True
return {"type": "message_stop"}
raise StopIteration
except StopIteration:
if self.sent_last_message is False:
self.sent_last_message = True
return {"type": "message_stop"}
raise StopIteration
except Exception as e:
verbose_logger.error(
"Anthropic Adapter - {}\n{}".format(e, traceback.format_exc())
)
async def __anext__(self):
try:
if self.sent_first_chunk is False:
self.sent_first_chunk = True
return {
"type": "message_start",
"message": {
"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY",
"type": "message",
"role": "assistant",
"content": [],
"model": "claude-3-5-sonnet-20240620",
"stop_reason": None,
"stop_sequence": None,
"usage": {"input_tokens": 25, "output_tokens": 1},
},
}
if self.sent_content_block_start is False:
self.sent_content_block_start = True
return {
"type": "content_block_start",
"index": 0,
"content_block": {"type": "text", "text": ""},
}
async for chunk in self.completion_stream:
if chunk == "None" or chunk is None:
raise Exception
processed_chunk = litellm.AnthropicConfig().translate_streaming_openai_response_to_anthropic(
response=chunk
)
if (
processed_chunk["type"] == "message_delta"
and self.sent_content_block_finish is False
):
self.holding_chunk = processed_chunk
self.sent_content_block_finish = True
return {
"type": "content_block_stop",
"index": 0,
}
elif self.holding_chunk is not None:
return_chunk = self.holding_chunk
self.holding_chunk = processed_chunk
return return_chunk
else:
return processed_chunk
if self.holding_chunk is not None:
return_chunk = self.holding_chunk
self.holding_chunk = None
return return_chunk
if self.sent_last_message is False:
self.sent_last_message = True
return {"type": "message_stop"}
raise StopIteration
except StopIteration:
if self.sent_last_message is False:
self.sent_last_message = True
return {"type": "message_stop"}
raise StopAsyncIteration

View file

@ -10,7 +10,7 @@ from pydantic import BaseModel
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.types.llms.openai import ChatCompletionRequest from litellm.types.llms.openai import ChatCompletionRequest
from litellm.types.utils import ModelResponse from litellm.types.utils import AdapterCompletionStreamWrapper, ModelResponse
class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback#callback-class class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
@ -76,7 +76,9 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
""" """
pass pass
def translate_completion_output_params_streaming(self) -> Optional[BaseModel]: def translate_completion_output_params_streaming(
self, completion_stream: Any
) -> Optional[AdapterCompletionStreamWrapper]:
""" """
Translates the streaming chunk, from the OpenAI format to the custom format. Translates the streaming chunk, from the OpenAI format to the custom format.
""" """

View file

@ -5,13 +5,16 @@ import time
import types import types
from enum import Enum from enum import Enum
from functools import partial from functools import partial
from typing import Callable, List, Optional, Union from typing import Callable, List, Literal, Optional, Tuple, Union
import httpx # type: ignore import httpx # type: ignore
import requests # type: ignore import requests # type: ignore
from openai.types.chat.chat_completion_chunk import Choice as OpenAIStreamingChoice
import litellm import litellm
import litellm.litellm_core_utils import litellm.litellm_core_utils
import litellm.types
import litellm.types.utils
from litellm import verbose_logger from litellm import verbose_logger
from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
@ -33,8 +36,12 @@ from litellm.types.llms.anthropic import (
AnthropicResponseUsageBlock, AnthropicResponseUsageBlock,
ContentBlockDelta, ContentBlockDelta,
ContentBlockStart, ContentBlockStart,
ContentJsonBlockDelta,
ContentTextBlockDelta,
MessageBlockDelta, MessageBlockDelta,
MessageDelta,
MessageStartBlock, MessageStartBlock,
UsageDelta,
) )
from litellm.types.llms.openai import ( from litellm.types.llms.openai import (
AllMessageValues, AllMessageValues,
@ -480,6 +487,74 @@ class AnthropicConfig:
return translated_obj return translated_obj
def _translate_streaming_openai_chunk_to_anthropic(
self, choices: List[OpenAIStreamingChoice]
) -> Tuple[
Literal["text_delta", "input_json_delta"],
Union[ContentTextBlockDelta, ContentJsonBlockDelta],
]:
text: str = ""
partial_json: Optional[str] = None
for choice in choices:
if choice.delta.content is not None:
text += choice.delta.content
elif choice.delta.tool_calls is not None:
partial_json = ""
for tool in choice.delta.tool_calls:
if (
tool.function is not None
and tool.function.arguments is not None
):
partial_json += tool.function.arguments
if partial_json is not None:
return "input_json_delta", ContentJsonBlockDelta(
type="input_json_delta", partial_json=partial_json
)
else:
return "text_delta", ContentTextBlockDelta(type="text_delta", text=text)
def translate_streaming_openai_response_to_anthropic(
self, response: litellm.ModelResponse
) -> Union[ContentBlockDelta, MessageBlockDelta]:
## base case - final chunk w/ finish reason
if response.choices[0].finish_reason is not None:
delta = MessageDelta(
stop_reason=self._translate_openai_finish_reason_to_anthropic(
response.choices[0].finish_reason
),
)
if getattr(response, "usage", None) is not None:
litellm_usage_chunk: Optional[litellm.Usage] = response.usage # type: ignore
elif (
hasattr(response, "_hidden_params")
and "usage" in response._hidden_params
):
litellm_usage_chunk = response._hidden_params["usage"]
else:
litellm_usage_chunk = None
if litellm_usage_chunk is not None:
usage_delta = UsageDelta(
input_tokens=litellm_usage_chunk.prompt_tokens or 0,
output_tokens=litellm_usage_chunk.completion_tokens or 0,
)
else:
usage_delta = UsageDelta(input_tokens=0, output_tokens=0)
return MessageBlockDelta(
type="message_delta", delta=delta, usage=usage_delta
)
(
type_of_content,
content_block_delta,
) = self._translate_streaming_openai_chunk_to_anthropic(
choices=response.choices # type: ignore
)
return ContentBlockDelta(
type="content_block_delta",
index=response.choices[0].index,
delta=content_block_delta,
)
# makes headers for API call # makes headers for API call
def validate_environment(api_key, user_headers, model): def validate_environment(api_key, user_headers, model):

View file

@ -125,7 +125,7 @@ from .llms.vertex_ai_partner import VertexAIPartnerModels
from .llms.vertex_httpx import VertexLLM from .llms.vertex_httpx import VertexLLM
from .llms.watsonx import IBMWatsonXAI from .llms.watsonx import IBMWatsonXAI
from .types.llms.openai import HttpxBinaryResponseContent from .types.llms.openai import HttpxBinaryResponseContent
from .types.utils import ChatCompletionMessageToolCall from .types.utils import AdapterCompletionStreamWrapper, ChatCompletionMessageToolCall
encoding = tiktoken.get_encoding("cl100k_base") encoding = tiktoken.get_encoding("cl100k_base")
from litellm.utils import ( from litellm.utils import (
@ -515,7 +515,7 @@ def mock_completion(
model_response = ModelResponse(stream=stream) model_response = ModelResponse(stream=stream)
if stream is True: if stream is True:
# don't try to access stream object, # don't try to access stream object,
if kwargs.get("acompletion", False) == True: if kwargs.get("acompletion", False) is True:
return CustomStreamWrapper( return CustomStreamWrapper(
completion_stream=async_mock_completion_streaming_obj( completion_stream=async_mock_completion_streaming_obj(
model_response, mock_response=mock_response, model=model, n=n model_response, mock_response=mock_response, model=model, n=n
@ -524,13 +524,14 @@ def mock_completion(
custom_llm_provider="openai", custom_llm_provider="openai",
logging_obj=logging, logging_obj=logging,
) )
response = mock_completion_streaming_obj( return CustomStreamWrapper(
model_response, completion_stream=mock_completion_streaming_obj(
mock_response=mock_response, model_response, mock_response=mock_response, model=model, n=n
),
model=model, model=model,
n=n, custom_llm_provider="openai",
logging_obj=logging,
) )
return response
if n is None: if n is None:
model_response.choices[0].message.content = mock_response # type: ignore model_response.choices[0].message.content = mock_response # type: ignore
else: else:
@ -4037,7 +4038,9 @@ def text_completion(
###### Adapter Completion ################ ###### Adapter Completion ################
async def aadapter_completion(*, adapter_id: str, **kwargs) -> Optional[BaseModel]: async def aadapter_completion(
*, adapter_id: str, **kwargs
) -> Optional[Union[BaseModel, AdapterCompletionStreamWrapper]]:
""" """
Implemented to handle async calls for adapter_completion() Implemented to handle async calls for adapter_completion()
""" """
@ -4056,18 +4059,29 @@ async def aadapter_completion(*, adapter_id: str, **kwargs) -> Optional[BaseMode
new_kwargs = translation_obj.translate_completion_input_params(kwargs=kwargs) new_kwargs = translation_obj.translate_completion_input_params(kwargs=kwargs)
response: ModelResponse = await acompletion(**new_kwargs) # type: ignore response: Union[ModelResponse, CustomStreamWrapper] = await acompletion(**new_kwargs) # type: ignore
translated_response: Optional[
translated_response = translation_obj.translate_completion_output_params( Union[BaseModel, AdapterCompletionStreamWrapper]
response=response ] = None
) if isinstance(response, ModelResponse):
translated_response = translation_obj.translate_completion_output_params(
response=response
)
if isinstance(response, CustomStreamWrapper):
translated_response = (
translation_obj.translate_completion_output_params_streaming(
completion_stream=response
)
)
return translated_response return translated_response
except Exception as e: except Exception as e:
raise e raise e
def adapter_completion(*, adapter_id: str, **kwargs) -> Optional[BaseModel]: def adapter_completion(
*, adapter_id: str, **kwargs
) -> Optional[Union[BaseModel, AdapterCompletionStreamWrapper]]:
translation_obj: Optional[CustomLogger] = None translation_obj: Optional[CustomLogger] = None
for item in litellm.adapters: for item in litellm.adapters:
if item["id"] == adapter_id: if item["id"] == adapter_id:
@ -4082,11 +4096,20 @@ def adapter_completion(*, adapter_id: str, **kwargs) -> Optional[BaseModel]:
new_kwargs = translation_obj.translate_completion_input_params(kwargs=kwargs) new_kwargs = translation_obj.translate_completion_input_params(kwargs=kwargs)
response: ModelResponse = completion(**new_kwargs) # type: ignore response: Union[ModelResponse, CustomStreamWrapper] = completion(**new_kwargs) # type: ignore
translated_response: Optional[Union[BaseModel, AdapterCompletionStreamWrapper]] = (
translated_response = translation_obj.translate_completion_output_params( None
response=response
) )
if isinstance(response, ModelResponse):
translated_response = translation_obj.translate_completion_output_params(
response=response
)
elif isinstance(response, CustomStreamWrapper) or inspect.isgenerator(response):
translated_response = (
translation_obj.translate_completion_output_params_streaming(
completion_stream=response
)
)
return translated_response return translated_response

View file

@ -1,7 +1,7 @@
model_list: model_list:
- model_name: "*" - model_name: "claude-3-5-sonnet-20240620"
litellm_params: litellm_params:
model: "*" model: "claude-3-5-sonnet-20240620"
# litellm_settings: # litellm_settings:
# failure_callback: ["langfuse"] # failure_callback: ["langfuse"]

View file

@ -2396,7 +2396,9 @@ async def async_data_generator(
user_api_key_dict=user_api_key_dict, response=chunk user_api_key_dict=user_api_key_dict, response=chunk
) )
chunk = chunk.model_dump_json(exclude_none=True, exclude_unset=True) if isinstance(chunk, BaseModel):
chunk = chunk.model_dump_json(exclude_none=True, exclude_unset=True)
try: try:
yield f"data: {chunk}\n\n" yield f"data: {chunk}\n\n"
except Exception as e: except Exception as e:
@ -2437,6 +2439,59 @@ async def async_data_generator(
yield f"data: {error_returned}\n\n" yield f"data: {error_returned}\n\n"
async def async_data_generator_anthropic(
response, user_api_key_dict: UserAPIKeyAuth, request_data: dict
):
verbose_proxy_logger.debug("inside generator")
try:
start_time = time.time()
async for chunk in response:
verbose_proxy_logger.debug(
"async_data_generator: received streaming chunk - {}".format(chunk)
)
### CALL HOOKS ### - modify outgoing data
chunk = await proxy_logging_obj.async_post_call_streaming_hook(
user_api_key_dict=user_api_key_dict, response=chunk
)
event_type = chunk.get("type")
try:
yield f"event: {event_type}\ndata:{json.dumps(chunk)}\n\n"
except Exception as e:
yield f"event: {event_type}\ndata:{str(e)}\n\n"
except Exception as e:
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.async_data_generator(): Exception occured - {}\n{}".format(
str(e), traceback.format_exc()
)
)
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict,
original_exception=e,
request_data=request_data,
)
verbose_proxy_logger.debug(
f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`"
)
router_model_names = llm_router.model_names if llm_router is not None else []
if isinstance(e, HTTPException):
raise e
else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}\n\n{error_traceback}"
proxy_exception = ProxyException(
message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
)
error_returned = json.dumps({"error": proxy_exception.to_dict()})
yield f"data: {error_returned}\n\n"
def select_data_generator( def select_data_generator(
response, user_api_key_dict: UserAPIKeyAuth, request_data: dict response, user_api_key_dict: UserAPIKeyAuth, request_data: dict
): ):
@ -5379,6 +5434,19 @@ async def anthropic_response(
) )
) )
if (
"stream" in data and data["stream"] is True
): # use generate_responses to stream responses
selected_data_generator = async_data_generator_anthropic(
response=response,
user_api_key_dict=user_api_key_dict,
request_data=data,
)
return StreamingResponse(
selected_data_generator,
media_type="text/event-stream",
)
verbose_proxy_logger.info("\nResponse from Litellm:\n{}".format(response)) verbose_proxy_logger.info("\nResponse from Litellm:\n{}".format(response))
return response return response
except RejectedRequestError as e: except RejectedRequestError as e:
@ -5425,11 +5493,10 @@ async def anthropic_response(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
) )
verbose_proxy_logger.error( verbose_proxy_logger.error(
"litellm.proxy.proxy_server.completion(): Exception occured - {}".format( "litellm.proxy.proxy_server.anthropic_response(): Exception occured - {}\n{}".format(
str(e) str(e), traceback.format_exc()
) )
) )
verbose_proxy_logger.debug(traceback.format_exc())
error_msg = f"{str(e)}" error_msg = f"{str(e)}"
raise ProxyException( raise ProxyException(
message=getattr(e, "message", error_msg), message=getattr(e, "message", error_msg),

View file

@ -8,6 +8,9 @@ import traceback
from dotenv import load_dotenv from dotenv import load_dotenv
import litellm.types
import litellm.types.utils
load_dotenv() load_dotenv()
import io import io
import os import os
@ -15,6 +18,7 @@ import os
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
from typing import Optional
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
@ -84,7 +88,22 @@ def test_anthropic_completion_input_translation_with_metadata():
assert translated_input["metadata"] == data["litellm_metadata"] assert translated_input["metadata"] == data["litellm_metadata"]
def test_anthropic_completion_e2e(): def streaming_format_tests(chunk: dict, idx: int):
"""
1st chunk - chunk.get("type") == "message_start"
2nd chunk - chunk.get("type") == "content_block_start"
3rd chunk - chunk.get("type") == "content_block_delta"
"""
if idx == 0:
assert chunk.get("type") == "message_start"
elif idx == 1:
assert chunk.get("type") == "content_block_start"
elif idx == 2:
assert chunk.get("type") == "content_block_delta"
@pytest.mark.parametrize("stream", [True]) # False
def test_anthropic_completion_e2e(stream):
litellm.set_verbose = True litellm.set_verbose = True
litellm.adapters = [{"id": "anthropic", "adapter": anthropic_adapter}] litellm.adapters = [{"id": "anthropic", "adapter": anthropic_adapter}]
@ -95,13 +114,40 @@ def test_anthropic_completion_e2e():
messages=messages, messages=messages,
adapter_id="anthropic", adapter_id="anthropic",
mock_response="This is a fake call", mock_response="This is a fake call",
stream=stream,
) )
print("Response: {}".format(response)) print("Response: {}".format(response))
assert response is not None assert response is not None
assert isinstance(response, AnthropicResponse) if stream is False:
assert isinstance(response, AnthropicResponse)
else:
"""
- ensure finish reason is returned
- assert content block is started and stopped
- ensure last chunk is 'message_stop'
"""
assert isinstance(response, litellm.types.utils.AdapterCompletionStreamWrapper)
finish_reason: Optional[str] = None
message_stop_received = False
content_block_started = False
content_block_finished = False
for idx, chunk in enumerate(response):
print(chunk)
streaming_format_tests(chunk=chunk, idx=idx)
if chunk.get("delta", {}).get("stop_reason") is not None:
finish_reason = chunk.get("delta", {}).get("stop_reason")
if chunk.get("type") == "message_stop":
message_stop_received = True
if chunk.get("type") == "content_block_stop":
content_block_finished = True
if chunk.get("type") == "content_block_start":
content_block_started = True
assert content_block_started and content_block_finished
assert finish_reason is not None
assert message_stop_received is True
@pytest.mark.asyncio @pytest.mark.asyncio

View file

@ -136,7 +136,7 @@ class ContentJsonBlockDelta(TypedDict):
class ContentBlockDelta(TypedDict): class ContentBlockDelta(TypedDict):
type: str type: Literal["content_block_delta"]
index: int index: int
delta: Union[ContentTextBlockDelta, ContentJsonBlockDelta] delta: Union[ContentTextBlockDelta, ContentJsonBlockDelta]

View file

@ -6,7 +6,7 @@ from typing import Dict, List, Literal, Optional, Tuple, Union
from openai._models import BaseModel as OpenAIObject from openai._models import BaseModel as OpenAIObject
from pydantic import ConfigDict, Field, PrivateAttr from pydantic import ConfigDict, Field, PrivateAttr
from typing_extensions import Dict, Required, TypedDict, override from typing_extensions import Callable, Dict, Required, TypedDict, override
from ..litellm_core_utils.core_helpers import map_finish_reason from ..litellm_core_utils.core_helpers import map_finish_reason
from .llms.openai import ChatCompletionToolCallChunk, ChatCompletionUsageBlock from .llms.openai import ChatCompletionToolCallChunk, ChatCompletionUsageBlock
@ -1069,3 +1069,36 @@ class LoggedLiteLLMParams(TypedDict, total=False):
output_cost_per_token: Optional[float] output_cost_per_token: Optional[float]
output_cost_per_second: Optional[float] output_cost_per_second: Optional[float]
cooldown_time: Optional[float] cooldown_time: Optional[float]
class AdapterCompletionStreamWrapper:
def __init__(self, completion_stream):
self.completion_stream = completion_stream
def __iter__(self):
return self
def __aiter__(self):
return self
def __next__(self):
try:
for chunk in self.completion_stream:
if chunk == "None" or chunk is None:
raise Exception
return chunk
raise StopIteration
except StopIteration:
raise StopIteration
except Exception as e:
print(f"AdapterCompletionStreamWrapper - {e}") # noqa
async def __anext__(self):
try:
async for chunk in self.completion_stream:
if chunk == "None" or chunk is None:
raise Exception
return chunk
raise StopIteration
except StopIteration:
raise StopAsyncIteration