mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
Merge branch 'BerriAI:main' into main
This commit is contained in:
commit
e01d12b878
317 changed files with 15980 additions and 5207 deletions
|
@ -73,8 +73,19 @@ def remove_index_from_tool_calls(
|
|||
def get_litellm_metadata_from_kwargs(kwargs: dict):
|
||||
"""
|
||||
Helper to get litellm metadata from all litellm request kwargs
|
||||
|
||||
Return `litellm_metadata` if it exists, otherwise return `metadata`
|
||||
"""
|
||||
return kwargs.get("litellm_params", {}).get("metadata", {})
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
if litellm_params:
|
||||
metadata = litellm_params.get("metadata", {})
|
||||
litellm_metadata = litellm_params.get("litellm_metadata", {})
|
||||
if litellm_metadata:
|
||||
return litellm_metadata
|
||||
elif metadata:
|
||||
return metadata
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
# Helper functions used for OTEL logging
|
||||
|
|
34
litellm/litellm_core_utils/credential_accessor.py
Normal file
34
litellm/litellm_core_utils/credential_accessor.py
Normal file
|
@ -0,0 +1,34 @@
|
|||
"""Utils for accessing credentials."""
|
||||
|
||||
from typing import List
|
||||
|
||||
import litellm
|
||||
from litellm.types.utils import CredentialItem
|
||||
|
||||
|
||||
class CredentialAccessor:
|
||||
@staticmethod
|
||||
def get_credential_values(credential_name: str) -> dict:
|
||||
"""Safe accessor for credentials."""
|
||||
if not litellm.credential_list:
|
||||
return {}
|
||||
for credential in litellm.credential_list:
|
||||
if credential.credential_name == credential_name:
|
||||
return credential.credential_values.copy()
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def upsert_credentials(credentials: List[CredentialItem]):
|
||||
"""Add a credential to the list of credentials."""
|
||||
|
||||
credential_names = [cred.credential_name for cred in litellm.credential_list]
|
||||
|
||||
for credential in credentials:
|
||||
if credential.credential_name in credential_names:
|
||||
# Find and replace the existing credential in the list
|
||||
for i, existing_cred in enumerate(litellm.credential_list):
|
||||
if existing_cred.credential_name == credential.credential_name:
|
||||
litellm.credential_list[i] = credential
|
||||
break
|
||||
else:
|
||||
litellm.credential_list.append(credential)
|
|
@ -331,6 +331,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
|||
model=model,
|
||||
response=getattr(original_exception, "response", None),
|
||||
litellm_debug_info=extra_information,
|
||||
body=getattr(original_exception, "body", None),
|
||||
)
|
||||
elif (
|
||||
"Web server is returning an unknown error" in error_str
|
||||
|
@ -421,6 +422,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
|||
llm_provider=custom_llm_provider,
|
||||
response=getattr(original_exception, "response", None),
|
||||
litellm_debug_info=extra_information,
|
||||
body=getattr(original_exception, "body", None),
|
||||
)
|
||||
elif original_exception.status_code == 429:
|
||||
exception_mapping_worked = True
|
||||
|
@ -1960,6 +1962,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
|||
model=model,
|
||||
litellm_debug_info=extra_information,
|
||||
response=getattr(original_exception, "response", None),
|
||||
body=getattr(original_exception, "body", None),
|
||||
)
|
||||
elif (
|
||||
"The api_key client option must be set either by passing api_key to the client or by setting"
|
||||
|
@ -1991,6 +1994,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
|||
model=model,
|
||||
litellm_debug_info=extra_information,
|
||||
response=getattr(original_exception, "response", None),
|
||||
body=getattr(original_exception, "body", None),
|
||||
)
|
||||
elif original_exception.status_code == 401:
|
||||
exception_mapping_worked = True
|
||||
|
|
|
@ -57,6 +57,9 @@ def get_litellm_params(
|
|||
prompt_variables: Optional[dict] = None,
|
||||
async_call: Optional[bool] = None,
|
||||
ssl_verify: Optional[bool] = None,
|
||||
merge_reasoning_content_in_choices: Optional[bool] = None,
|
||||
api_version: Optional[str] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
litellm_params = {
|
||||
|
@ -97,5 +100,15 @@ def get_litellm_params(
|
|||
"prompt_variables": prompt_variables,
|
||||
"async_call": async_call,
|
||||
"ssl_verify": ssl_verify,
|
||||
"merge_reasoning_content_in_choices": merge_reasoning_content_in_choices,
|
||||
"api_version": api_version,
|
||||
"azure_ad_token": kwargs.get("azure_ad_token"),
|
||||
"tenant_id": kwargs.get("tenant_id"),
|
||||
"client_id": kwargs.get("client_id"),
|
||||
"client_secret": kwargs.get("client_secret"),
|
||||
"azure_username": kwargs.get("azure_username"),
|
||||
"azure_password": kwargs.get("azure_password"),
|
||||
"max_retries": max_retries,
|
||||
"timeout": kwargs.get("timeout"),
|
||||
}
|
||||
return litellm_params
|
||||
|
|
|
@ -25,6 +25,7 @@ from litellm import (
|
|||
turn_off_message_logging,
|
||||
)
|
||||
from litellm._logging import _is_debugging_on, verbose_logger
|
||||
from litellm.batches.batch_utils import _handle_completed_batch
|
||||
from litellm.caching.caching import DualCache, InMemoryCache
|
||||
from litellm.caching.caching_handler import LLMCachingHandler
|
||||
from litellm.cost_calculator import _select_model_name_for_cost_calc
|
||||
|
@ -38,11 +39,14 @@ from litellm.litellm_core_utils.redact_messages import (
|
|||
redact_message_input_output_from_custom_logger,
|
||||
redact_message_input_output_from_logging,
|
||||
)
|
||||
from litellm.responses.utils import ResponseAPILoggingUtils
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
Batch,
|
||||
FineTuningJob,
|
||||
HttpxBinaryResponseContent,
|
||||
ResponseCompletedEvent,
|
||||
ResponsesAPIResponse,
|
||||
)
|
||||
from litellm.types.rerank import RerankResponse
|
||||
from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS
|
||||
|
@ -50,9 +54,11 @@ from litellm.types.utils import (
|
|||
CallTypes,
|
||||
EmbeddingResponse,
|
||||
ImageResponse,
|
||||
LiteLLMBatch,
|
||||
LiteLLMLoggingBaseClass,
|
||||
ModelResponse,
|
||||
ModelResponseStream,
|
||||
RawRequestTypedDict,
|
||||
StandardCallbackDynamicParams,
|
||||
StandardLoggingAdditionalHeaders,
|
||||
StandardLoggingHiddenParams,
|
||||
|
@ -203,6 +209,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
] = None,
|
||||
applied_guardrails: Optional[List[str]] = None,
|
||||
kwargs: Optional[Dict] = None,
|
||||
log_raw_request_response: bool = False,
|
||||
):
|
||||
_input: Optional[str] = messages # save original value of messages
|
||||
if messages is not None:
|
||||
|
@ -231,6 +238,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
self.sync_streaming_chunks: List[Any] = (
|
||||
[]
|
||||
) # for generating complete stream response
|
||||
self.log_raw_request_response = log_raw_request_response
|
||||
|
||||
# Initialize dynamic callbacks
|
||||
self.dynamic_input_callbacks: Optional[
|
||||
|
@ -451,6 +459,18 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
|
||||
return model, messages, non_default_params
|
||||
|
||||
def _get_raw_request_body(self, data: Optional[Union[dict, str]]) -> dict:
|
||||
if data is None:
|
||||
return {"error": "Received empty dictionary for raw request body"}
|
||||
if isinstance(data, str):
|
||||
try:
|
||||
return json.loads(data)
|
||||
except Exception:
|
||||
return {
|
||||
"error": "Unable to parse raw request body. Got - {}".format(data)
|
||||
}
|
||||
return data
|
||||
|
||||
def _pre_call(self, input, api_key, model=None, additional_args={}):
|
||||
"""
|
||||
Common helper function across the sync + async pre-call function
|
||||
|
@ -466,6 +486,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
self.model_call_details["model"] = model
|
||||
|
||||
def pre_call(self, input, api_key, model=None, additional_args={}): # noqa: PLR0915
|
||||
|
||||
# Log the exact input to the LLM API
|
||||
litellm.error_logs["PRE_CALL"] = locals()
|
||||
try:
|
||||
|
@ -483,28 +504,54 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
additional_args=additional_args,
|
||||
)
|
||||
# log raw request to provider (like LangFuse) -- if opted in.
|
||||
if log_raw_request_response is True:
|
||||
if (
|
||||
self.log_raw_request_response is True
|
||||
or log_raw_request_response is True
|
||||
):
|
||||
|
||||
_litellm_params = self.model_call_details.get("litellm_params", {})
|
||||
_metadata = _litellm_params.get("metadata", {}) or {}
|
||||
try:
|
||||
# [Non-blocking Extra Debug Information in metadata]
|
||||
if (
|
||||
turn_off_message_logging is not None
|
||||
and turn_off_message_logging is True
|
||||
):
|
||||
if turn_off_message_logging is True:
|
||||
|
||||
_metadata["raw_request"] = (
|
||||
"redacted by litellm. \
|
||||
'litellm.turn_off_message_logging=True'"
|
||||
)
|
||||
else:
|
||||
|
||||
curl_command = self._get_request_curl_command(
|
||||
api_base=additional_args.get("api_base", ""),
|
||||
headers=additional_args.get("headers", {}),
|
||||
additional_args=additional_args,
|
||||
data=additional_args.get("complete_input_dict", {}),
|
||||
)
|
||||
|
||||
_metadata["raw_request"] = str(curl_command)
|
||||
# split up, so it's easier to parse in the UI
|
||||
self.model_call_details["raw_request_typed_dict"] = (
|
||||
RawRequestTypedDict(
|
||||
raw_request_api_base=str(
|
||||
additional_args.get("api_base") or ""
|
||||
),
|
||||
raw_request_body=self._get_raw_request_body(
|
||||
additional_args.get("complete_input_dict", {})
|
||||
),
|
||||
raw_request_headers=self._get_masked_headers(
|
||||
additional_args.get("headers", {}) or {},
|
||||
ignore_sensitive_headers=True,
|
||||
),
|
||||
error=None,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
self.model_call_details["raw_request_typed_dict"] = (
|
||||
RawRequestTypedDict(
|
||||
error=str(e),
|
||||
)
|
||||
)
|
||||
traceback.print_exc()
|
||||
_metadata["raw_request"] = (
|
||||
"Unable to Log \
|
||||
raw request: {}".format(
|
||||
|
@ -637,9 +684,14 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
)
|
||||
verbose_logger.debug(f"\033[92m{curl_command}\033[0m\n")
|
||||
|
||||
def _get_request_body(self, data: dict) -> str:
|
||||
return str(data)
|
||||
|
||||
def _get_request_curl_command(
|
||||
self, api_base: str, headers: dict, additional_args: dict, data: dict
|
||||
self, api_base: str, headers: Optional[dict], additional_args: dict, data: dict
|
||||
) -> str:
|
||||
if headers is None:
|
||||
headers = {}
|
||||
curl_command = "\n\nPOST Request Sent from LiteLLM:\n"
|
||||
curl_command += "curl -X POST \\\n"
|
||||
curl_command += f"{api_base} \\\n"
|
||||
|
@ -647,11 +699,10 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
formatted_headers = " ".join(
|
||||
[f"-H '{k}: {v}'" for k, v in masked_headers.items()]
|
||||
)
|
||||
|
||||
curl_command += (
|
||||
f"{formatted_headers} \\\n" if formatted_headers.strip() != "" else ""
|
||||
)
|
||||
curl_command += f"-d '{str(data)}'\n"
|
||||
curl_command += f"-d '{self._get_request_body(data)}'\n"
|
||||
if additional_args.get("request_str", None) is not None:
|
||||
# print the sagemaker / bedrock client request
|
||||
curl_command = "\nRequest Sent from LiteLLM:\n"
|
||||
|
@ -660,12 +711,20 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
curl_command = str(self.model_call_details)
|
||||
return curl_command
|
||||
|
||||
def _get_masked_headers(self, headers: dict):
|
||||
def _get_masked_headers(
|
||||
self, headers: dict, ignore_sensitive_headers: bool = False
|
||||
) -> dict:
|
||||
"""
|
||||
Internal debugging helper function
|
||||
|
||||
Masks the headers of the request sent from LiteLLM
|
||||
"""
|
||||
sensitive_keywords = [
|
||||
"authorization",
|
||||
"token",
|
||||
"key",
|
||||
"secret",
|
||||
]
|
||||
return {
|
||||
k: (
|
||||
(v[:-44] + "*" * 44)
|
||||
|
@ -673,6 +732,11 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
else "*****"
|
||||
)
|
||||
for k, v in headers.items()
|
||||
if not ignore_sensitive_headers
|
||||
or not any(
|
||||
sensitive_keyword in k.lower()
|
||||
for sensitive_keyword in sensitive_keywords
|
||||
)
|
||||
}
|
||||
|
||||
def post_call(
|
||||
|
@ -790,6 +854,8 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
RerankResponse,
|
||||
Batch,
|
||||
FineTuningJob,
|
||||
ResponsesAPIResponse,
|
||||
ResponseCompletedEvent,
|
||||
],
|
||||
cache_hit: Optional[bool] = None,
|
||||
) -> Optional[float]:
|
||||
|
@ -871,6 +937,24 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
|
||||
return None
|
||||
|
||||
async def _response_cost_calculator_async(
|
||||
self,
|
||||
result: Union[
|
||||
ModelResponse,
|
||||
ModelResponseStream,
|
||||
EmbeddingResponse,
|
||||
ImageResponse,
|
||||
TranscriptionResponse,
|
||||
TextCompletionResponse,
|
||||
HttpxBinaryResponseContent,
|
||||
RerankResponse,
|
||||
Batch,
|
||||
FineTuningJob,
|
||||
],
|
||||
cache_hit: Optional[bool] = None,
|
||||
) -> Optional[float]:
|
||||
return self._response_cost_calculator(result=result, cache_hit=cache_hit)
|
||||
|
||||
def should_run_callback(
|
||||
self, callback: litellm.CALLBACK_TYPES, litellm_params: dict, event_hook: str
|
||||
) -> bool:
|
||||
|
@ -912,13 +996,16 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
self.model_call_details["log_event_type"] = "successful_api_call"
|
||||
self.model_call_details["end_time"] = end_time
|
||||
self.model_call_details["cache_hit"] = cache_hit
|
||||
|
||||
if self.call_type == CallTypes.anthropic_messages.value:
|
||||
result = self._handle_anthropic_messages_response_logging(result=result)
|
||||
## if model in model cost map - log the response cost
|
||||
## else set cost to None
|
||||
if (
|
||||
standard_logging_object is None
|
||||
and result is not None
|
||||
and self.stream is not True
|
||||
): # handle streaming separately
|
||||
):
|
||||
if (
|
||||
isinstance(result, ModelResponse)
|
||||
or isinstance(result, ModelResponseStream)
|
||||
|
@ -928,8 +1015,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
or isinstance(result, TextCompletionResponse)
|
||||
or isinstance(result, HttpxBinaryResponseContent) # tts
|
||||
or isinstance(result, RerankResponse)
|
||||
or isinstance(result, Batch)
|
||||
or isinstance(result, FineTuningJob)
|
||||
or isinstance(result, LiteLLMBatch)
|
||||
or isinstance(result, ResponsesAPIResponse)
|
||||
):
|
||||
## HIDDEN PARAMS ##
|
||||
hidden_params = getattr(result, "_hidden_params", {})
|
||||
|
@ -1029,7 +1117,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
|
||||
## BUILD COMPLETE STREAMED RESPONSE
|
||||
complete_streaming_response: Optional[
|
||||
Union[ModelResponse, TextCompletionResponse]
|
||||
Union[ModelResponse, TextCompletionResponse, ResponsesAPIResponse]
|
||||
] = None
|
||||
if "complete_streaming_response" in self.model_call_details:
|
||||
return # break out of this.
|
||||
|
@ -1525,6 +1613,20 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
print_verbose(
|
||||
"Logging Details LiteLLM-Async Success Call, cache_hit={}".format(cache_hit)
|
||||
)
|
||||
|
||||
## CALCULATE COST FOR BATCH JOBS
|
||||
if self.call_type == CallTypes.aretrieve_batch.value and isinstance(
|
||||
result, LiteLLMBatch
|
||||
):
|
||||
|
||||
response_cost, batch_usage, batch_models = await _handle_completed_batch(
|
||||
batch=result, custom_llm_provider=self.custom_llm_provider
|
||||
)
|
||||
|
||||
result._hidden_params["response_cost"] = response_cost
|
||||
result._hidden_params["batch_models"] = batch_models
|
||||
result.usage = batch_usage
|
||||
|
||||
start_time, end_time, result = self._success_handler_helper_fn(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
|
@ -1532,11 +1634,12 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
cache_hit=cache_hit,
|
||||
standard_logging_object=kwargs.get("standard_logging_object", None),
|
||||
)
|
||||
|
||||
## BUILD COMPLETE STREAMED RESPONSE
|
||||
if "async_complete_streaming_response" in self.model_call_details:
|
||||
return # break out of this.
|
||||
complete_streaming_response: Optional[
|
||||
Union[ModelResponse, TextCompletionResponse]
|
||||
Union[ModelResponse, TextCompletionResponse, ResponsesAPIResponse]
|
||||
] = self._get_assembled_streaming_response(
|
||||
result=result,
|
||||
start_time=start_time,
|
||||
|
@ -2246,16 +2349,24 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
|
||||
def _get_assembled_streaming_response(
|
||||
self,
|
||||
result: Union[ModelResponse, TextCompletionResponse, ModelResponseStream, Any],
|
||||
result: Union[
|
||||
ModelResponse,
|
||||
TextCompletionResponse,
|
||||
ModelResponseStream,
|
||||
ResponseCompletedEvent,
|
||||
Any,
|
||||
],
|
||||
start_time: datetime.datetime,
|
||||
end_time: datetime.datetime,
|
||||
is_async: bool,
|
||||
streaming_chunks: List[Any],
|
||||
) -> Optional[Union[ModelResponse, TextCompletionResponse]]:
|
||||
) -> Optional[Union[ModelResponse, TextCompletionResponse, ResponsesAPIResponse]]:
|
||||
if isinstance(result, ModelResponse):
|
||||
return result
|
||||
elif isinstance(result, TextCompletionResponse):
|
||||
return result
|
||||
elif isinstance(result, ResponseCompletedEvent):
|
||||
return result.response
|
||||
elif isinstance(result, ModelResponseStream):
|
||||
complete_streaming_response: Optional[
|
||||
Union[ModelResponse, TextCompletionResponse]
|
||||
|
@ -2270,6 +2381,37 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
return complete_streaming_response
|
||||
return None
|
||||
|
||||
def _handle_anthropic_messages_response_logging(self, result: Any) -> ModelResponse:
|
||||
"""
|
||||
Handles logging for Anthropic messages responses.
|
||||
|
||||
Args:
|
||||
result: The response object from the model call
|
||||
|
||||
Returns:
|
||||
The the response object from the model call
|
||||
|
||||
- For Non-streaming responses, we need to transform the response to a ModelResponse object.
|
||||
- For streaming responses, anthropic_messages handler calls success_handler with a assembled ModelResponse.
|
||||
"""
|
||||
if self.stream and isinstance(result, ModelResponse):
|
||||
return result
|
||||
|
||||
result = litellm.AnthropicConfig().transform_response(
|
||||
raw_response=self.model_call_details["httpx_response"],
|
||||
model_response=litellm.ModelResponse(),
|
||||
model=self.model,
|
||||
messages=[],
|
||||
logging_obj=self,
|
||||
optional_params={},
|
||||
api_key="",
|
||||
request_data={},
|
||||
encoding=litellm.encoding,
|
||||
json_mode=False,
|
||||
litellm_params={},
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def set_callbacks(callback_list, function_id=None): # noqa: PLR0915
|
||||
"""
|
||||
|
@ -2983,6 +3125,12 @@ class StandardLoggingPayloadSetup:
|
|||
elif isinstance(usage, Usage):
|
||||
return usage
|
||||
elif isinstance(usage, dict):
|
||||
if ResponseAPILoggingUtils._is_response_api_usage(usage):
|
||||
return (
|
||||
ResponseAPILoggingUtils._transform_response_api_usage_to_chat_usage(
|
||||
usage
|
||||
)
|
||||
)
|
||||
return Usage(**usage)
|
||||
|
||||
raise ValueError(f"usage is required, got={usage} of type {type(usage)}")
|
||||
|
@ -3086,6 +3234,7 @@ class StandardLoggingPayloadSetup:
|
|||
response_cost=None,
|
||||
additional_headers=None,
|
||||
litellm_overhead_time_ms=None,
|
||||
batch_models=None,
|
||||
)
|
||||
if hidden_params is not None:
|
||||
for key in StandardLoggingHiddenParams.__annotations__.keys():
|
||||
|
@ -3199,6 +3348,7 @@ def get_standard_logging_object_payload(
|
|||
api_base=None,
|
||||
response_cost=None,
|
||||
litellm_overhead_time_ms=None,
|
||||
batch_models=None,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -3483,6 +3633,7 @@ def create_dummy_standard_logging_payload() -> StandardLoggingPayload:
|
|||
response_cost=None,
|
||||
additional_headers=None,
|
||||
litellm_overhead_time_ms=None,
|
||||
batch_models=None,
|
||||
)
|
||||
|
||||
# Convert numeric values to appropriate types
|
||||
|
|
|
@ -9,6 +9,7 @@ from typing import Dict, Iterable, List, Literal, Optional, Tuple, Union
|
|||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.constants import RESPONSE_FORMAT_TOOL_NAME
|
||||
from litellm.types.llms.openai import ChatCompletionThinkingBlock
|
||||
from litellm.types.utils import (
|
||||
ChatCompletionDeltaToolCall,
|
||||
ChatCompletionMessageToolCall,
|
||||
|
@ -128,12 +129,7 @@ def convert_to_streaming_response(response_object: Optional[dict] = None):
|
|||
model_response_object = ModelResponse(stream=True)
|
||||
choice_list = []
|
||||
for idx, choice in enumerate(response_object["choices"]):
|
||||
delta = Delta(
|
||||
content=choice["message"].get("content", None),
|
||||
role=choice["message"]["role"],
|
||||
function_call=choice["message"].get("function_call", None),
|
||||
tool_calls=choice["message"].get("tool_calls", None),
|
||||
)
|
||||
delta = Delta(**choice["message"])
|
||||
finish_reason = choice.get("finish_reason", None)
|
||||
if finish_reason is None:
|
||||
# gpt-4 vision can return 'finish_reason' or 'finish_details'
|
||||
|
@ -243,6 +239,24 @@ def _parse_content_for_reasoning(
|
|||
return None, message_text
|
||||
|
||||
|
||||
def _extract_reasoning_content(message: dict) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
Extract reasoning content and main content from a message.
|
||||
|
||||
Args:
|
||||
message (dict): The message dictionary that may contain reasoning_content
|
||||
|
||||
Returns:
|
||||
tuple[Optional[str], Optional[str]]: A tuple of (reasoning_content, content)
|
||||
"""
|
||||
if "reasoning_content" in message:
|
||||
return message["reasoning_content"], message["content"]
|
||||
elif "reasoning" in message:
|
||||
return message["reasoning"], message["content"]
|
||||
else:
|
||||
return _parse_content_for_reasoning(message.get("content"))
|
||||
|
||||
|
||||
class LiteLLMResponseObjectHandler:
|
||||
|
||||
@staticmethod
|
||||
|
@ -456,11 +470,16 @@ def convert_to_model_response_object( # noqa: PLR0915
|
|||
provider_specific_fields[field] = choice["message"][field]
|
||||
|
||||
# Handle reasoning models that display `reasoning_content` within `content`
|
||||
|
||||
reasoning_content, content = _parse_content_for_reasoning(
|
||||
choice["message"].get("content")
|
||||
reasoning_content, content = _extract_reasoning_content(
|
||||
choice["message"]
|
||||
)
|
||||
|
||||
# Handle thinking models that display `thinking_blocks` within `content`
|
||||
thinking_blocks: Optional[List[ChatCompletionThinkingBlock]] = None
|
||||
if "thinking_blocks" in choice["message"]:
|
||||
thinking_blocks = choice["message"]["thinking_blocks"]
|
||||
provider_specific_fields["thinking_blocks"] = thinking_blocks
|
||||
|
||||
if reasoning_content:
|
||||
provider_specific_fields["reasoning_content"] = (
|
||||
reasoning_content
|
||||
|
@ -474,6 +493,7 @@ def convert_to_model_response_object( # noqa: PLR0915
|
|||
audio=choice["message"].get("audio", None),
|
||||
provider_specific_fields=provider_specific_fields,
|
||||
reasoning_content=reasoning_content,
|
||||
thinking_blocks=thinking_blocks,
|
||||
)
|
||||
finish_reason = choice.get("finish_reason", None)
|
||||
if finish_reason is None:
|
||||
|
|
|
@ -187,53 +187,125 @@ def ollama_pt(
|
|||
final_prompt_value="### Response:",
|
||||
messages=messages,
|
||||
)
|
||||
elif "llava" in model:
|
||||
prompt = ""
|
||||
images = []
|
||||
for message in messages:
|
||||
if isinstance(message["content"], str):
|
||||
prompt += message["content"]
|
||||
elif isinstance(message["content"], list):
|
||||
# see https://docs.litellm.ai/docs/providers/openai#openai-vision-models
|
||||
for element in message["content"]:
|
||||
if isinstance(element, dict):
|
||||
if element["type"] == "text":
|
||||
prompt += element["text"]
|
||||
elif element["type"] == "image_url":
|
||||
base64_image = convert_to_ollama_image(
|
||||
element["image_url"]["url"]
|
||||
)
|
||||
images.append(base64_image)
|
||||
return {"prompt": prompt, "images": images}
|
||||
else:
|
||||
user_message_types = {"user", "tool", "function"}
|
||||
msg_i = 0
|
||||
images = []
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
role = message["role"]
|
||||
content = message.get("content", "")
|
||||
while msg_i < len(messages):
|
||||
init_msg_i = msg_i
|
||||
user_content_str = ""
|
||||
## MERGE CONSECUTIVE USER CONTENT ##
|
||||
while (
|
||||
msg_i < len(messages) and messages[msg_i]["role"] in user_message_types
|
||||
):
|
||||
msg_content = messages[msg_i].get("content")
|
||||
if msg_content:
|
||||
if isinstance(msg_content, list):
|
||||
for m in msg_content:
|
||||
if m.get("type", "") == "image_url":
|
||||
if isinstance(m["image_url"], str):
|
||||
images.append(m["image_url"])
|
||||
elif isinstance(m["image_url"], dict):
|
||||
images.append(m["image_url"]["url"])
|
||||
elif m.get("type", "") == "text":
|
||||
user_content_str += m["text"]
|
||||
else:
|
||||
# Tool message content will always be a string
|
||||
user_content_str += msg_content
|
||||
|
||||
if "tool_calls" in message:
|
||||
tool_calls = []
|
||||
msg_i += 1
|
||||
|
||||
for call in message["tool_calls"]:
|
||||
call_id: str = call["id"]
|
||||
function_name: str = call["function"]["name"]
|
||||
arguments = json.loads(call["function"]["arguments"])
|
||||
if user_content_str:
|
||||
prompt += f"### User:\n{user_content_str}\n\n"
|
||||
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {"name": function_name, "arguments": arguments},
|
||||
}
|
||||
assistant_content_str = ""
|
||||
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
|
||||
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
|
||||
msg_content = messages[msg_i].get("content")
|
||||
if msg_content:
|
||||
if isinstance(msg_content, list):
|
||||
for m in msg_content:
|
||||
if m.get("type", "") == "text":
|
||||
assistant_content_str += m["text"]
|
||||
elif isinstance(msg_content, str):
|
||||
# Tool message content will always be a string
|
||||
assistant_content_str += msg_content
|
||||
|
||||
tool_calls = messages[msg_i].get("tool_calls")
|
||||
ollama_tool_calls = []
|
||||
if tool_calls:
|
||||
for call in tool_calls:
|
||||
call_id: str = call["id"]
|
||||
function_name: str = call["function"]["name"]
|
||||
arguments = json.loads(call["function"]["arguments"])
|
||||
|
||||
ollama_tool_calls.append(
|
||||
{
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": function_name,
|
||||
"arguments": arguments,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
if ollama_tool_calls:
|
||||
assistant_content_str += (
|
||||
f"Tool Calls: {json.dumps(ollama_tool_calls, indent=2)}"
|
||||
)
|
||||
|
||||
prompt += f"### Assistant:\nTool Calls: {json.dumps(tool_calls, indent=2)}\n\n"
|
||||
msg_i += 1
|
||||
|
||||
elif "tool_call_id" in message:
|
||||
prompt += f"### User:\n{message['content']}\n\n"
|
||||
if assistant_content_str:
|
||||
prompt += f"### Assistant:\n{assistant_content_str}\n\n"
|
||||
|
||||
elif content:
|
||||
prompt += f"### {role.capitalize()}:\n{content}\n\n"
|
||||
if msg_i == init_msg_i: # prevent infinite loops
|
||||
raise litellm.BadRequestError(
|
||||
message=BAD_MESSAGE_ERROR_STR + f"passed in {messages[msg_i]}",
|
||||
model=model,
|
||||
llm_provider="ollama",
|
||||
)
|
||||
# prompt = ""
|
||||
# images = []
|
||||
# for message in messages:
|
||||
# if isinstance(message["content"], str):
|
||||
# prompt += message["content"]
|
||||
# elif isinstance(message["content"], list):
|
||||
# # see https://docs.litellm.ai/docs/providers/openai#openai-vision-models
|
||||
# for element in message["content"]:
|
||||
# if isinstance(element, dict):
|
||||
# if element["type"] == "text":
|
||||
# prompt += element["text"]
|
||||
# elif element["type"] == "image_url":
|
||||
# base64_image = convert_to_ollama_image(
|
||||
# element["image_url"]["url"]
|
||||
# )
|
||||
# images.append(base64_image)
|
||||
|
||||
# if "tool_calls" in message:
|
||||
# tool_calls = []
|
||||
|
||||
# for call in message["tool_calls"]:
|
||||
# call_id: str = call["id"]
|
||||
# function_name: str = call["function"]["name"]
|
||||
# arguments = json.loads(call["function"]["arguments"])
|
||||
|
||||
# tool_calls.append(
|
||||
# {
|
||||
# "id": call_id,
|
||||
# "type": "function",
|
||||
# "function": {"name": function_name, "arguments": arguments},
|
||||
# }
|
||||
# )
|
||||
|
||||
# prompt += f"### Assistant:\nTool Calls: {json.dumps(tool_calls, indent=2)}\n\n"
|
||||
|
||||
# elif "tool_call_id" in message:
|
||||
# prompt += f"### User:\n{message['content']}\n\n"
|
||||
|
||||
return {"prompt": prompt, "images": images}
|
||||
|
||||
return prompt
|
||||
|
||||
|
@ -680,12 +752,13 @@ def convert_generic_image_chunk_to_openai_image_obj(
|
|||
Return:
|
||||
"data:image/jpeg;base64,{base64_image}"
|
||||
"""
|
||||
return "data:{};{},{}".format(
|
||||
image_chunk["media_type"], image_chunk["type"], image_chunk["data"]
|
||||
)
|
||||
media_type = image_chunk["media_type"]
|
||||
return "data:{};{},{}".format(media_type, image_chunk["type"], image_chunk["data"])
|
||||
|
||||
|
||||
def convert_to_anthropic_image_obj(openai_image_url: str) -> GenericImageParsingChunk:
|
||||
def convert_to_anthropic_image_obj(
|
||||
openai_image_url: str, format: Optional[str]
|
||||
) -> GenericImageParsingChunk:
|
||||
"""
|
||||
Input:
|
||||
"image_url": "data:image/jpeg;base64,{base64_image}",
|
||||
|
@ -702,7 +775,11 @@ def convert_to_anthropic_image_obj(openai_image_url: str) -> GenericImageParsing
|
|||
openai_image_url = convert_url_to_base64(url=openai_image_url)
|
||||
# Extract the media type and base64 data
|
||||
media_type, base64_data = openai_image_url.split("data:")[1].split(";base64,")
|
||||
media_type = media_type.replace("\\/", "/")
|
||||
|
||||
if format:
|
||||
media_type = format
|
||||
else:
|
||||
media_type = media_type.replace("\\/", "/")
|
||||
|
||||
return GenericImageParsingChunk(
|
||||
type="base64",
|
||||
|
@ -820,11 +897,12 @@ def anthropic_messages_pt_xml(messages: list):
|
|||
if isinstance(messages[msg_i]["content"], list):
|
||||
for m in messages[msg_i]["content"]:
|
||||
if m.get("type", "") == "image_url":
|
||||
format = m["image_url"].get("format")
|
||||
user_content.append(
|
||||
{
|
||||
"type": "image",
|
||||
"source": convert_to_anthropic_image_obj(
|
||||
m["image_url"]["url"]
|
||||
m["image_url"]["url"], format=format
|
||||
),
|
||||
}
|
||||
)
|
||||
|
@ -1156,10 +1234,13 @@ def convert_to_anthropic_tool_result(
|
|||
)
|
||||
elif content["type"] == "image_url":
|
||||
if isinstance(content["image_url"], str):
|
||||
image_chunk = convert_to_anthropic_image_obj(content["image_url"])
|
||||
else:
|
||||
image_chunk = convert_to_anthropic_image_obj(
|
||||
content["image_url"]["url"]
|
||||
content["image_url"], format=None
|
||||
)
|
||||
else:
|
||||
format = content["image_url"].get("format")
|
||||
image_chunk = convert_to_anthropic_image_obj(
|
||||
content["image_url"]["url"], format=format
|
||||
)
|
||||
anthropic_content_list.append(
|
||||
AnthropicMessagesImageParam(
|
||||
|
@ -1282,6 +1363,7 @@ def add_cache_control_to_content(
|
|||
AnthropicMessagesImageParam,
|
||||
AnthropicMessagesTextParam,
|
||||
AnthropicMessagesDocumentParam,
|
||||
ChatCompletionThinkingBlock,
|
||||
],
|
||||
orignal_content_element: Union[dict, AllMessageValues],
|
||||
):
|
||||
|
@ -1317,6 +1399,7 @@ def _anthropic_content_element_factory(
|
|||
data=image_chunk["data"],
|
||||
),
|
||||
)
|
||||
|
||||
return _anthropic_content_element
|
||||
|
||||
|
||||
|
@ -1368,13 +1451,16 @@ def anthropic_messages_pt( # noqa: PLR0915
|
|||
for m in user_message_types_block["content"]:
|
||||
if m.get("type", "") == "image_url":
|
||||
m = cast(ChatCompletionImageObject, m)
|
||||
format: Optional[str] = None
|
||||
if isinstance(m["image_url"], str):
|
||||
image_chunk = convert_to_anthropic_image_obj(
|
||||
openai_image_url=m["image_url"]
|
||||
openai_image_url=m["image_url"], format=None
|
||||
)
|
||||
else:
|
||||
format = m["image_url"].get("format")
|
||||
image_chunk = convert_to_anthropic_image_obj(
|
||||
openai_image_url=m["image_url"]["url"]
|
||||
openai_image_url=m["image_url"]["url"],
|
||||
format=format,
|
||||
)
|
||||
|
||||
_anthropic_content_element = (
|
||||
|
@ -1454,12 +1540,23 @@ def anthropic_messages_pt( # noqa: PLR0915
|
|||
assistant_content_block["content"], list
|
||||
):
|
||||
for m in assistant_content_block["content"]:
|
||||
# handle text
|
||||
# handle thinking blocks
|
||||
thinking_block = cast(str, m.get("thinking", ""))
|
||||
text_block = cast(str, m.get("text", ""))
|
||||
if (
|
||||
m.get("type", "") == "text" and len(m.get("text", "")) > 0
|
||||
m.get("type", "") == "thinking" and len(thinking_block) > 0
|
||||
): # don't pass empty text blocks. anthropic api raises errors.
|
||||
anthropic_message: Union[
|
||||
ChatCompletionThinkingBlock,
|
||||
AnthropicMessagesTextParam,
|
||||
] = cast(ChatCompletionThinkingBlock, m)
|
||||
assistant_content.append(anthropic_message)
|
||||
# handle text
|
||||
elif (
|
||||
m.get("type", "") == "text" and len(text_block) > 0
|
||||
): # don't pass empty text blocks. anthropic api raises errors.
|
||||
anthropic_message = AnthropicMessagesTextParam(
|
||||
type="text", text=m.get("text")
|
||||
type="text", text=text_block
|
||||
)
|
||||
_cached_message = add_cache_control_to_content(
|
||||
anthropic_content_element=anthropic_message,
|
||||
|
@ -1512,6 +1609,7 @@ def anthropic_messages_pt( # noqa: PLR0915
|
|||
msg_i += 1
|
||||
|
||||
if assistant_content:
|
||||
|
||||
new_messages.append({"role": "assistant", "content": assistant_content})
|
||||
|
||||
if msg_i == init_msg_i: # prevent infinite loops
|
||||
|
@ -1520,17 +1618,6 @@ def anthropic_messages_pt( # noqa: PLR0915
|
|||
model=model,
|
||||
llm_provider=llm_provider,
|
||||
)
|
||||
if not new_messages or new_messages[0]["role"] != "user":
|
||||
if litellm.modify_params:
|
||||
new_messages.insert(
|
||||
0, {"role": "user", "content": [{"type": "text", "text": "."}]}
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
"Invalid first message={}. Should always start with 'role'='user' for Anthropic. System prompt is sent separately for Anthropic. set 'litellm.modify_params = True' or 'litellm_settings:modify_params = True' on proxy, to insert a placeholder user message - '.' as the first message, ".format(
|
||||
new_messages
|
||||
)
|
||||
)
|
||||
|
||||
if new_messages[-1]["role"] == "assistant":
|
||||
if isinstance(new_messages[-1]["content"], str):
|
||||
|
@ -2301,8 +2388,11 @@ class BedrockImageProcessor:
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def process_image_sync(cls, image_url: str) -> BedrockContentBlock:
|
||||
def process_image_sync(
|
||||
cls, image_url: str, format: Optional[str] = None
|
||||
) -> BedrockContentBlock:
|
||||
"""Synchronous image processing."""
|
||||
|
||||
if "base64" in image_url:
|
||||
img_bytes, mime_type, image_format = cls._parse_base64_image(image_url)
|
||||
elif "http://" in image_url or "https://" in image_url:
|
||||
|
@ -2313,11 +2403,17 @@ class BedrockImageProcessor:
|
|||
"Unsupported image type. Expected either image url or base64 encoded string"
|
||||
)
|
||||
|
||||
if format:
|
||||
mime_type = format
|
||||
image_format = mime_type.split("/")[1]
|
||||
|
||||
image_format = cls._validate_format(mime_type, image_format)
|
||||
return cls._create_bedrock_block(img_bytes, mime_type, image_format)
|
||||
|
||||
@classmethod
|
||||
async def process_image_async(cls, image_url: str) -> BedrockContentBlock:
|
||||
async def process_image_async(
|
||||
cls, image_url: str, format: Optional[str]
|
||||
) -> BedrockContentBlock:
|
||||
"""Asynchronous image processing."""
|
||||
|
||||
if "base64" in image_url:
|
||||
|
@ -2332,6 +2428,10 @@ class BedrockImageProcessor:
|
|||
"Unsupported image type. Expected either image url or base64 encoded string"
|
||||
)
|
||||
|
||||
if format: # override with user-defined params
|
||||
mime_type = format
|
||||
image_format = mime_type.split("/")[1]
|
||||
|
||||
image_format = cls._validate_format(mime_type, image_format)
|
||||
return cls._create_bedrock_block(img_bytes, mime_type, image_format)
|
||||
|
||||
|
@ -2819,12 +2919,14 @@ class BedrockConverseMessagesProcessor:
|
|||
_part = BedrockContentBlock(text=element["text"])
|
||||
_parts.append(_part)
|
||||
elif element["type"] == "image_url":
|
||||
format: Optional[str] = None
|
||||
if isinstance(element["image_url"], dict):
|
||||
image_url = element["image_url"]["url"]
|
||||
format = element["image_url"].get("format")
|
||||
else:
|
||||
image_url = element["image_url"]
|
||||
_part = await BedrockImageProcessor.process_image_async( # type: ignore
|
||||
image_url=image_url
|
||||
image_url=image_url, format=format
|
||||
)
|
||||
_parts.append(_part) # type: ignore
|
||||
_cache_point_block = (
|
||||
|
@ -2924,7 +3026,14 @@ class BedrockConverseMessagesProcessor:
|
|||
assistants_parts: List[BedrockContentBlock] = []
|
||||
for element in _assistant_content:
|
||||
if isinstance(element, dict):
|
||||
if element["type"] == "text":
|
||||
if element["type"] == "thinking":
|
||||
thinking_block = BedrockConverseMessagesProcessor.translate_thinking_blocks_to_reasoning_content_blocks(
|
||||
thinking_blocks=[
|
||||
cast(ChatCompletionThinkingBlock, element)
|
||||
]
|
||||
)
|
||||
assistants_parts.extend(thinking_block)
|
||||
elif element["type"] == "text":
|
||||
assistants_part = BedrockContentBlock(
|
||||
text=element["text"]
|
||||
)
|
||||
|
@ -2974,7 +3083,7 @@ class BedrockConverseMessagesProcessor:
|
|||
reasoning_content_blocks: List[BedrockContentBlock] = []
|
||||
for thinking_block in thinking_blocks:
|
||||
reasoning_text = thinking_block.get("thinking")
|
||||
reasoning_signature = thinking_block.get("signature_delta")
|
||||
reasoning_signature = thinking_block.get("signature")
|
||||
text_block = BedrockConverseReasoningTextBlock(
|
||||
text=reasoning_text or "",
|
||||
)
|
||||
|
@ -3050,12 +3159,15 @@ def _bedrock_converse_messages_pt( # noqa: PLR0915
|
|||
_part = BedrockContentBlock(text=element["text"])
|
||||
_parts.append(_part)
|
||||
elif element["type"] == "image_url":
|
||||
format: Optional[str] = None
|
||||
if isinstance(element["image_url"], dict):
|
||||
image_url = element["image_url"]["url"]
|
||||
format = element["image_url"].get("format")
|
||||
else:
|
||||
image_url = element["image_url"]
|
||||
_part = BedrockImageProcessor.process_image_sync( # type: ignore
|
||||
image_url=image_url
|
||||
image_url=image_url,
|
||||
format=format,
|
||||
)
|
||||
_parts.append(_part) # type: ignore
|
||||
_cache_point_block = (
|
||||
|
@ -3157,7 +3269,14 @@ def _bedrock_converse_messages_pt( # noqa: PLR0915
|
|||
assistants_parts: List[BedrockContentBlock] = []
|
||||
for element in _assistant_content:
|
||||
if isinstance(element, dict):
|
||||
if element["type"] == "text":
|
||||
if element["type"] == "thinking":
|
||||
thinking_block = BedrockConverseMessagesProcessor.translate_thinking_blocks_to_reasoning_content_blocks(
|
||||
thinking_blocks=[
|
||||
cast(ChatCompletionThinkingBlock, element)
|
||||
]
|
||||
)
|
||||
assistants_parts.extend(thinking_block)
|
||||
elif element["type"] == "text":
|
||||
assistants_part = BedrockContentBlock(text=element["text"])
|
||||
assistants_parts.append(assistants_part)
|
||||
elif element["type"] == "image_url":
|
||||
|
|
|
@ -15,6 +15,7 @@ from litellm import verbose_logger
|
|||
from litellm.litellm_core_utils.redact_messages import LiteLLMLoggingObject
|
||||
from litellm.litellm_core_utils.thread_pool_executor import executor
|
||||
from litellm.types.llms.openai import ChatCompletionChunk
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.utils import Delta
|
||||
from litellm.types.utils import GenericStreamingChunk as GChunk
|
||||
from litellm.types.utils import (
|
||||
|
@ -70,6 +71,17 @@ class CustomStreamWrapper:
|
|||
self.completion_stream = completion_stream
|
||||
self.sent_first_chunk = False
|
||||
self.sent_last_chunk = False
|
||||
|
||||
litellm_params: GenericLiteLLMParams = GenericLiteLLMParams(
|
||||
**self.logging_obj.model_call_details.get("litellm_params", {})
|
||||
)
|
||||
self.merge_reasoning_content_in_choices: bool = (
|
||||
litellm_params.merge_reasoning_content_in_choices or False
|
||||
)
|
||||
self.sent_first_thinking_block = False
|
||||
self.sent_last_thinking_block = False
|
||||
self.thinking_content = ""
|
||||
|
||||
self.system_fingerprint: Optional[str] = None
|
||||
self.received_finish_reason: Optional[str] = None
|
||||
self.intermittent_finish_reason: Optional[str] = (
|
||||
|
@ -87,12 +99,7 @@ class CustomStreamWrapper:
|
|||
self.holding_chunk = ""
|
||||
self.complete_response = ""
|
||||
self.response_uptil_now = ""
|
||||
_model_info = (
|
||||
self.logging_obj.model_call_details.get("litellm_params", {}).get(
|
||||
"model_info", {}
|
||||
)
|
||||
or {}
|
||||
)
|
||||
_model_info: Dict = litellm_params.model_info or {}
|
||||
|
||||
_api_base = get_api_base(
|
||||
model=model or "",
|
||||
|
@ -630,7 +637,10 @@ class CustomStreamWrapper:
|
|||
if isinstance(chunk, bytes):
|
||||
chunk = chunk.decode("utf-8")
|
||||
if "text_output" in chunk:
|
||||
response = chunk.replace("data: ", "").strip()
|
||||
response = (
|
||||
CustomStreamWrapper._strip_sse_data_from_chunk(chunk) or ""
|
||||
)
|
||||
response = response.strip()
|
||||
parsed_response = json.loads(response)
|
||||
else:
|
||||
return {
|
||||
|
@ -873,6 +883,10 @@ class CustomStreamWrapper:
|
|||
_index: Optional[int] = completion_obj.get("index")
|
||||
if _index is not None:
|
||||
model_response.choices[0].index = _index
|
||||
|
||||
self._optional_combine_thinking_block_in_choices(
|
||||
model_response=model_response
|
||||
)
|
||||
print_verbose(f"returning model_response: {model_response}")
|
||||
return model_response
|
||||
else:
|
||||
|
@ -929,6 +943,48 @@ class CustomStreamWrapper:
|
|||
self.chunks.append(model_response)
|
||||
return
|
||||
|
||||
def _optional_combine_thinking_block_in_choices(
|
||||
self, model_response: ModelResponseStream
|
||||
) -> None:
|
||||
"""
|
||||
UI's Like OpenWebUI expect to get 1 chunk with <think>...</think> tags in the chunk content
|
||||
|
||||
In place updates the model_response object with reasoning_content in content with <think>...</think> tags
|
||||
|
||||
Enabled when `merge_reasoning_content_in_choices=True` passed in request params
|
||||
|
||||
|
||||
"""
|
||||
if self.merge_reasoning_content_in_choices is True:
|
||||
reasoning_content = getattr(
|
||||
model_response.choices[0].delta, "reasoning_content", None
|
||||
)
|
||||
if reasoning_content:
|
||||
if self.sent_first_thinking_block is False:
|
||||
model_response.choices[0].delta.content += (
|
||||
"<think>" + reasoning_content
|
||||
)
|
||||
self.sent_first_thinking_block = True
|
||||
elif (
|
||||
self.sent_first_thinking_block is True
|
||||
and hasattr(model_response.choices[0].delta, "reasoning_content")
|
||||
and model_response.choices[0].delta.reasoning_content
|
||||
):
|
||||
model_response.choices[0].delta.content = reasoning_content
|
||||
elif (
|
||||
self.sent_first_thinking_block is True
|
||||
and not self.sent_last_thinking_block
|
||||
and model_response.choices[0].delta.content
|
||||
):
|
||||
model_response.choices[0].delta.content = (
|
||||
"</think>" + model_response.choices[0].delta.content
|
||||
)
|
||||
self.sent_last_thinking_block = True
|
||||
|
||||
if hasattr(model_response.choices[0].delta, "reasoning_content"):
|
||||
del model_response.choices[0].delta.reasoning_content
|
||||
return
|
||||
|
||||
def chunk_creator(self, chunk: Any): # type: ignore # noqa: PLR0915
|
||||
model_response = self.model_response_creator()
|
||||
response_obj: Dict[str, Any] = {}
|
||||
|
@ -1775,6 +1831,42 @@ class CustomStreamWrapper:
|
|||
extra_kwargs={},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _strip_sse_data_from_chunk(chunk: Optional[str]) -> Optional[str]:
|
||||
"""
|
||||
Strips the 'data: ' prefix from Server-Sent Events (SSE) chunks.
|
||||
|
||||
Some providers like sagemaker send it as `data:`, need to handle both
|
||||
|
||||
SSE messages are prefixed with 'data: ' which is part of the protocol,
|
||||
not the actual content from the LLM. This method removes that prefix
|
||||
and returns the actual content.
|
||||
|
||||
Args:
|
||||
chunk: The SSE chunk that may contain the 'data: ' prefix (string or bytes)
|
||||
|
||||
Returns:
|
||||
The chunk with the 'data: ' prefix removed, or the original chunk
|
||||
if no prefix was found. Returns None if input is None.
|
||||
|
||||
See OpenAI Python Ref for this: https://github.com/openai/openai-python/blob/041bf5a8ec54da19aad0169671793c2078bd6173/openai/api_requestor.py#L100
|
||||
"""
|
||||
if chunk is None:
|
||||
return None
|
||||
|
||||
if isinstance(chunk, str):
|
||||
# OpenAI sends `data: `
|
||||
if chunk.startswith("data: "):
|
||||
# Strip the prefix and any leading whitespace that might follow it
|
||||
_length_of_sse_data_prefix = len("data: ")
|
||||
return chunk[_length_of_sse_data_prefix:]
|
||||
elif chunk.startswith("data:"):
|
||||
# Sagemaker sends `data:`, no trailing whitespace
|
||||
_length_of_sse_data_prefix = len("data:")
|
||||
return chunk[_length_of_sse_data_prefix:]
|
||||
|
||||
return chunk
|
||||
|
||||
|
||||
def calculate_total_usage(chunks: List[ModelResponse]) -> Usage:
|
||||
"""Assume most recent usage chunk has total usage uptil then."""
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue