Merge branch 'BerriAI:main' into main

This commit is contained in:
Sunny Wan 2025-03-13 19:37:22 -04:00 committed by GitHub
commit e01d12b878
317 changed files with 15980 additions and 5207 deletions

View file

@ -73,8 +73,19 @@ def remove_index_from_tool_calls(
def get_litellm_metadata_from_kwargs(kwargs: dict):
"""
Helper to get litellm metadata from all litellm request kwargs
Return `litellm_metadata` if it exists, otherwise return `metadata`
"""
return kwargs.get("litellm_params", {}).get("metadata", {})
litellm_params = kwargs.get("litellm_params", {})
if litellm_params:
metadata = litellm_params.get("metadata", {})
litellm_metadata = litellm_params.get("litellm_metadata", {})
if litellm_metadata:
return litellm_metadata
elif metadata:
return metadata
return {}
# Helper functions used for OTEL logging

View file

@ -0,0 +1,34 @@
"""Utils for accessing credentials."""
from typing import List
import litellm
from litellm.types.utils import CredentialItem
class CredentialAccessor:
@staticmethod
def get_credential_values(credential_name: str) -> dict:
"""Safe accessor for credentials."""
if not litellm.credential_list:
return {}
for credential in litellm.credential_list:
if credential.credential_name == credential_name:
return credential.credential_values.copy()
return {}
@staticmethod
def upsert_credentials(credentials: List[CredentialItem]):
"""Add a credential to the list of credentials."""
credential_names = [cred.credential_name for cred in litellm.credential_list]
for credential in credentials:
if credential.credential_name in credential_names:
# Find and replace the existing credential in the list
for i, existing_cred in enumerate(litellm.credential_list):
if existing_cred.credential_name == credential.credential_name:
litellm.credential_list[i] = credential
break
else:
litellm.credential_list.append(credential)

View file

@ -331,6 +331,7 @@ def exception_type( # type: ignore # noqa: PLR0915
model=model,
response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information,
body=getattr(original_exception, "body", None),
)
elif (
"Web server is returning an unknown error" in error_str
@ -421,6 +422,7 @@ def exception_type( # type: ignore # noqa: PLR0915
llm_provider=custom_llm_provider,
response=getattr(original_exception, "response", None),
litellm_debug_info=extra_information,
body=getattr(original_exception, "body", None),
)
elif original_exception.status_code == 429:
exception_mapping_worked = True
@ -1960,6 +1962,7 @@ def exception_type( # type: ignore # noqa: PLR0915
model=model,
litellm_debug_info=extra_information,
response=getattr(original_exception, "response", None),
body=getattr(original_exception, "body", None),
)
elif (
"The api_key client option must be set either by passing api_key to the client or by setting"
@ -1991,6 +1994,7 @@ def exception_type( # type: ignore # noqa: PLR0915
model=model,
litellm_debug_info=extra_information,
response=getattr(original_exception, "response", None),
body=getattr(original_exception, "body", None),
)
elif original_exception.status_code == 401:
exception_mapping_worked = True

View file

@ -57,6 +57,9 @@ def get_litellm_params(
prompt_variables: Optional[dict] = None,
async_call: Optional[bool] = None,
ssl_verify: Optional[bool] = None,
merge_reasoning_content_in_choices: Optional[bool] = None,
api_version: Optional[str] = None,
max_retries: Optional[int] = None,
**kwargs,
) -> dict:
litellm_params = {
@ -97,5 +100,15 @@ def get_litellm_params(
"prompt_variables": prompt_variables,
"async_call": async_call,
"ssl_verify": ssl_verify,
"merge_reasoning_content_in_choices": merge_reasoning_content_in_choices,
"api_version": api_version,
"azure_ad_token": kwargs.get("azure_ad_token"),
"tenant_id": kwargs.get("tenant_id"),
"client_id": kwargs.get("client_id"),
"client_secret": kwargs.get("client_secret"),
"azure_username": kwargs.get("azure_username"),
"azure_password": kwargs.get("azure_password"),
"max_retries": max_retries,
"timeout": kwargs.get("timeout"),
}
return litellm_params

View file

@ -25,6 +25,7 @@ from litellm import (
turn_off_message_logging,
)
from litellm._logging import _is_debugging_on, verbose_logger
from litellm.batches.batch_utils import _handle_completed_batch
from litellm.caching.caching import DualCache, InMemoryCache
from litellm.caching.caching_handler import LLMCachingHandler
from litellm.cost_calculator import _select_model_name_for_cost_calc
@ -38,11 +39,14 @@ from litellm.litellm_core_utils.redact_messages import (
redact_message_input_output_from_custom_logger,
redact_message_input_output_from_logging,
)
from litellm.responses.utils import ResponseAPILoggingUtils
from litellm.types.llms.openai import (
AllMessageValues,
Batch,
FineTuningJob,
HttpxBinaryResponseContent,
ResponseCompletedEvent,
ResponsesAPIResponse,
)
from litellm.types.rerank import RerankResponse
from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS
@ -50,9 +54,11 @@ from litellm.types.utils import (
CallTypes,
EmbeddingResponse,
ImageResponse,
LiteLLMBatch,
LiteLLMLoggingBaseClass,
ModelResponse,
ModelResponseStream,
RawRequestTypedDict,
StandardCallbackDynamicParams,
StandardLoggingAdditionalHeaders,
StandardLoggingHiddenParams,
@ -203,6 +209,7 @@ class Logging(LiteLLMLoggingBaseClass):
] = None,
applied_guardrails: Optional[List[str]] = None,
kwargs: Optional[Dict] = None,
log_raw_request_response: bool = False,
):
_input: Optional[str] = messages # save original value of messages
if messages is not None:
@ -231,6 +238,7 @@ class Logging(LiteLLMLoggingBaseClass):
self.sync_streaming_chunks: List[Any] = (
[]
) # for generating complete stream response
self.log_raw_request_response = log_raw_request_response
# Initialize dynamic callbacks
self.dynamic_input_callbacks: Optional[
@ -451,6 +459,18 @@ class Logging(LiteLLMLoggingBaseClass):
return model, messages, non_default_params
def _get_raw_request_body(self, data: Optional[Union[dict, str]]) -> dict:
if data is None:
return {"error": "Received empty dictionary for raw request body"}
if isinstance(data, str):
try:
return json.loads(data)
except Exception:
return {
"error": "Unable to parse raw request body. Got - {}".format(data)
}
return data
def _pre_call(self, input, api_key, model=None, additional_args={}):
"""
Common helper function across the sync + async pre-call function
@ -466,6 +486,7 @@ class Logging(LiteLLMLoggingBaseClass):
self.model_call_details["model"] = model
def pre_call(self, input, api_key, model=None, additional_args={}): # noqa: PLR0915
# Log the exact input to the LLM API
litellm.error_logs["PRE_CALL"] = locals()
try:
@ -483,28 +504,54 @@ class Logging(LiteLLMLoggingBaseClass):
additional_args=additional_args,
)
# log raw request to provider (like LangFuse) -- if opted in.
if log_raw_request_response is True:
if (
self.log_raw_request_response is True
or log_raw_request_response is True
):
_litellm_params = self.model_call_details.get("litellm_params", {})
_metadata = _litellm_params.get("metadata", {}) or {}
try:
# [Non-blocking Extra Debug Information in metadata]
if (
turn_off_message_logging is not None
and turn_off_message_logging is True
):
if turn_off_message_logging is True:
_metadata["raw_request"] = (
"redacted by litellm. \
'litellm.turn_off_message_logging=True'"
)
else:
curl_command = self._get_request_curl_command(
api_base=additional_args.get("api_base", ""),
headers=additional_args.get("headers", {}),
additional_args=additional_args,
data=additional_args.get("complete_input_dict", {}),
)
_metadata["raw_request"] = str(curl_command)
# split up, so it's easier to parse in the UI
self.model_call_details["raw_request_typed_dict"] = (
RawRequestTypedDict(
raw_request_api_base=str(
additional_args.get("api_base") or ""
),
raw_request_body=self._get_raw_request_body(
additional_args.get("complete_input_dict", {})
),
raw_request_headers=self._get_masked_headers(
additional_args.get("headers", {}) or {},
ignore_sensitive_headers=True,
),
error=None,
)
)
except Exception as e:
self.model_call_details["raw_request_typed_dict"] = (
RawRequestTypedDict(
error=str(e),
)
)
traceback.print_exc()
_metadata["raw_request"] = (
"Unable to Log \
raw request: {}".format(
@ -637,9 +684,14 @@ class Logging(LiteLLMLoggingBaseClass):
)
verbose_logger.debug(f"\033[92m{curl_command}\033[0m\n")
def _get_request_body(self, data: dict) -> str:
return str(data)
def _get_request_curl_command(
self, api_base: str, headers: dict, additional_args: dict, data: dict
self, api_base: str, headers: Optional[dict], additional_args: dict, data: dict
) -> str:
if headers is None:
headers = {}
curl_command = "\n\nPOST Request Sent from LiteLLM:\n"
curl_command += "curl -X POST \\\n"
curl_command += f"{api_base} \\\n"
@ -647,11 +699,10 @@ class Logging(LiteLLMLoggingBaseClass):
formatted_headers = " ".join(
[f"-H '{k}: {v}'" for k, v in masked_headers.items()]
)
curl_command += (
f"{formatted_headers} \\\n" if formatted_headers.strip() != "" else ""
)
curl_command += f"-d '{str(data)}'\n"
curl_command += f"-d '{self._get_request_body(data)}'\n"
if additional_args.get("request_str", None) is not None:
# print the sagemaker / bedrock client request
curl_command = "\nRequest Sent from LiteLLM:\n"
@ -660,12 +711,20 @@ class Logging(LiteLLMLoggingBaseClass):
curl_command = str(self.model_call_details)
return curl_command
def _get_masked_headers(self, headers: dict):
def _get_masked_headers(
self, headers: dict, ignore_sensitive_headers: bool = False
) -> dict:
"""
Internal debugging helper function
Masks the headers of the request sent from LiteLLM
"""
sensitive_keywords = [
"authorization",
"token",
"key",
"secret",
]
return {
k: (
(v[:-44] + "*" * 44)
@ -673,6 +732,11 @@ class Logging(LiteLLMLoggingBaseClass):
else "*****"
)
for k, v in headers.items()
if not ignore_sensitive_headers
or not any(
sensitive_keyword in k.lower()
for sensitive_keyword in sensitive_keywords
)
}
def post_call(
@ -790,6 +854,8 @@ class Logging(LiteLLMLoggingBaseClass):
RerankResponse,
Batch,
FineTuningJob,
ResponsesAPIResponse,
ResponseCompletedEvent,
],
cache_hit: Optional[bool] = None,
) -> Optional[float]:
@ -871,6 +937,24 @@ class Logging(LiteLLMLoggingBaseClass):
return None
async def _response_cost_calculator_async(
self,
result: Union[
ModelResponse,
ModelResponseStream,
EmbeddingResponse,
ImageResponse,
TranscriptionResponse,
TextCompletionResponse,
HttpxBinaryResponseContent,
RerankResponse,
Batch,
FineTuningJob,
],
cache_hit: Optional[bool] = None,
) -> Optional[float]:
return self._response_cost_calculator(result=result, cache_hit=cache_hit)
def should_run_callback(
self, callback: litellm.CALLBACK_TYPES, litellm_params: dict, event_hook: str
) -> bool:
@ -912,13 +996,16 @@ class Logging(LiteLLMLoggingBaseClass):
self.model_call_details["log_event_type"] = "successful_api_call"
self.model_call_details["end_time"] = end_time
self.model_call_details["cache_hit"] = cache_hit
if self.call_type == CallTypes.anthropic_messages.value:
result = self._handle_anthropic_messages_response_logging(result=result)
## if model in model cost map - log the response cost
## else set cost to None
if (
standard_logging_object is None
and result is not None
and self.stream is not True
): # handle streaming separately
):
if (
isinstance(result, ModelResponse)
or isinstance(result, ModelResponseStream)
@ -928,8 +1015,9 @@ class Logging(LiteLLMLoggingBaseClass):
or isinstance(result, TextCompletionResponse)
or isinstance(result, HttpxBinaryResponseContent) # tts
or isinstance(result, RerankResponse)
or isinstance(result, Batch)
or isinstance(result, FineTuningJob)
or isinstance(result, LiteLLMBatch)
or isinstance(result, ResponsesAPIResponse)
):
## HIDDEN PARAMS ##
hidden_params = getattr(result, "_hidden_params", {})
@ -1029,7 +1117,7 @@ class Logging(LiteLLMLoggingBaseClass):
## BUILD COMPLETE STREAMED RESPONSE
complete_streaming_response: Optional[
Union[ModelResponse, TextCompletionResponse]
Union[ModelResponse, TextCompletionResponse, ResponsesAPIResponse]
] = None
if "complete_streaming_response" in self.model_call_details:
return # break out of this.
@ -1525,6 +1613,20 @@ class Logging(LiteLLMLoggingBaseClass):
print_verbose(
"Logging Details LiteLLM-Async Success Call, cache_hit={}".format(cache_hit)
)
## CALCULATE COST FOR BATCH JOBS
if self.call_type == CallTypes.aretrieve_batch.value and isinstance(
result, LiteLLMBatch
):
response_cost, batch_usage, batch_models = await _handle_completed_batch(
batch=result, custom_llm_provider=self.custom_llm_provider
)
result._hidden_params["response_cost"] = response_cost
result._hidden_params["batch_models"] = batch_models
result.usage = batch_usage
start_time, end_time, result = self._success_handler_helper_fn(
start_time=start_time,
end_time=end_time,
@ -1532,11 +1634,12 @@ class Logging(LiteLLMLoggingBaseClass):
cache_hit=cache_hit,
standard_logging_object=kwargs.get("standard_logging_object", None),
)
## BUILD COMPLETE STREAMED RESPONSE
if "async_complete_streaming_response" in self.model_call_details:
return # break out of this.
complete_streaming_response: Optional[
Union[ModelResponse, TextCompletionResponse]
Union[ModelResponse, TextCompletionResponse, ResponsesAPIResponse]
] = self._get_assembled_streaming_response(
result=result,
start_time=start_time,
@ -2246,16 +2349,24 @@ class Logging(LiteLLMLoggingBaseClass):
def _get_assembled_streaming_response(
self,
result: Union[ModelResponse, TextCompletionResponse, ModelResponseStream, Any],
result: Union[
ModelResponse,
TextCompletionResponse,
ModelResponseStream,
ResponseCompletedEvent,
Any,
],
start_time: datetime.datetime,
end_time: datetime.datetime,
is_async: bool,
streaming_chunks: List[Any],
) -> Optional[Union[ModelResponse, TextCompletionResponse]]:
) -> Optional[Union[ModelResponse, TextCompletionResponse, ResponsesAPIResponse]]:
if isinstance(result, ModelResponse):
return result
elif isinstance(result, TextCompletionResponse):
return result
elif isinstance(result, ResponseCompletedEvent):
return result.response
elif isinstance(result, ModelResponseStream):
complete_streaming_response: Optional[
Union[ModelResponse, TextCompletionResponse]
@ -2270,6 +2381,37 @@ class Logging(LiteLLMLoggingBaseClass):
return complete_streaming_response
return None
def _handle_anthropic_messages_response_logging(self, result: Any) -> ModelResponse:
"""
Handles logging for Anthropic messages responses.
Args:
result: The response object from the model call
Returns:
The the response object from the model call
- For Non-streaming responses, we need to transform the response to a ModelResponse object.
- For streaming responses, anthropic_messages handler calls success_handler with a assembled ModelResponse.
"""
if self.stream and isinstance(result, ModelResponse):
return result
result = litellm.AnthropicConfig().transform_response(
raw_response=self.model_call_details["httpx_response"],
model_response=litellm.ModelResponse(),
model=self.model,
messages=[],
logging_obj=self,
optional_params={},
api_key="",
request_data={},
encoding=litellm.encoding,
json_mode=False,
litellm_params={},
)
return result
def set_callbacks(callback_list, function_id=None): # noqa: PLR0915
"""
@ -2983,6 +3125,12 @@ class StandardLoggingPayloadSetup:
elif isinstance(usage, Usage):
return usage
elif isinstance(usage, dict):
if ResponseAPILoggingUtils._is_response_api_usage(usage):
return (
ResponseAPILoggingUtils._transform_response_api_usage_to_chat_usage(
usage
)
)
return Usage(**usage)
raise ValueError(f"usage is required, got={usage} of type {type(usage)}")
@ -3086,6 +3234,7 @@ class StandardLoggingPayloadSetup:
response_cost=None,
additional_headers=None,
litellm_overhead_time_ms=None,
batch_models=None,
)
if hidden_params is not None:
for key in StandardLoggingHiddenParams.__annotations__.keys():
@ -3199,6 +3348,7 @@ def get_standard_logging_object_payload(
api_base=None,
response_cost=None,
litellm_overhead_time_ms=None,
batch_models=None,
)
)
@ -3483,6 +3633,7 @@ def create_dummy_standard_logging_payload() -> StandardLoggingPayload:
response_cost=None,
additional_headers=None,
litellm_overhead_time_ms=None,
batch_models=None,
)
# Convert numeric values to appropriate types

View file

@ -9,6 +9,7 @@ from typing import Dict, Iterable, List, Literal, Optional, Tuple, Union
import litellm
from litellm._logging import verbose_logger
from litellm.constants import RESPONSE_FORMAT_TOOL_NAME
from litellm.types.llms.openai import ChatCompletionThinkingBlock
from litellm.types.utils import (
ChatCompletionDeltaToolCall,
ChatCompletionMessageToolCall,
@ -128,12 +129,7 @@ def convert_to_streaming_response(response_object: Optional[dict] = None):
model_response_object = ModelResponse(stream=True)
choice_list = []
for idx, choice in enumerate(response_object["choices"]):
delta = Delta(
content=choice["message"].get("content", None),
role=choice["message"]["role"],
function_call=choice["message"].get("function_call", None),
tool_calls=choice["message"].get("tool_calls", None),
)
delta = Delta(**choice["message"])
finish_reason = choice.get("finish_reason", None)
if finish_reason is None:
# gpt-4 vision can return 'finish_reason' or 'finish_details'
@ -243,6 +239,24 @@ def _parse_content_for_reasoning(
return None, message_text
def _extract_reasoning_content(message: dict) -> Tuple[Optional[str], Optional[str]]:
"""
Extract reasoning content and main content from a message.
Args:
message (dict): The message dictionary that may contain reasoning_content
Returns:
tuple[Optional[str], Optional[str]]: A tuple of (reasoning_content, content)
"""
if "reasoning_content" in message:
return message["reasoning_content"], message["content"]
elif "reasoning" in message:
return message["reasoning"], message["content"]
else:
return _parse_content_for_reasoning(message.get("content"))
class LiteLLMResponseObjectHandler:
@staticmethod
@ -456,11 +470,16 @@ def convert_to_model_response_object( # noqa: PLR0915
provider_specific_fields[field] = choice["message"][field]
# Handle reasoning models that display `reasoning_content` within `content`
reasoning_content, content = _parse_content_for_reasoning(
choice["message"].get("content")
reasoning_content, content = _extract_reasoning_content(
choice["message"]
)
# Handle thinking models that display `thinking_blocks` within `content`
thinking_blocks: Optional[List[ChatCompletionThinkingBlock]] = None
if "thinking_blocks" in choice["message"]:
thinking_blocks = choice["message"]["thinking_blocks"]
provider_specific_fields["thinking_blocks"] = thinking_blocks
if reasoning_content:
provider_specific_fields["reasoning_content"] = (
reasoning_content
@ -474,6 +493,7 @@ def convert_to_model_response_object( # noqa: PLR0915
audio=choice["message"].get("audio", None),
provider_specific_fields=provider_specific_fields,
reasoning_content=reasoning_content,
thinking_blocks=thinking_blocks,
)
finish_reason = choice.get("finish_reason", None)
if finish_reason is None:

View file

@ -187,53 +187,125 @@ def ollama_pt(
final_prompt_value="### Response:",
messages=messages,
)
elif "llava" in model:
prompt = ""
images = []
for message in messages:
if isinstance(message["content"], str):
prompt += message["content"]
elif isinstance(message["content"], list):
# see https://docs.litellm.ai/docs/providers/openai#openai-vision-models
for element in message["content"]:
if isinstance(element, dict):
if element["type"] == "text":
prompt += element["text"]
elif element["type"] == "image_url":
base64_image = convert_to_ollama_image(
element["image_url"]["url"]
)
images.append(base64_image)
return {"prompt": prompt, "images": images}
else:
user_message_types = {"user", "tool", "function"}
msg_i = 0
images = []
prompt = ""
for message in messages:
role = message["role"]
content = message.get("content", "")
while msg_i < len(messages):
init_msg_i = msg_i
user_content_str = ""
## MERGE CONSECUTIVE USER CONTENT ##
while (
msg_i < len(messages) and messages[msg_i]["role"] in user_message_types
):
msg_content = messages[msg_i].get("content")
if msg_content:
if isinstance(msg_content, list):
for m in msg_content:
if m.get("type", "") == "image_url":
if isinstance(m["image_url"], str):
images.append(m["image_url"])
elif isinstance(m["image_url"], dict):
images.append(m["image_url"]["url"])
elif m.get("type", "") == "text":
user_content_str += m["text"]
else:
# Tool message content will always be a string
user_content_str += msg_content
if "tool_calls" in message:
tool_calls = []
msg_i += 1
for call in message["tool_calls"]:
call_id: str = call["id"]
function_name: str = call["function"]["name"]
arguments = json.loads(call["function"]["arguments"])
if user_content_str:
prompt += f"### User:\n{user_content_str}\n\n"
tool_calls.append(
{
"id": call_id,
"type": "function",
"function": {"name": function_name, "arguments": arguments},
}
assistant_content_str = ""
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
msg_content = messages[msg_i].get("content")
if msg_content:
if isinstance(msg_content, list):
for m in msg_content:
if m.get("type", "") == "text":
assistant_content_str += m["text"]
elif isinstance(msg_content, str):
# Tool message content will always be a string
assistant_content_str += msg_content
tool_calls = messages[msg_i].get("tool_calls")
ollama_tool_calls = []
if tool_calls:
for call in tool_calls:
call_id: str = call["id"]
function_name: str = call["function"]["name"]
arguments = json.loads(call["function"]["arguments"])
ollama_tool_calls.append(
{
"id": call_id,
"type": "function",
"function": {
"name": function_name,
"arguments": arguments,
},
}
)
if ollama_tool_calls:
assistant_content_str += (
f"Tool Calls: {json.dumps(ollama_tool_calls, indent=2)}"
)
prompt += f"### Assistant:\nTool Calls: {json.dumps(tool_calls, indent=2)}\n\n"
msg_i += 1
elif "tool_call_id" in message:
prompt += f"### User:\n{message['content']}\n\n"
if assistant_content_str:
prompt += f"### Assistant:\n{assistant_content_str}\n\n"
elif content:
prompt += f"### {role.capitalize()}:\n{content}\n\n"
if msg_i == init_msg_i: # prevent infinite loops
raise litellm.BadRequestError(
message=BAD_MESSAGE_ERROR_STR + f"passed in {messages[msg_i]}",
model=model,
llm_provider="ollama",
)
# prompt = ""
# images = []
# for message in messages:
# if isinstance(message["content"], str):
# prompt += message["content"]
# elif isinstance(message["content"], list):
# # see https://docs.litellm.ai/docs/providers/openai#openai-vision-models
# for element in message["content"]:
# if isinstance(element, dict):
# if element["type"] == "text":
# prompt += element["text"]
# elif element["type"] == "image_url":
# base64_image = convert_to_ollama_image(
# element["image_url"]["url"]
# )
# images.append(base64_image)
# if "tool_calls" in message:
# tool_calls = []
# for call in message["tool_calls"]:
# call_id: str = call["id"]
# function_name: str = call["function"]["name"]
# arguments = json.loads(call["function"]["arguments"])
# tool_calls.append(
# {
# "id": call_id,
# "type": "function",
# "function": {"name": function_name, "arguments": arguments},
# }
# )
# prompt += f"### Assistant:\nTool Calls: {json.dumps(tool_calls, indent=2)}\n\n"
# elif "tool_call_id" in message:
# prompt += f"### User:\n{message['content']}\n\n"
return {"prompt": prompt, "images": images}
return prompt
@ -680,12 +752,13 @@ def convert_generic_image_chunk_to_openai_image_obj(
Return:
"data:image/jpeg;base64,{base64_image}"
"""
return "data:{};{},{}".format(
image_chunk["media_type"], image_chunk["type"], image_chunk["data"]
)
media_type = image_chunk["media_type"]
return "data:{};{},{}".format(media_type, image_chunk["type"], image_chunk["data"])
def convert_to_anthropic_image_obj(openai_image_url: str) -> GenericImageParsingChunk:
def convert_to_anthropic_image_obj(
openai_image_url: str, format: Optional[str]
) -> GenericImageParsingChunk:
"""
Input:
"image_url": "data:image/jpeg;base64,{base64_image}",
@ -702,7 +775,11 @@ def convert_to_anthropic_image_obj(openai_image_url: str) -> GenericImageParsing
openai_image_url = convert_url_to_base64(url=openai_image_url)
# Extract the media type and base64 data
media_type, base64_data = openai_image_url.split("data:")[1].split(";base64,")
media_type = media_type.replace("\\/", "/")
if format:
media_type = format
else:
media_type = media_type.replace("\\/", "/")
return GenericImageParsingChunk(
type="base64",
@ -820,11 +897,12 @@ def anthropic_messages_pt_xml(messages: list):
if isinstance(messages[msg_i]["content"], list):
for m in messages[msg_i]["content"]:
if m.get("type", "") == "image_url":
format = m["image_url"].get("format")
user_content.append(
{
"type": "image",
"source": convert_to_anthropic_image_obj(
m["image_url"]["url"]
m["image_url"]["url"], format=format
),
}
)
@ -1156,10 +1234,13 @@ def convert_to_anthropic_tool_result(
)
elif content["type"] == "image_url":
if isinstance(content["image_url"], str):
image_chunk = convert_to_anthropic_image_obj(content["image_url"])
else:
image_chunk = convert_to_anthropic_image_obj(
content["image_url"]["url"]
content["image_url"], format=None
)
else:
format = content["image_url"].get("format")
image_chunk = convert_to_anthropic_image_obj(
content["image_url"]["url"], format=format
)
anthropic_content_list.append(
AnthropicMessagesImageParam(
@ -1282,6 +1363,7 @@ def add_cache_control_to_content(
AnthropicMessagesImageParam,
AnthropicMessagesTextParam,
AnthropicMessagesDocumentParam,
ChatCompletionThinkingBlock,
],
orignal_content_element: Union[dict, AllMessageValues],
):
@ -1317,6 +1399,7 @@ def _anthropic_content_element_factory(
data=image_chunk["data"],
),
)
return _anthropic_content_element
@ -1368,13 +1451,16 @@ def anthropic_messages_pt( # noqa: PLR0915
for m in user_message_types_block["content"]:
if m.get("type", "") == "image_url":
m = cast(ChatCompletionImageObject, m)
format: Optional[str] = None
if isinstance(m["image_url"], str):
image_chunk = convert_to_anthropic_image_obj(
openai_image_url=m["image_url"]
openai_image_url=m["image_url"], format=None
)
else:
format = m["image_url"].get("format")
image_chunk = convert_to_anthropic_image_obj(
openai_image_url=m["image_url"]["url"]
openai_image_url=m["image_url"]["url"],
format=format,
)
_anthropic_content_element = (
@ -1454,12 +1540,23 @@ def anthropic_messages_pt( # noqa: PLR0915
assistant_content_block["content"], list
):
for m in assistant_content_block["content"]:
# handle text
# handle thinking blocks
thinking_block = cast(str, m.get("thinking", ""))
text_block = cast(str, m.get("text", ""))
if (
m.get("type", "") == "text" and len(m.get("text", "")) > 0
m.get("type", "") == "thinking" and len(thinking_block) > 0
): # don't pass empty text blocks. anthropic api raises errors.
anthropic_message: Union[
ChatCompletionThinkingBlock,
AnthropicMessagesTextParam,
] = cast(ChatCompletionThinkingBlock, m)
assistant_content.append(anthropic_message)
# handle text
elif (
m.get("type", "") == "text" and len(text_block) > 0
): # don't pass empty text blocks. anthropic api raises errors.
anthropic_message = AnthropicMessagesTextParam(
type="text", text=m.get("text")
type="text", text=text_block
)
_cached_message = add_cache_control_to_content(
anthropic_content_element=anthropic_message,
@ -1512,6 +1609,7 @@ def anthropic_messages_pt( # noqa: PLR0915
msg_i += 1
if assistant_content:
new_messages.append({"role": "assistant", "content": assistant_content})
if msg_i == init_msg_i: # prevent infinite loops
@ -1520,17 +1618,6 @@ def anthropic_messages_pt( # noqa: PLR0915
model=model,
llm_provider=llm_provider,
)
if not new_messages or new_messages[0]["role"] != "user":
if litellm.modify_params:
new_messages.insert(
0, {"role": "user", "content": [{"type": "text", "text": "."}]}
)
else:
raise Exception(
"Invalid first message={}. Should always start with 'role'='user' for Anthropic. System prompt is sent separately for Anthropic. set 'litellm.modify_params = True' or 'litellm_settings:modify_params = True' on proxy, to insert a placeholder user message - '.' as the first message, ".format(
new_messages
)
)
if new_messages[-1]["role"] == "assistant":
if isinstance(new_messages[-1]["content"], str):
@ -2301,8 +2388,11 @@ class BedrockImageProcessor:
)
@classmethod
def process_image_sync(cls, image_url: str) -> BedrockContentBlock:
def process_image_sync(
cls, image_url: str, format: Optional[str] = None
) -> BedrockContentBlock:
"""Synchronous image processing."""
if "base64" in image_url:
img_bytes, mime_type, image_format = cls._parse_base64_image(image_url)
elif "http://" in image_url or "https://" in image_url:
@ -2313,11 +2403,17 @@ class BedrockImageProcessor:
"Unsupported image type. Expected either image url or base64 encoded string"
)
if format:
mime_type = format
image_format = mime_type.split("/")[1]
image_format = cls._validate_format(mime_type, image_format)
return cls._create_bedrock_block(img_bytes, mime_type, image_format)
@classmethod
async def process_image_async(cls, image_url: str) -> BedrockContentBlock:
async def process_image_async(
cls, image_url: str, format: Optional[str]
) -> BedrockContentBlock:
"""Asynchronous image processing."""
if "base64" in image_url:
@ -2332,6 +2428,10 @@ class BedrockImageProcessor:
"Unsupported image type. Expected either image url or base64 encoded string"
)
if format: # override with user-defined params
mime_type = format
image_format = mime_type.split("/")[1]
image_format = cls._validate_format(mime_type, image_format)
return cls._create_bedrock_block(img_bytes, mime_type, image_format)
@ -2819,12 +2919,14 @@ class BedrockConverseMessagesProcessor:
_part = BedrockContentBlock(text=element["text"])
_parts.append(_part)
elif element["type"] == "image_url":
format: Optional[str] = None
if isinstance(element["image_url"], dict):
image_url = element["image_url"]["url"]
format = element["image_url"].get("format")
else:
image_url = element["image_url"]
_part = await BedrockImageProcessor.process_image_async( # type: ignore
image_url=image_url
image_url=image_url, format=format
)
_parts.append(_part) # type: ignore
_cache_point_block = (
@ -2924,7 +3026,14 @@ class BedrockConverseMessagesProcessor:
assistants_parts: List[BedrockContentBlock] = []
for element in _assistant_content:
if isinstance(element, dict):
if element["type"] == "text":
if element["type"] == "thinking":
thinking_block = BedrockConverseMessagesProcessor.translate_thinking_blocks_to_reasoning_content_blocks(
thinking_blocks=[
cast(ChatCompletionThinkingBlock, element)
]
)
assistants_parts.extend(thinking_block)
elif element["type"] == "text":
assistants_part = BedrockContentBlock(
text=element["text"]
)
@ -2974,7 +3083,7 @@ class BedrockConverseMessagesProcessor:
reasoning_content_blocks: List[BedrockContentBlock] = []
for thinking_block in thinking_blocks:
reasoning_text = thinking_block.get("thinking")
reasoning_signature = thinking_block.get("signature_delta")
reasoning_signature = thinking_block.get("signature")
text_block = BedrockConverseReasoningTextBlock(
text=reasoning_text or "",
)
@ -3050,12 +3159,15 @@ def _bedrock_converse_messages_pt( # noqa: PLR0915
_part = BedrockContentBlock(text=element["text"])
_parts.append(_part)
elif element["type"] == "image_url":
format: Optional[str] = None
if isinstance(element["image_url"], dict):
image_url = element["image_url"]["url"]
format = element["image_url"].get("format")
else:
image_url = element["image_url"]
_part = BedrockImageProcessor.process_image_sync( # type: ignore
image_url=image_url
image_url=image_url,
format=format,
)
_parts.append(_part) # type: ignore
_cache_point_block = (
@ -3157,7 +3269,14 @@ def _bedrock_converse_messages_pt( # noqa: PLR0915
assistants_parts: List[BedrockContentBlock] = []
for element in _assistant_content:
if isinstance(element, dict):
if element["type"] == "text":
if element["type"] == "thinking":
thinking_block = BedrockConverseMessagesProcessor.translate_thinking_blocks_to_reasoning_content_blocks(
thinking_blocks=[
cast(ChatCompletionThinkingBlock, element)
]
)
assistants_parts.extend(thinking_block)
elif element["type"] == "text":
assistants_part = BedrockContentBlock(text=element["text"])
assistants_parts.append(assistants_part)
elif element["type"] == "image_url":

View file

@ -15,6 +15,7 @@ from litellm import verbose_logger
from litellm.litellm_core_utils.redact_messages import LiteLLMLoggingObject
from litellm.litellm_core_utils.thread_pool_executor import executor
from litellm.types.llms.openai import ChatCompletionChunk
from litellm.types.router import GenericLiteLLMParams
from litellm.types.utils import Delta
from litellm.types.utils import GenericStreamingChunk as GChunk
from litellm.types.utils import (
@ -70,6 +71,17 @@ class CustomStreamWrapper:
self.completion_stream = completion_stream
self.sent_first_chunk = False
self.sent_last_chunk = False
litellm_params: GenericLiteLLMParams = GenericLiteLLMParams(
**self.logging_obj.model_call_details.get("litellm_params", {})
)
self.merge_reasoning_content_in_choices: bool = (
litellm_params.merge_reasoning_content_in_choices or False
)
self.sent_first_thinking_block = False
self.sent_last_thinking_block = False
self.thinking_content = ""
self.system_fingerprint: Optional[str] = None
self.received_finish_reason: Optional[str] = None
self.intermittent_finish_reason: Optional[str] = (
@ -87,12 +99,7 @@ class CustomStreamWrapper:
self.holding_chunk = ""
self.complete_response = ""
self.response_uptil_now = ""
_model_info = (
self.logging_obj.model_call_details.get("litellm_params", {}).get(
"model_info", {}
)
or {}
)
_model_info: Dict = litellm_params.model_info or {}
_api_base = get_api_base(
model=model or "",
@ -630,7 +637,10 @@ class CustomStreamWrapper:
if isinstance(chunk, bytes):
chunk = chunk.decode("utf-8")
if "text_output" in chunk:
response = chunk.replace("data: ", "").strip()
response = (
CustomStreamWrapper._strip_sse_data_from_chunk(chunk) or ""
)
response = response.strip()
parsed_response = json.loads(response)
else:
return {
@ -873,6 +883,10 @@ class CustomStreamWrapper:
_index: Optional[int] = completion_obj.get("index")
if _index is not None:
model_response.choices[0].index = _index
self._optional_combine_thinking_block_in_choices(
model_response=model_response
)
print_verbose(f"returning model_response: {model_response}")
return model_response
else:
@ -929,6 +943,48 @@ class CustomStreamWrapper:
self.chunks.append(model_response)
return
def _optional_combine_thinking_block_in_choices(
self, model_response: ModelResponseStream
) -> None:
"""
UI's Like OpenWebUI expect to get 1 chunk with <think>...</think> tags in the chunk content
In place updates the model_response object with reasoning_content in content with <think>...</think> tags
Enabled when `merge_reasoning_content_in_choices=True` passed in request params
"""
if self.merge_reasoning_content_in_choices is True:
reasoning_content = getattr(
model_response.choices[0].delta, "reasoning_content", None
)
if reasoning_content:
if self.sent_first_thinking_block is False:
model_response.choices[0].delta.content += (
"<think>" + reasoning_content
)
self.sent_first_thinking_block = True
elif (
self.sent_first_thinking_block is True
and hasattr(model_response.choices[0].delta, "reasoning_content")
and model_response.choices[0].delta.reasoning_content
):
model_response.choices[0].delta.content = reasoning_content
elif (
self.sent_first_thinking_block is True
and not self.sent_last_thinking_block
and model_response.choices[0].delta.content
):
model_response.choices[0].delta.content = (
"</think>" + model_response.choices[0].delta.content
)
self.sent_last_thinking_block = True
if hasattr(model_response.choices[0].delta, "reasoning_content"):
del model_response.choices[0].delta.reasoning_content
return
def chunk_creator(self, chunk: Any): # type: ignore # noqa: PLR0915
model_response = self.model_response_creator()
response_obj: Dict[str, Any] = {}
@ -1775,6 +1831,42 @@ class CustomStreamWrapper:
extra_kwargs={},
)
@staticmethod
def _strip_sse_data_from_chunk(chunk: Optional[str]) -> Optional[str]:
"""
Strips the 'data: ' prefix from Server-Sent Events (SSE) chunks.
Some providers like sagemaker send it as `data:`, need to handle both
SSE messages are prefixed with 'data: ' which is part of the protocol,
not the actual content from the LLM. This method removes that prefix
and returns the actual content.
Args:
chunk: The SSE chunk that may contain the 'data: ' prefix (string or bytes)
Returns:
The chunk with the 'data: ' prefix removed, or the original chunk
if no prefix was found. Returns None if input is None.
See OpenAI Python Ref for this: https://github.com/openai/openai-python/blob/041bf5a8ec54da19aad0169671793c2078bd6173/openai/api_requestor.py#L100
"""
if chunk is None:
return None
if isinstance(chunk, str):
# OpenAI sends `data: `
if chunk.startswith("data: "):
# Strip the prefix and any leading whitespace that might follow it
_length_of_sse_data_prefix = len("data: ")
return chunk[_length_of_sse_data_prefix:]
elif chunk.startswith("data:"):
# Sagemaker sends `data:`, no trailing whitespace
_length_of_sse_data_prefix = len("data:")
return chunk[_length_of_sse_data_prefix:]
return chunk
def calculate_total_usage(chunks: List[ModelResponse]) -> Usage:
"""Assume most recent usage chunk has total usage uptil then."""