feat(anthropic_adapter.py): support streaming requests for /v1/messages endpoint

Fixes https://github.com/BerriAI/litellm/issues/5011
This commit is contained in:
Krrish Dholakia 2024-08-03 20:16:19 -07:00
parent 39a98a2882
commit ac6c39c283
9 changed files with 425 additions and 35 deletions

View file

@ -4,7 +4,7 @@ import json
import os
import traceback
import uuid
from typing import Literal, Optional
from typing import Any, Literal, Optional
import dotenv
import httpx
@ -13,7 +13,12 @@ from pydantic import BaseModel
import litellm
from litellm import ChatCompletionRequest, verbose_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.types.llms.anthropic import AnthropicMessagesRequest, AnthropicResponse
from litellm.types.llms.anthropic import (
AnthropicMessagesRequest,
AnthropicResponse,
ContentBlockDelta,
)
from litellm.types.utils import AdapterCompletionStreamWrapper
class AnthropicAdapter(CustomLogger):
@ -43,8 +48,147 @@ class AnthropicAdapter(CustomLogger):
response=response
)
def translate_completion_output_params_streaming(self) -> Optional[BaseModel]:
return super().translate_completion_output_params_streaming()
def translate_completion_output_params_streaming(
self, completion_stream: Any
) -> AdapterCompletionStreamWrapper | None:
return AnthropicStreamWrapper(completion_stream=completion_stream)
anthropic_adapter = AnthropicAdapter()
class AnthropicStreamWrapper(AdapterCompletionStreamWrapper):
"""
- first chunk return 'message_start'
- content block must be started and stopped
- finish_reason must map exactly to anthropic reason, else anthropic client won't be able to parse it.
"""
sent_first_chunk: bool = False
sent_content_block_start: bool = False
sent_content_block_finish: bool = False
sent_last_message: bool = False
holding_chunk: Optional[Any] = None
def __next__(self):
try:
if self.sent_first_chunk is False:
self.sent_first_chunk = True
return {
"type": "message_start",
"message": {
"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY",
"type": "message",
"role": "assistant",
"content": [],
"model": "claude-3-5-sonnet-20240620",
"stop_reason": None,
"stop_sequence": None,
"usage": {"input_tokens": 25, "output_tokens": 1},
},
}
if self.sent_content_block_start is False:
self.sent_content_block_start = True
return {
"type": "content_block_start",
"index": 0,
"content_block": {"type": "text", "text": ""},
}
for chunk in self.completion_stream:
if chunk == "None" or chunk is None:
raise Exception
processed_chunk = litellm.AnthropicConfig().translate_streaming_openai_response_to_anthropic(
response=chunk
)
if (
processed_chunk["type"] == "message_delta"
and self.sent_content_block_finish is False
):
self.holding_chunk = processed_chunk
self.sent_content_block_finish = True
return {
"type": "content_block_stop",
"index": 0,
}
elif self.holding_chunk is not None:
return_chunk = self.holding_chunk
self.holding_chunk = processed_chunk
return return_chunk
else:
return processed_chunk
if self.sent_last_message is False:
self.sent_last_message = True
return {"type": "message_stop"}
raise StopIteration
except StopIteration:
if self.sent_last_message is False:
self.sent_last_message = True
return {"type": "message_stop"}
raise StopIteration
except Exception as e:
verbose_logger.error(
"Anthropic Adapter - {}\n{}".format(e, traceback.format_exc())
)
async def __anext__(self):
try:
if self.sent_first_chunk is False:
self.sent_first_chunk = True
return {
"type": "message_start",
"message": {
"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY",
"type": "message",
"role": "assistant",
"content": [],
"model": "claude-3-5-sonnet-20240620",
"stop_reason": None,
"stop_sequence": None,
"usage": {"input_tokens": 25, "output_tokens": 1},
},
}
if self.sent_content_block_start is False:
self.sent_content_block_start = True
return {
"type": "content_block_start",
"index": 0,
"content_block": {"type": "text", "text": ""},
}
async for chunk in self.completion_stream:
if chunk == "None" or chunk is None:
raise Exception
processed_chunk = litellm.AnthropicConfig().translate_streaming_openai_response_to_anthropic(
response=chunk
)
if (
processed_chunk["type"] == "message_delta"
and self.sent_content_block_finish is False
):
self.holding_chunk = processed_chunk
self.sent_content_block_finish = True
return {
"type": "content_block_stop",
"index": 0,
}
elif self.holding_chunk is not None:
return_chunk = self.holding_chunk
self.holding_chunk = processed_chunk
return return_chunk
else:
return processed_chunk
if self.holding_chunk is not None:
return_chunk = self.holding_chunk
self.holding_chunk = None
return return_chunk
if self.sent_last_message is False:
self.sent_last_message = True
return {"type": "message_stop"}
raise StopIteration
except StopIteration:
if self.sent_last_message is False:
self.sent_last_message = True
return {"type": "message_stop"}
raise StopAsyncIteration

