mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
feat(anthropic_adapter.py): support streaming requests for /v1/messages
endpoint
Fixes https://github.com/BerriAI/litellm/issues/5011
This commit is contained in:
parent
39a98a2882
commit
ac6c39c283
9 changed files with 425 additions and 35 deletions
|
@ -4,7 +4,7 @@ import json
|
|||
import os
|
||||
import traceback
|
||||
import uuid
|
||||
from typing import Literal, Optional
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
import dotenv
|
||||
import httpx
|
||||
|
@ -13,7 +13,12 @@ from pydantic import BaseModel
|
|||
import litellm
|
||||
from litellm import ChatCompletionRequest, verbose_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.types.llms.anthropic import AnthropicMessagesRequest, AnthropicResponse
|
||||
from litellm.types.llms.anthropic import (
|
||||
AnthropicMessagesRequest,
|
||||
AnthropicResponse,
|
||||
ContentBlockDelta,
|
||||
)
|
||||
from litellm.types.utils import AdapterCompletionStreamWrapper
|
||||
|
||||
|
||||
class AnthropicAdapter(CustomLogger):
|
||||
|
@ -43,8 +48,147 @@ class AnthropicAdapter(CustomLogger):
|
|||
response=response
|
||||
)
|
||||
|
||||
def translate_completion_output_params_streaming(self) -> Optional[BaseModel]:
|
||||
return super().translate_completion_output_params_streaming()
|
||||
def translate_completion_output_params_streaming(
|
||||
self, completion_stream: Any
|
||||
) -> AdapterCompletionStreamWrapper | None:
|
||||
return AnthropicStreamWrapper(completion_stream=completion_stream)
|
||||
|
||||
|
||||
anthropic_adapter = AnthropicAdapter()
|
||||
|
||||
|
||||
class AnthropicStreamWrapper(AdapterCompletionStreamWrapper):
|
||||
"""
|
||||
- first chunk return 'message_start'
|
||||
- content block must be started and stopped
|
||||
- finish_reason must map exactly to anthropic reason, else anthropic client won't be able to parse it.
|
||||
"""
|
||||
|
||||
sent_first_chunk: bool = False
|
||||
sent_content_block_start: bool = False
|
||||
sent_content_block_finish: bool = False
|
||||
sent_last_message: bool = False
|
||||
holding_chunk: Optional[Any] = None
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
if self.sent_first_chunk is False:
|
||||
self.sent_first_chunk = True
|
||||
return {
|
||||
"type": "message_start",
|
||||
"message": {
|
||||
"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [],
|
||||
"model": "claude-3-5-sonnet-20240620",
|
||||
"stop_reason": None,
|
||||
"stop_sequence": None,
|
||||
"usage": {"input_tokens": 25, "output_tokens": 1},
|
||||
},
|
||||
}
|
||||
if self.sent_content_block_start is False:
|
||||
self.sent_content_block_start = True
|
||||
return {
|
||||
"type": "content_block_start",
|
||||
"index": 0,
|
||||
"content_block": {"type": "text", "text": ""},
|
||||
}
|
||||
|
||||
for chunk in self.completion_stream:
|
||||
if chunk == "None" or chunk is None:
|
||||
raise Exception
|
||||
|
||||
processed_chunk = litellm.AnthropicConfig().translate_streaming_openai_response_to_anthropic(
|
||||
response=chunk
|
||||
)
|
||||
if (
|
||||
processed_chunk["type"] == "message_delta"
|
||||
and self.sent_content_block_finish is False
|
||||
):
|
||||
self.holding_chunk = processed_chunk
|
||||
self.sent_content_block_finish = True
|
||||
return {
|
||||
"type": "content_block_stop",
|
||||
"index": 0,
|
||||
}
|
||||
elif self.holding_chunk is not None:
|
||||
return_chunk = self.holding_chunk
|
||||
self.holding_chunk = processed_chunk
|
||||
return return_chunk
|
||||
else:
|
||||
return processed_chunk
|
||||
|
||||
if self.sent_last_message is False:
|
||||
self.sent_last_message = True
|
||||
return {"type": "message_stop"}
|
||||
raise StopIteration
|
||||
except StopIteration:
|
||||
if self.sent_last_message is False:
|
||||
self.sent_last_message = True
|
||||
return {"type": "message_stop"}
|
||||
raise StopIteration
|
||||
except Exception as e:
|
||||
verbose_logger.error(
|
||||
"Anthropic Adapter - {}\n{}".format(e, traceback.format_exc())
|
||||
)
|
||||
|
||||
async def __anext__(self):
|
||||
try:
|
||||
if self.sent_first_chunk is False:
|
||||
self.sent_first_chunk = True
|
||||
return {
|
||||
"type": "message_start",
|
||||
"message": {
|
||||
"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [],
|
||||
"model": "claude-3-5-sonnet-20240620",
|
||||
"stop_reason": None,
|
||||
"stop_sequence": None,
|
||||
"usage": {"input_tokens": 25, "output_tokens": 1},
|
||||
},
|
||||
}
|
||||
if self.sent_content_block_start is False:
|
||||
self.sent_content_block_start = True
|
||||
return {
|
||||
"type": "content_block_start",
|
||||
"index": 0,
|
||||
"content_block": {"type": "text", "text": ""},
|
||||
}
|
||||
async for chunk in self.completion_stream:
|
||||
if chunk == "None" or chunk is None:
|
||||
raise Exception
|
||||
processed_chunk = litellm.AnthropicConfig().translate_streaming_openai_response_to_anthropic(
|
||||
response=chunk
|
||||
)
|
||||
if (
|
||||
processed_chunk["type"] == "message_delta"
|
||||
and self.sent_content_block_finish is False
|
||||
):
|
||||
self.holding_chunk = processed_chunk
|
||||
self.sent_content_block_finish = True
|
||||
return {
|
||||
"type": "content_block_stop",
|
||||
"index": 0,
|
||||
}
|
||||
elif self.holding_chunk is not None:
|
||||
return_chunk = self.holding_chunk
|
||||
self.holding_chunk = processed_chunk
|
||||
return return_chunk
|
||||
else:
|
||||
return processed_chunk
|
||||
if self.holding_chunk is not None:
|
||||
return_chunk = self.holding_chunk
|
||||
self.holding_chunk = None
|
||||
return return_chunk
|
||||
if self.sent_last_message is False:
|
||||
self.sent_last_message = True
|
||||
return {"type": "message_stop"}
|
||||
raise StopIteration
|
||||
except StopIteration:
|
||||
if self.sent_last_message is False:
|
||||
self.sent_last_message = True
|
||||
return {"type": "message_stop"}
|
||||
raise StopAsyncIteration
|
||||
|
|
|
@ -10,7 +10,7 @@ from pydantic import BaseModel
|
|||
from litellm.caching import DualCache
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.types.llms.openai import ChatCompletionRequest
|
||||
from litellm.types.utils import ModelResponse
|
||||
from litellm.types.utils import AdapterCompletionStreamWrapper, ModelResponse
|
||||
|
||||
|
||||
class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
|
||||
|
@ -76,7 +76,9 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
|
|||
"""
|
||||
pass
|
||||
|
||||
def translate_completion_output_params_streaming(self) -> Optional[BaseModel]:
|
||||
def translate_completion_output_params_streaming(
|
||||
self, completion_stream: Any
|
||||
) -> Optional[AdapterCompletionStreamWrapper]:
|
||||
"""
|
||||
Translates the streaming chunk, from the OpenAI format to the custom format.
|
||||
"""
|
||||
|
|
|
@ -5,13 +5,16 @@ import time
|
|||
import types
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Callable, List, Optional, Union
|
||||
from typing import Callable, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import httpx # type: ignore
|
||||
import requests # type: ignore
|
||||
from openai.types.chat.chat_completion_chunk import Choice as OpenAIStreamingChoice
|
||||
|
||||
import litellm
|
||||
import litellm.litellm_core_utils
|
||||
import litellm.types
|
||||
import litellm.types.utils
|
||||
from litellm import verbose_logger
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
|
@ -33,8 +36,12 @@ from litellm.types.llms.anthropic import (
|
|||
AnthropicResponseUsageBlock,
|
||||
ContentBlockDelta,
|
||||
ContentBlockStart,
|
||||
ContentJsonBlockDelta,
|
||||
ContentTextBlockDelta,
|
||||
MessageBlockDelta,
|
||||
MessageDelta,
|
||||
MessageStartBlock,
|
||||
UsageDelta,
|
||||
)
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
|
@ -480,6 +487,74 @@ class AnthropicConfig:
|
|||
|
||||
return translated_obj
|
||||
|
||||
def _translate_streaming_openai_chunk_to_anthropic(
|
||||
self, choices: List[OpenAIStreamingChoice]
|
||||
) -> Tuple[
|
||||
Literal["text_delta", "input_json_delta"],
|
||||
Union[ContentTextBlockDelta, ContentJsonBlockDelta],
|
||||
]:
|
||||
text: str = ""
|
||||
partial_json: Optional[str] = None
|
||||
for choice in choices:
|
||||
if choice.delta.content is not None:
|
||||
text += choice.delta.content
|
||||
elif choice.delta.tool_calls is not None:
|
||||
partial_json = ""
|
||||
for tool in choice.delta.tool_calls:
|
||||
if (
|
||||
tool.function is not None
|
||||
and tool.function.arguments is not None
|
||||
):
|
||||
partial_json += tool.function.arguments
|
||||
|
||||
if partial_json is not None:
|
||||
return "input_json_delta", ContentJsonBlockDelta(
|
||||
type="input_json_delta", partial_json=partial_json
|
||||
)
|
||||
else:
|
||||
return "text_delta", ContentTextBlockDelta(type="text_delta", text=text)
|
||||
|
||||
def translate_streaming_openai_response_to_anthropic(
|
||||
self, response: litellm.ModelResponse
|
||||
) -> Union[ContentBlockDelta, MessageBlockDelta]:
|
||||
## base case - final chunk w/ finish reason
|
||||
if response.choices[0].finish_reason is not None:
|
||||
delta = MessageDelta(
|
||||
stop_reason=self._translate_openai_finish_reason_to_anthropic(
|
||||
response.choices[0].finish_reason
|
||||
),
|
||||
)
|
||||
if getattr(response, "usage", None) is not None:
|
||||
litellm_usage_chunk: Optional[litellm.Usage] = response.usage # type: ignore
|
||||
elif (
|
||||
hasattr(response, "_hidden_params")
|
||||
and "usage" in response._hidden_params
|
||||
):
|
||||
litellm_usage_chunk = response._hidden_params["usage"]
|
||||
else:
|
||||
litellm_usage_chunk = None
|
||||
if litellm_usage_chunk is not None:
|
||||
usage_delta = UsageDelta(
|
||||
input_tokens=litellm_usage_chunk.prompt_tokens or 0,
|
||||
output_tokens=litellm_usage_chunk.completion_tokens or 0,
|
||||
)
|
||||
else:
|
||||
usage_delta = UsageDelta(input_tokens=0, output_tokens=0)
|
||||
return MessageBlockDelta(
|
||||
type="message_delta", delta=delta, usage=usage_delta
|
||||
)
|
||||
(
|
||||
type_of_content,
|
||||
content_block_delta,
|
||||
) = self._translate_streaming_openai_chunk_to_anthropic(
|
||||
choices=response.choices # type: ignore
|
||||
)
|
||||
return ContentBlockDelta(
|
||||
type="content_block_delta",
|
||||
index=response.choices[0].index,
|
||||
delta=content_block_delta,
|
||||
)
|
||||
|
||||
|
||||
# makes headers for API call
|
||||
def validate_environment(api_key, user_headers, model):
|
||||
|
|
|
@ -125,7 +125,7 @@ from .llms.vertex_ai_partner import VertexAIPartnerModels
|
|||
from .llms.vertex_httpx import VertexLLM
|
||||
from .llms.watsonx import IBMWatsonXAI
|
||||
from .types.llms.openai import HttpxBinaryResponseContent
|
||||
from .types.utils import ChatCompletionMessageToolCall
|
||||
from .types.utils import AdapterCompletionStreamWrapper, ChatCompletionMessageToolCall
|
||||
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
from litellm.utils import (
|
||||
|
@ -515,7 +515,7 @@ def mock_completion(
|
|||
model_response = ModelResponse(stream=stream)
|
||||
if stream is True:
|
||||
# don't try to access stream object,
|
||||
if kwargs.get("acompletion", False) == True:
|
||||
if kwargs.get("acompletion", False) is True:
|
||||
return CustomStreamWrapper(
|
||||
completion_stream=async_mock_completion_streaming_obj(
|
||||
model_response, mock_response=mock_response, model=model, n=n
|
||||
|
@ -524,13 +524,14 @@ def mock_completion(
|
|||
custom_llm_provider="openai",
|
||||
logging_obj=logging,
|
||||
)
|
||||
response = mock_completion_streaming_obj(
|
||||
model_response,
|
||||
mock_response=mock_response,
|
||||
return CustomStreamWrapper(
|
||||
completion_stream=mock_completion_streaming_obj(
|
||||
model_response, mock_response=mock_response, model=model, n=n
|
||||
),
|
||||
model=model,
|
||||
n=n,
|
||||
custom_llm_provider="openai",
|
||||
logging_obj=logging,
|
||||
)
|
||||
return response
|
||||
if n is None:
|
||||
model_response.choices[0].message.content = mock_response # type: ignore
|
||||
else:
|
||||
|
@ -4037,7 +4038,9 @@ def text_completion(
|
|||
###### Adapter Completion ################
|
||||
|
||||
|
||||
async def aadapter_completion(*, adapter_id: str, **kwargs) -> Optional[BaseModel]:
|
||||
async def aadapter_completion(
|
||||
*, adapter_id: str, **kwargs
|
||||
) -> Optional[Union[BaseModel, AdapterCompletionStreamWrapper]]:
|
||||
"""
|
||||
Implemented to handle async calls for adapter_completion()
|
||||
"""
|
||||
|
@ -4056,18 +4059,29 @@ async def aadapter_completion(*, adapter_id: str, **kwargs) -> Optional[BaseMode
|
|||
|
||||
new_kwargs = translation_obj.translate_completion_input_params(kwargs=kwargs)
|
||||
|
||||
response: ModelResponse = await acompletion(**new_kwargs) # type: ignore
|
||||
|
||||
response: Union[ModelResponse, CustomStreamWrapper] = await acompletion(**new_kwargs) # type: ignore
|
||||
translated_response: Optional[
|
||||
Union[BaseModel, AdapterCompletionStreamWrapper]
|
||||
] = None
|
||||
if isinstance(response, ModelResponse):
|
||||
translated_response = translation_obj.translate_completion_output_params(
|
||||
response=response
|
||||
)
|
||||
if isinstance(response, CustomStreamWrapper):
|
||||
translated_response = (
|
||||
translation_obj.translate_completion_output_params_streaming(
|
||||
completion_stream=response
|
||||
)
|
||||
)
|
||||
|
||||
return translated_response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def adapter_completion(*, adapter_id: str, **kwargs) -> Optional[BaseModel]:
|
||||
def adapter_completion(
|
||||
*, adapter_id: str, **kwargs
|
||||
) -> Optional[Union[BaseModel, AdapterCompletionStreamWrapper]]:
|
||||
translation_obj: Optional[CustomLogger] = None
|
||||
for item in litellm.adapters:
|
||||
if item["id"] == adapter_id:
|
||||
|
@ -4082,11 +4096,20 @@ def adapter_completion(*, adapter_id: str, **kwargs) -> Optional[BaseModel]:
|
|||
|
||||
new_kwargs = translation_obj.translate_completion_input_params(kwargs=kwargs)
|
||||
|
||||
response: ModelResponse = completion(**new_kwargs) # type: ignore
|
||||
|
||||
response: Union[ModelResponse, CustomStreamWrapper] = completion(**new_kwargs) # type: ignore
|
||||
translated_response: Optional[Union[BaseModel, AdapterCompletionStreamWrapper]] = (
|
||||
None
|
||||
)
|
||||
if isinstance(response, ModelResponse):
|
||||
translated_response = translation_obj.translate_completion_output_params(
|
||||
response=response
|
||||
)
|
||||
elif isinstance(response, CustomStreamWrapper) or inspect.isgenerator(response):
|
||||
translated_response = (
|
||||
translation_obj.translate_completion_output_params_streaming(
|
||||
completion_stream=response
|
||||
)
|
||||
)
|
||||
|
||||
return translated_response
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
model_list:
|
||||
- model_name: "*"
|
||||
- model_name: "claude-3-5-sonnet-20240620"
|
||||
litellm_params:
|
||||
model: "*"
|
||||
model: "claude-3-5-sonnet-20240620"
|
||||
|
||||
# litellm_settings:
|
||||
# failure_callback: ["langfuse"]
|
||||
|
|
|
@ -2396,7 +2396,9 @@ async def async_data_generator(
|
|||
user_api_key_dict=user_api_key_dict, response=chunk
|
||||
)
|
||||
|
||||
if isinstance(chunk, BaseModel):
|
||||
chunk = chunk.model_dump_json(exclude_none=True, exclude_unset=True)
|
||||
|
||||
try:
|
||||
yield f"data: {chunk}\n\n"
|
||||
except Exception as e:
|
||||
|
@ -2437,6 +2439,59 @@ async def async_data_generator(
|
|||
yield f"data: {error_returned}\n\n"
|
||||
|
||||
|
||||
async def async_data_generator_anthropic(
|
||||
response, user_api_key_dict: UserAPIKeyAuth, request_data: dict
|
||||
):
|
||||
verbose_proxy_logger.debug("inside generator")
|
||||
try:
|
||||
start_time = time.time()
|
||||
async for chunk in response:
|
||||
verbose_proxy_logger.debug(
|
||||
"async_data_generator: received streaming chunk - {}".format(chunk)
|
||||
)
|
||||
### CALL HOOKS ### - modify outgoing data
|
||||
chunk = await proxy_logging_obj.async_post_call_streaming_hook(
|
||||
user_api_key_dict=user_api_key_dict, response=chunk
|
||||
)
|
||||
|
||||
event_type = chunk.get("type")
|
||||
|
||||
try:
|
||||
yield f"event: {event_type}\ndata:{json.dumps(chunk)}\n\n"
|
||||
except Exception as e:
|
||||
yield f"event: {event_type}\ndata:{str(e)}\n\n"
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.async_data_generator(): Exception occured - {}\n{}".format(
|
||||
str(e), traceback.format_exc()
|
||||
)
|
||||
)
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
original_exception=e,
|
||||
request_data=request_data,
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`"
|
||||
)
|
||||
router_model_names = llm_router.model_names if llm_router is not None else []
|
||||
|
||||
if isinstance(e, HTTPException):
|
||||
raise e
|
||||
else:
|
||||
error_traceback = traceback.format_exc()
|
||||
error_msg = f"{str(e)}\n\n{error_traceback}"
|
||||
|
||||
proxy_exception = ProxyException(
|
||||
message=getattr(e, "message", error_msg),
|
||||
type=getattr(e, "type", "None"),
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", 500),
|
||||
)
|
||||
error_returned = json.dumps({"error": proxy_exception.to_dict()})
|
||||
yield f"data: {error_returned}\n\n"
|
||||
|
||||
|
||||
def select_data_generator(
|
||||
response, user_api_key_dict: UserAPIKeyAuth, request_data: dict
|
||||
):
|
||||
|
@ -5379,6 +5434,19 @@ async def anthropic_response(
|
|||
)
|
||||
)
|
||||
|
||||
if (
|
||||
"stream" in data and data["stream"] is True
|
||||
): # use generate_responses to stream responses
|
||||
selected_data_generator = async_data_generator_anthropic(
|
||||
response=response,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
request_data=data,
|
||||
)
|
||||
return StreamingResponse(
|
||||
selected_data_generator,
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
verbose_proxy_logger.info("\nResponse from Litellm:\n{}".format(response))
|
||||
return response
|
||||
except RejectedRequestError as e:
|
||||
|
@ -5425,11 +5493,10 @@ async def anthropic_response(
|
|||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||
)
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.completion(): Exception occured - {}".format(
|
||||
str(e)
|
||||
"litellm.proxy.proxy_server.anthropic_response(): Exception occured - {}\n{}".format(
|
||||
str(e), traceback.format_exc()
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
error_msg = f"{str(e)}"
|
||||
raise ProxyException(
|
||||
message=getattr(e, "message", error_msg),
|
||||
|
|
|
@ -8,6 +8,9 @@ import traceback
|
|||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import litellm.types
|
||||
import litellm.types.utils
|
||||
|
||||
load_dotenv()
|
||||
import io
|
||||
import os
|
||||
|
@ -15,6 +18,7 @@ import os
|
|||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
from typing import Optional
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
@ -84,7 +88,22 @@ def test_anthropic_completion_input_translation_with_metadata():
|
|||
assert translated_input["metadata"] == data["litellm_metadata"]
|
||||
|
||||
|
||||
def test_anthropic_completion_e2e():
|
||||
def streaming_format_tests(chunk: dict, idx: int):
|
||||
"""
|
||||
1st chunk - chunk.get("type") == "message_start"
|
||||
2nd chunk - chunk.get("type") == "content_block_start"
|
||||
3rd chunk - chunk.get("type") == "content_block_delta"
|
||||
"""
|
||||
if idx == 0:
|
||||
assert chunk.get("type") == "message_start"
|
||||
elif idx == 1:
|
||||
assert chunk.get("type") == "content_block_start"
|
||||
elif idx == 2:
|
||||
assert chunk.get("type") == "content_block_delta"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("stream", [True]) # False
|
||||
def test_anthropic_completion_e2e(stream):
|
||||
litellm.set_verbose = True
|
||||
|
||||
litellm.adapters = [{"id": "anthropic", "adapter": anthropic_adapter}]
|
||||
|
@ -95,13 +114,40 @@ def test_anthropic_completion_e2e():
|
|||
messages=messages,
|
||||
adapter_id="anthropic",
|
||||
mock_response="This is a fake call",
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
print("Response: {}".format(response))
|
||||
|
||||
assert response is not None
|
||||
|
||||
if stream is False:
|
||||
assert isinstance(response, AnthropicResponse)
|
||||
else:
|
||||
"""
|
||||
- ensure finish reason is returned
|
||||
- assert content block is started and stopped
|
||||
- ensure last chunk is 'message_stop'
|
||||
"""
|
||||
assert isinstance(response, litellm.types.utils.AdapterCompletionStreamWrapper)
|
||||
finish_reason: Optional[str] = None
|
||||
message_stop_received = False
|
||||
content_block_started = False
|
||||
content_block_finished = False
|
||||
for idx, chunk in enumerate(response):
|
||||
print(chunk)
|
||||
streaming_format_tests(chunk=chunk, idx=idx)
|
||||
if chunk.get("delta", {}).get("stop_reason") is not None:
|
||||
finish_reason = chunk.get("delta", {}).get("stop_reason")
|
||||
if chunk.get("type") == "message_stop":
|
||||
message_stop_received = True
|
||||
if chunk.get("type") == "content_block_stop":
|
||||
content_block_finished = True
|
||||
if chunk.get("type") == "content_block_start":
|
||||
content_block_started = True
|
||||
assert content_block_started and content_block_finished
|
||||
assert finish_reason is not None
|
||||
assert message_stop_received is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
@ -136,7 +136,7 @@ class ContentJsonBlockDelta(TypedDict):
|
|||
|
||||
|
||||
class ContentBlockDelta(TypedDict):
|
||||
type: str
|
||||
type: Literal["content_block_delta"]
|
||||
index: int
|
||||
delta: Union[ContentTextBlockDelta, ContentJsonBlockDelta]
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ from typing import Dict, List, Literal, Optional, Tuple, Union
|
|||
|
||||
from openai._models import BaseModel as OpenAIObject
|
||||
from pydantic import ConfigDict, Field, PrivateAttr
|
||||
from typing_extensions import Dict, Required, TypedDict, override
|
||||
from typing_extensions import Callable, Dict, Required, TypedDict, override
|
||||
|
||||
from ..litellm_core_utils.core_helpers import map_finish_reason
|
||||
from .llms.openai import ChatCompletionToolCallChunk, ChatCompletionUsageBlock
|
||||
|
@ -1069,3 +1069,36 @@ class LoggedLiteLLMParams(TypedDict, total=False):
|
|||
output_cost_per_token: Optional[float]
|
||||
output_cost_per_second: Optional[float]
|
||||
cooldown_time: Optional[float]
|
||||
|
||||
|
||||
class AdapterCompletionStreamWrapper:
|
||||
def __init__(self, completion_stream):
|
||||
self.completion_stream = completion_stream
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
for chunk in self.completion_stream:
|
||||
if chunk == "None" or chunk is None:
|
||||
raise Exception
|
||||
return chunk
|
||||
raise StopIteration
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
except Exception as e:
|
||||
print(f"AdapterCompletionStreamWrapper - {e}") # noqa
|
||||
|
||||
async def __anext__(self):
|
||||
try:
|
||||
async for chunk in self.completion_stream:
|
||||
if chunk == "None" or chunk is None:
|
||||
raise Exception
|
||||
return chunk
|
||||
raise StopIteration
|
||||
except StopIteration:
|
||||
raise StopAsyncIteration
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue