Merge branch 'BerriAI:main' into main

This commit is contained in:
Sunny Wan 2025-03-13 19:37:22 -04:00 committed by GitHub
commit e01d12b878
317 changed files with 15980 additions and 5207 deletions

View file

@ -66,6 +66,7 @@ from litellm.litellm_core_utils.core_helpers import (
map_finish_reason,
process_response_headers,
)
from litellm.litellm_core_utils.credential_accessor import CredentialAccessor
from litellm.litellm_core_utils.default_encoding import encoding
from litellm.litellm_core_utils.exception_mapping_utils import (
_get_response_headers,
@ -141,6 +142,7 @@ from litellm.types.utils import (
ChatCompletionMessageToolCall,
Choices,
CostPerToken,
CredentialItem,
CustomHuggingfaceTokenizer,
Delta,
Embedding,
@ -156,6 +158,7 @@ from litellm.types.utils import (
ModelResponseStream,
ProviderField,
ProviderSpecificModelInfo,
RawRequestTypedDict,
SelectTokenizerResponse,
StreamingChoices,
TextChoices,
@ -191,6 +194,9 @@ from typing import (
from openai import OpenAIError as OriginalError
from litellm.litellm_core_utils.thread_pool_executor import executor
from litellm.llms.base_llm.anthropic_messages.transformation import (
BaseAnthropicMessagesConfig,
)
from litellm.llms.base_llm.audio_transcription.transformation import (
BaseAudioTranscriptionConfig,
)
@ -205,6 +211,7 @@ from litellm.llms.base_llm.image_variations.transformation import (
BaseImageVariationConfig,
)
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
from litellm.llms.base_llm.responses.transformation import BaseResponsesAPIConfig
from ._logging import _is_debugging_on, verbose_logger
from .caching.caching import (
@ -451,6 +458,18 @@ def get_applied_guardrails(kwargs: Dict[str, Any]) -> List[str]:
return applied_guardrails
def load_credentials_from_list(kwargs: dict):
"""
Updates kwargs with the credentials if credential_name in kwarg
"""
credential_name = kwargs.get("litellm_credential_name")
if credential_name and litellm.credential_list:
credential_accessor = CredentialAccessor.get_credential_values(credential_name)
for key, value in credential_accessor.items():
if key not in kwargs:
kwargs[key] = value
def get_dynamic_callbacks(
dynamic_callbacks: Optional[List[Union[str, Callable, CustomLogger]]]
) -> List:
@ -711,6 +730,11 @@ def function_setup( # noqa: PLR0915
call_type == CallTypes.aspeech.value or call_type == CallTypes.speech.value
):
messages = kwargs.get("input", "speech")
elif (
call_type == CallTypes.aresponses.value
or call_type == CallTypes.responses.value
):
messages = args[0] if len(args) > 0 else kwargs["input"]
else:
messages = "default-message-value"
stream = True if "stream" in kwargs and kwargs["stream"] is True else False
@ -979,6 +1003,8 @@ def client(original_function): # noqa: PLR0915
logging_obj, kwargs = function_setup(
original_function.__name__, rules_obj, start_time, *args, **kwargs
)
## LOAD CREDENTIALS
load_credentials_from_list(kwargs)
kwargs["litellm_logging_obj"] = logging_obj
_llm_caching_handler: LLMCachingHandler = LLMCachingHandler(
original_function=original_function,
@ -1048,6 +1074,7 @@ def client(original_function): # noqa: PLR0915
)
if caching_handler_response.cached_result is not None:
verbose_logger.debug("Cache hit!")
return caching_handler_response.cached_result
# CHECK MAX TOKENS
@ -1234,6 +1261,8 @@ def client(original_function): # noqa: PLR0915
original_function.__name__, rules_obj, start_time, *args, **kwargs
)
kwargs["litellm_logging_obj"] = logging_obj
## LOAD CREDENTIALS
load_credentials_from_list(kwargs)
logging_obj._llm_caching_handler = _llm_caching_handler
# [OPTIONAL] CHECK BUDGET
if litellm.max_budget:
@ -2422,6 +2451,7 @@ def get_optional_params_image_gen(
config_class = (
litellm.AmazonStability3Config
if litellm.AmazonStability3Config._is_stability_3_model(model=model)
else litellm.AmazonNovaCanvasConfig if litellm.AmazonNovaCanvasConfig._is_nova_model(model=model)
else litellm.AmazonStabilityConfig
)
supported_params = config_class.get_supported_openai_params(model=model)
@ -4403,6 +4433,12 @@ def _get_model_info_helper( # noqa: PLR0915
input_cost_per_audio_token=_model_info.get(
"input_cost_per_audio_token", None
),
input_cost_per_token_batches=_model_info.get(
"input_cost_per_token_batches"
),
output_cost_per_token_batches=_model_info.get(
"output_cost_per_token_batches"
),
output_cost_per_token=_output_cost_per_token,
output_cost_per_audio_token=_model_info.get(
"output_cost_per_audio_token", None
@ -5092,7 +5128,7 @@ def prompt_token_calculator(model, messages):
from anthropic import AI_PROMPT, HUMAN_PROMPT, Anthropic
anthropic_obj = Anthropic()
num_tokens = anthropic_obj.count_tokens(text)
num_tokens = anthropic_obj.count_tokens(text) # type: ignore
else:
num_tokens = len(encoding.encode(text))
return num_tokens
@ -6246,6 +6282,15 @@ class ProviderConfigManager:
return litellm.JinaAIRerankConfig()
return litellm.CohereRerankConfig()
@staticmethod
def get_provider_anthropic_messages_config(
model: str,
provider: LlmProviders,
) -> Optional[BaseAnthropicMessagesConfig]:
if litellm.LlmProviders.ANTHROPIC == provider:
return litellm.AnthropicMessagesConfig()
return None
@staticmethod
def get_provider_audio_transcription_config(
model: str,
@ -6257,6 +6302,15 @@ class ProviderConfigManager:
return litellm.DeepgramAudioTranscriptionConfig()
return None
@staticmethod
def get_provider_responses_api_config(
model: str,
provider: LlmProviders,
) -> Optional[BaseResponsesAPIConfig]:
if litellm.LlmProviders.OPENAI == provider:
return litellm.OpenAIResponsesAPIConfig()
return None
@staticmethod
def get_provider_text_completion_config(
model: str,
@ -6466,3 +6520,48 @@ def add_openai_metadata(metadata: dict) -> dict:
}
return visible_metadata.copy()
def return_raw_request(endpoint: CallTypes, kwargs: dict) -> RawRequestTypedDict:
"""
Return the json str of the request
This is currently in BETA, and tested for `/chat/completions` -> `litellm.completion` calls.
"""
from datetime import datetime
from litellm.litellm_core_utils.litellm_logging import Logging
litellm_logging_obj = Logging(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "hi"}],
stream=False,
call_type="acompletion",
litellm_call_id="1234",
start_time=datetime.now(),
function_id="1234",
log_raw_request_response=True,
)
llm_api_endpoint = getattr(litellm, endpoint.value)
received_exception = ""
try:
llm_api_endpoint(
**kwargs,
litellm_logging_obj=litellm_logging_obj,
api_key="my-fake-api-key", # 👈 ensure the request fails
)
except Exception as e:
received_exception = str(e)
raw_request_typed_dict = litellm_logging_obj.model_call_details.get(
"raw_request_typed_dict"
)
if raw_request_typed_dict:
return cast(RawRequestTypedDict, raw_request_typed_dict)
else:
return RawRequestTypedDict(
error=received_exception,
)