fix(add-exception-mapping-+-langfuse-exception-logging-for-streaming-exceptions): add exception mapping + langfuse exception logging for streaming exceptions

Fixes https://github.com/BerriAI/litellm/issues/4338
This commit is contained in:
Krrish Dholakia 2024-06-22 21:26:15 -07:00 committed by Ishaan Jaff
parent e20e8c2e74
commit 7f54c90459
4 changed files with 89 additions and 65 deletions

View file

@ -1,63 +1,64 @@
# What is this? # What is this?
## Initial implementation of calling bedrock via httpx client (allows for async calls). ## Initial implementation of calling bedrock via httpx client (allows for async calls).
## V1 - covers cohere + anthropic claude-3 support ## V1 - covers cohere + anthropic claude-3 support
from functools import partial import copy
import os, types
import json import json
from enum import Enum import os
import requests, copy # type: ignore
import time import time
import types
import urllib.parse
import uuid
from enum import Enum
from functools import partial
from typing import ( from typing import (
Any,
AsyncIterator,
Callable, Callable,
Optional, Iterator,
List, List,
Literal, Literal,
Union, Optional,
Any,
TypedDict,
Tuple, Tuple,
Iterator, TypedDict,
AsyncIterator, Union,
)
from litellm.utils import (
ModelResponse,
Usage,
CustomStreamWrapper,
get_secret,
) )
import httpx # type: ignore
import requests # type: ignore
import litellm
from litellm.caching import DualCache
from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.litellm_core_utils.litellm_logging import Logging from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.types.utils import Message, Choices
import litellm, uuid
from .prompt_templates.factory import (
prompt_factory,
custom_prompt,
cohere_message_pt,
construct_tool_use_system_prompt,
extract_between_tags,
parse_xml_params,
contains_tag,
_bedrock_converse_messages_pt,
_bedrock_tools_pt,
)
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler, AsyncHTTPHandler,
HTTPHandler, HTTPHandler,
_get_async_httpx_client, _get_async_httpx_client,
_get_httpx_client, _get_httpx_client,
) )
from .base import BaseLLM
import httpx # type: ignore
from .bedrock import BedrockError, convert_messages_to_prompt, ModelResponseIterator
from litellm.types.llms.bedrock import * from litellm.types.llms.bedrock import *
import urllib.parse
from litellm.types.llms.openai import ( from litellm.types.llms.openai import (
ChatCompletionDeltaChunk,
ChatCompletionResponseMessage, ChatCompletionResponseMessage,
ChatCompletionToolCallChunk, ChatCompletionToolCallChunk,
ChatCompletionToolCallFunctionChunk, ChatCompletionToolCallFunctionChunk,
ChatCompletionDeltaChunk,
) )
from litellm.caching import DualCache from litellm.types.utils import Choices, Message
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage, get_secret
from .base import BaseLLM
from .bedrock import BedrockError, ModelResponseIterator, convert_messages_to_prompt
from .prompt_templates.factory import (
_bedrock_converse_messages_pt,
_bedrock_tools_pt,
cohere_message_pt,
construct_tool_use_system_prompt,
contains_tag,
custom_prompt,
extract_between_tags,
parse_xml_params,
prompt_factory,
)
iam_cache = DualCache() iam_cache = DualCache()
@ -171,26 +172,34 @@ async def make_call(
messages: list, messages: list,
logging_obj, logging_obj,
): ):
if client is None: try:
client = _get_async_httpx_client() # Create a new client if none provided if client is None:
client = _get_async_httpx_client() # Create a new client if none provided
response = await client.post(api_base, headers=headers, data=data, stream=True) response = await client.post(api_base, headers=headers, data=data, stream=True)
if response.status_code != 200: if response.status_code != 200:
raise BedrockError(status_code=response.status_code, message=response.text) raise BedrockError(status_code=response.status_code, message=response.text)
decoder = AWSEventStreamDecoder(model=model) decoder = AWSEventStreamDecoder(model=model)
completion_stream = decoder.aiter_bytes(response.aiter_bytes(chunk_size=1024)) completion_stream = decoder.aiter_bytes(response.aiter_bytes(chunk_size=1024))
# LOGGING # LOGGING
logging_obj.post_call( logging_obj.post_call(
input=messages, input=messages,
api_key="", api_key="",
original_response="first stream response received", original_response="first stream response received",
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
) )
return completion_stream return completion_stream
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=str(err))
except httpx.TimeoutException as e:
raise BedrockError(status_code=408, message="Timeout error occurred.")
except Exception as e:
raise BedrockError(status_code=500, message=str(e))
def make_sync_call( def make_sync_call(
@ -704,7 +713,6 @@ class BedrockLLM(BaseLLM):
) -> Union[ModelResponse, CustomStreamWrapper]: ) -> Union[ModelResponse, CustomStreamWrapper]:
try: try:
import boto3 import boto3
from botocore.auth import SigV4Auth from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials from botocore.credentials import Credentials
@ -1650,7 +1658,6 @@ class BedrockConverseLLM(BaseLLM):
): ):
try: try:
import boto3 import boto3
from botocore.auth import SigV4Auth from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials from botocore.credentials import Credentials
@ -1904,8 +1911,8 @@ class BedrockConverseLLM(BaseLLM):
def get_response_stream_shape(): def get_response_stream_shape():
from botocore.model import ServiceModel
from botocore.loaders import Loader from botocore.loaders import Loader
from botocore.model import ServiceModel
loader = Loader() loader = Loader()
bedrock_service_dict = loader.load_service_model("bedrock-runtime", "service-2") bedrock_service_dict = loader.load_service_model("bedrock-runtime", "service-2")