View file

@ -10,7 +10,7 @@ from pydantic import BaseModel
from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.types.llms.openai import ChatCompletionRequest
from litellm.types.utils import ModelResponse
from litellm.types.utils import AdapterCompletionStreamWrapper, ModelResponse
class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
@ -76,7 +76,9 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
"""
pass
def translate_completion_output_params_streaming(self) -> Optional[BaseModel]:
def translate_completion_output_params_streaming(
self, completion_stream: Any
) -> Optional[AdapterCompletionStreamWrapper]:
"""
Translates the streaming chunk, from the OpenAI format to the custom format.
"""

View file

@ -5,13 +5,16 @@ import time
import types
from enum import Enum
from functools import partial
from typing import Callable, List, Optional, Union
from typing import Callable, List, Literal, Optional, Tuple, Union
import httpx # type: ignore
import requests # type: ignore
from openai.types.chat.chat_completion_chunk import Choice as OpenAIStreamingChoice
import litellm
import litellm.litellm_core_utils
import litellm.types
import litellm.types.utils
from litellm import verbose_logger
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import (
@ -33,8 +36,12 @@ from litellm.types.llms.anthropic import (
AnthropicResponseUsageBlock,
ContentBlockDelta,
ContentBlockStart,
ContentJsonBlockDelta,
ContentTextBlockDelta,
MessageBlockDelta,
MessageDelta,
MessageStartBlock,
UsageDelta,
)
from litellm.types.llms.openai import (
AllMessageValues,
@ -480,6 +487,74 @@ class AnthropicConfig:
return translated_obj
def _translate_streaming_openai_chunk_to_anthropic(
self, choices: List[OpenAIStreamingChoice]
) -> Tuple[
Literal["text_delta", "input_json_delta"],
Union[ContentTextBlockDelta, ContentJsonBlockDelta],
]:
text: str = ""
partial_json: Optional[str] = None
for choice in choices:
if choice.delta.content is not None:
text += choice.delta.content
elif choice.delta.tool_calls is not None:
partial_json = ""
for tool in choice.delta.tool_calls:
if (
tool.function is not None
and tool.function.arguments is not None
):
partial_json += tool.function.arguments
if partial_json is not None:
return "input_json_delta", ContentJsonBlockDelta(
type="input_json_delta", partial_json=partial_json
)
else:
return "text_delta", ContentTextBlockDelta(type="text_delta", text=text)
def translate_streaming_openai_response_to_anthropic(
self, response: litellm.ModelResponse
) -> Union[ContentBlockDelta, MessageBlockDelta]:
## base case - final chunk w/ finish reason
if response.choices[0].finish_reason is not None:
delta = MessageDelta(
stop_reason=self._translate_openai_finish_reason_to_anthropic(
response.choices[0].finish_reason
),
)
if getattr(response, "usage", None) is not None:
litellm_usage_chunk: Optional[litellm.Usage] = response.usage # type: ignore
elif (
hasattr(response, "_hidden_params")
and "usage" in response._hidden_params
):
litellm_usage_chunk = response._hidden_params["usage"]
else:
litellm_usage_chunk = None
if litellm_usage_chunk is not None:
usage_delta = UsageDelta(
input_tokens=litellm_usage_chunk.prompt_tokens or 0,
output_tokens=litellm_usage_chunk.completion_tokens or 0,
)
else:
usage_delta = UsageDelta(input_tokens=0, output_tokens=0)
return MessageBlockDelta(
type="message_delta", delta=delta, usage=usage_delta
)
(
type_of_content,
content_block_delta,
) = self._translate_streaming_openai_chunk_to_anthropic(
choices=response.choices # type: ignore
)
return ContentBlockDelta(
type="content_block_delta",
index=response.choices[0].index,
delta=content_block_delta,
)
# makes headers for API call
def validate_environment(api_key, user_headers, model):

View file

@ -125,7 +125,7 @@ from .llms.vertex_ai_partner import VertexAIPartnerModels
from .llms.vertex_httpx import VertexLLM
from .llms.watsonx import IBMWatsonXAI
from .types.llms.openai import HttpxBinaryResponseContent
from .types.utils import ChatCompletionMessageToolCall
from .types.utils import AdapterCompletionStreamWrapper, ChatCompletionMessageToolCall
encoding = tiktoken.get_encoding("cl100k_base")
from litellm.utils import (
@ -515,7 +515,7 @@ def mock_completion(
model_response = ModelResponse(stream=stream)
if stream is True:
# don't try to access stream object,
if kwargs.get("acompletion", False) == True:
if kwargs.get("acompletion", False) is True:
return CustomStreamWrapper(
completion_stream=async_mock_completion_streaming_obj(
model_response, mock_response=mock_response, model=model, n=n
@ -524,13 +524,14 @@ def mock_completion(
custom_llm_provider="openai",
logging_obj=logging,
)
response = mock_completion_streaming_obj(
model_response,
mock_response=mock_response,
return CustomStreamWrapper(
completion_stream=mock_completion_streaming_obj(
model_response, mock_response=mock_response, model=model, n=n
),
model=model,
n=n,
custom_llm_provider="openai",
logging_obj=logging,
)
return response
if n is None:
model_response.choices[0].message.content = mock_response # type: ignore
else:
@ -4037,7 +4038,9 @@ def text_completion(
###### Adapter Completion ################
async def aadapter_completion(*, adapter_id: str, **kwargs) -> Optional[BaseModel]:
async def aadapter_completion(
*, adapter_id: str, **kwargs
) -> Optional[Union[BaseModel, AdapterCompletionStreamWrapper]]:
"""
Implemented to handle async calls for adapter_completion()
"""
@ -4056,18 +4059,29 @@ async def aadapter_completion(*, adapter_id: str, **kwargs) -> Optional[BaseMode
new_kwargs = translation_obj.translate_completion_input_params(kwargs=kwargs)
response: ModelResponse = await acompletion(**new_kwargs) # type: ignore
response: Union[ModelResponse, CustomStreamWrapper] = await acompletion(**new_kwargs) # type: ignore
translated_response: Optional[
Union[BaseModel, AdapterCompletionStreamWrapper]
] = None
if isinstance(response, ModelResponse):
translated_response = translation_obj.translate_completion_output_params(
response=response
)
if isinstance(response, CustomStreamWrapper):
translated_response = (
translation_obj.translate_completion_output_params_streaming(
completion_stream=response
)
)
return translated_response
except Exception as e:
raise e
def adapter_completion(*, adapter_id: str, **kwargs) -> Optional[BaseModel]:
def adapter_completion(
*, adapter_id: str, **kwargs
) -> Optional[Union[BaseModel, AdapterCompletionStreamWrapper]]:
translation_obj: Optional[CustomLogger] = None
for item in litellm.adapters:
if item["id"] == adapter_id:
@ -4082,11 +4096,20 @@ def adapter_completion(*, adapter_id: str, **kwargs) -> Optional[BaseModel]:
new_kwargs = translation_obj.translate_completion_input_params(kwargs=kwargs)
response: ModelResponse = completion(**new_kwargs) # type: ignore
response: Union[ModelResponse, CustomStreamWrapper] = completion(**new_kwargs) # type: ignore
translated_response: Optional[Union[BaseModel, AdapterCompletionStreamWrapper]] = (
None
)
if isinstance(response, ModelResponse):
translated_response = translation_obj.translate_completion_output_params(
response=response
)
elif isinstance(response, CustomStreamWrapper) or inspect.isgenerator(response):
translated_response = (
translation_obj.translate_completion_output_params_streaming(
completion_stream=response
)
)
return translated_response

View file

@ -1,7 +1,7 @@
model_list:
- model_name: "*"
- model_name: "claude-3-5-sonnet-20240620"
litellm_params:
model: "*"
model: "claude-3-5-sonnet-20240620"
# litellm_settings:
# failure_callback: ["langfuse"]

View file

@ -2396,7 +2396,9 @@ async def async_data_generator(
user_api_key_dict=user_api_key_dict, response=chunk
)
if isinstance(chunk, BaseModel):
chunk = chunk.model_dump_json(exclude_none=True, exclude_unset=True)
try:
yield f"data: {chunk}\n\n"
except Exception as e:
@ -2437,6 +2439,59 @@ async def async_data_generator(
yield f"data: {error_returned}\n\n"
async def async_data_generator_anthropic(
response, user_api_key_dict: UserAPIKeyAuth, request_data: dict
):
verbose_proxy_logger.debug("inside generator")
try:
start_time = time.time()
async for chunk in response:
verbose_proxy_logger.debug(
"async_data_generator: received streaming chunk - {}".format(chunk)
)
### CALL HOOKS ### - modify outgoing data
chunk = await proxy_logging_obj.async_post_call_streaming_hook(
user_api_key_dict=user_api_key_dict, response=chunk
)
event_type = chunk.get("type")
try:
yield f"event: {event_type}\ndata:{json.dumps(chunk)}\n\n"
except Exception as e:
yield f"event: {event_type}\ndata:{str(e)}\n\n"
except Exception as e:
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.async_data_generator(): Exception occured - {}\n{}".format(
str(e), traceback.format_exc()
)
)
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict,
original_exception=e,
request_data=request_data,
)
verbose_proxy_logger.debug(
f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`"
)
router_model_names = llm_router.model_names if llm_router is not None else []
if isinstance(e, HTTPException):
raise e
else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}\n\n{error_traceback}"
proxy_exception = ProxyException(
message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
)
error_returned = json.dumps({"error": proxy_exception.to_dict()})
yield f"data: {error_returned}\n\n"
def select_data_generator(
response, user_api_key_dict: UserAPIKeyAuth, request_data: dict
):
@ -5379,6 +5434,19 @@ async def anthropic_response(
)
)
if (
"stream" in data and data["stream"] is True
): # use generate_responses to stream responses
selected_data_generator = async_data_generator_anthropic(
response=response,
user_api_key_dict=user_api_key_dict,
request_data=data,
)
return StreamingResponse(
selected_data_generator,
media_type="text/event-stream",
)
verbose_proxy_logger.info("\nResponse from Litellm:\n{}".format(response))
return response
except RejectedRequestError as e:
@ -5425,11 +5493,10 @@ async def anthropic_response(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.completion(): Exception occured - {}".format(
str(e)
"litellm.proxy.proxy_server.anthropic_response(): Exception occured - {}\n{}".format(
str(e), traceback.format_exc()
)
)
verbose_proxy_logger.debug(traceback.format_exc())
error_msg = f"{str(e)}"
raise ProxyException(
message=getattr(e, "message", error_msg),

View file

@ -8,6 +8,9 @@ import traceback
from dotenv import load_dotenv
import litellm.types
import litellm.types.utils
load_dotenv()
import io
import os
@ -15,6 +18,7 @@ import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from typing import Optional
from unittest.mock import MagicMock, patch
import pytest
@ -84,7 +88,22 @@ def test_anthropic_completion_input_translation_with_metadata():
assert translated_input["metadata"] == data["litellm_metadata"]
def test_anthropic_completion_e2e():
def streaming_format_tests(chunk: dict, idx: int):
"""
1st chunk - chunk.get("type") == "message_start"
2nd chunk - chunk.get("type") == "content_block_start"
3rd chunk - chunk.get("type") == "content_block_delta"
"""
if idx == 0:
assert chunk.get("type") == "message_start"
elif idx == 1:
assert chunk.get("type") == "content_block_start"
elif idx == 2:
assert chunk.get("type") == "content_block_delta"
@pytest.mark.parametrize("stream", [True]) # False
def test_anthropic_completion_e2e(stream):
litellm.set_verbose = True
litellm.adapters = [{"id": "anthropic", "adapter": anthropic_adapter}]
@ -95,13 +114,40 @@ def test_anthropic_completion_e2e():
messages=messages,
adapter_id="anthropic",
mock_response="This is a fake call",
stream=stream,
)
print("Response: {}".format(response))
assert response is not None
if stream is False:
assert isinstance(response, AnthropicResponse)
else:
"""
- ensure finish reason is returned
- assert content block is started and stopped
- ensure last chunk is 'message_stop'
"""
assert isinstance(response, litellm.types.utils.AdapterCompletionStreamWrapper)
finish_reason: Optional[str] = None
message_stop_received = False
content_block_started = False
content_block_finished = False
for idx, chunk in enumerate(response):
print(chunk)
streaming_format_tests(chunk=chunk, idx=idx)
if chunk.get("delta", {}).get("stop_reason") is not None:
finish_reason = chunk.get("delta", {}).get("stop_reason")
if chunk.get("type") == "message_stop":
message_stop_received = True
if chunk.get("type") == "content_block_stop":
content_block_finished = True
if chunk.get("type") == "content_block_start":
content_block_started = True
assert content_block_started and content_block_finished
assert finish_reason is not None
assert message_stop_received is True
@pytest.mark.asyncio

View file

@ -136,7 +136,7 @@ class ContentJsonBlockDelta(TypedDict):
class ContentBlockDelta(TypedDict):
type: str
type: Literal["content_block_delta"]
index: int
delta: Union[ContentTextBlockDelta, ContentJsonBlockDelta]

View file

@ -6,7 +6,7 @@ from typing import Dict, List, Literal, Optional, Tuple, Union
from openai._models import BaseModel as OpenAIObject
from pydantic import ConfigDict, Field, PrivateAttr
from typing_extensions import Dict, Required, TypedDict, override
from typing_extensions import Callable, Dict, Required, TypedDict, override
from ..litellm_core_utils.core_helpers import map_finish_reason
from .llms.openai import ChatCompletionToolCallChunk, ChatCompletionUsageBlock
@ -1069,3 +1069,36 @@ class LoggedLiteLLMParams(TypedDict, total=False):
output_cost_per_token: Optional[float]
output_cost_per_second: Optional[float]
cooldown_time: Optional[float]
class AdapterCompletionStreamWrapper:
def __init__(self, completion_stream):
self.completion_stream = completion_stream
def __iter__(self):
return self
def __aiter__(self):
return self
def __next__(self):
try:
for chunk in self.completion_stream:
if chunk == "None" or chunk is None:
raise Exception
return chunk
raise StopIteration
except StopIteration:
raise StopIteration
except Exception as e:
print(f"AdapterCompletionStreamWrapper - {e}") # noqa
async def __anext__(self):
try:
async for chunk in self.completion_stream:
if chunk == "None" or chunk is None:
raise Exception
return chunk
raise StopIteration
except StopIteration:
raise StopAsyncIteration