View file

@ -1,10 +1,10 @@
model_list: model_list:
- model_name: my-fake-model - model_name: my-fake-model
litellm_params: litellm_params:
model: gpt-3.5-turbo model: bedrock/anthropic.claude-3-sonnet-20240229-v1:0
api_key: my-fake-key api_key: my-fake-key
mock_response: hello-world aws_bedrock_runtime_endpoint: http://127.0.0.1:8000
tpm: 60
litellm_settings: litellm_settings:
callbacks: ["dynamic_rate_limiter"] success_callback: ["langfuse"]
failure_callback: ["langfuse"]

View file

@ -2526,11 +2526,10 @@ async def async_data_generator(
yield f"data: {done_message}\n\n" yield f"data: {done_message}\n\n"
except Exception as e: except Exception as e:
verbose_proxy_logger.error( verbose_proxy_logger.error(
"litellm.proxy.proxy_server.async_data_generator(): Exception occured - {}".format( "litellm.proxy.proxy_server.async_data_generator(): Exception occured - {}\n{}".format(
str(e) str(e), traceback.format_exc()
) )
) )
verbose_proxy_logger.debug(traceback.format_exc())
await proxy_logging_obj.post_call_failure_hook( await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, user_api_key_dict=user_api_key_dict,
original_exception=e, original_exception=e,

View file

@ -9595,6 +9595,11 @@ class CustomStreamWrapper:
litellm.request_timeout litellm.request_timeout
) )
if self.logging_obj is not None: if self.logging_obj is not None:
## LOGGING
threading.Thread(
target=self.logging_obj.failure_handler,
args=(e, traceback_exception),
).start() # log response
# Handle any exceptions that might occur during streaming # Handle any exceptions that might occur during streaming
asyncio.create_task( asyncio.create_task(
self.logging_obj.async_failure_handler(e, traceback_exception) self.logging_obj.async_failure_handler(e, traceback_exception)
@ -9602,11 +9607,24 @@ class CustomStreamWrapper:
raise e raise e
except Exception as e: except Exception as e:
traceback_exception = traceback.format_exc() traceback_exception = traceback.format_exc()
# Handle any exceptions that might occur during streaming if self.logging_obj is not None:
asyncio.create_task( ## LOGGING
self.logging_obj.async_failure_handler(e, traceback_exception) # type: ignore threading.Thread(
target=self.logging_obj.failure_handler,
args=(e, traceback_exception),
).start() # log response
# Handle any exceptions that might occur during streaming
asyncio.create_task(
self.logging_obj.async_failure_handler(e, traceback_exception) # type: ignore
)
## Map to OpenAI Exception
raise exception_type(
model=self.model,
custom_llm_provider=self.custom_llm_provider,
original_exception=e,
completion_kwargs={},
extra_kwargs={},
) )
raise e
class TextCompletionStreamWrapper: class TextCompletionStreamWrapper: