build(pyproject.toml): add new dev dependencies - for type checking (#9631)

* build(pyproject.toml): add new dev dependencies - for type checking

* build: reformat files to fit black

* ci: reformat to fit black

* ci(test-litellm.yml): make tests run clear

* build(pyproject.toml): add ruff

* fix: fix ruff checks

* build(mypy/): fix mypy linting errors

* fix(hashicorp_secret_manager.py): fix passing cert for tls auth

* build(mypy/): resolve all mypy errors

* test: update test

* fix: fix black formatting

* build(pre-commit-config.yaml): use poetry run black

* fix(proxy_server.py): fix linting error

* fix: fix ruff safe representation error
This commit is contained in:
Krish Dholakia 2025-03-29 11:02:13 -07:00 committed by GitHub
parent 95e5dfae5a
commit 9b7ebb6a7d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
214 changed files with 1553 additions and 1433 deletions

53
.github/workflows/test-linting.yml vendored Normal file
View file

@ -0,0 +1,53 @@
name: LiteLLM Linting
on:
pull_request:
branches: [ main ]
jobs:
lint:
runs-on: ubuntu-latest
timeout-minutes: 5
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.12'
- name: Install Poetry
uses: snok/install-poetry@v1
- name: Install dependencies
run: |
poetry install --with dev
- name: Run Black formatting check
run: |
cd litellm
poetry run black . --check
cd ..
- name: Run Ruff linting
run: |
cd litellm
poetry run ruff check .
cd ..
- name: Run MyPy type checking
run: |
cd litellm
poetry run mypy . --ignore-missing-imports
cd ..
- name: Check for circular imports
run: |
cd litellm
poetry run python ../tests/documentation_tests/test_circular_imports.py
cd ..
- name: Check import safety
run: |
poetry run python -c "from litellm import *" || (echo '🚨 import failed, this means you introduced unprotected imports! 🚨'; exit 1)

View file

@ -1,4 +1,4 @@
name: LiteLLM Tests name: LiteLLM Mock Tests (folder - tests/litellm)
on: on:
pull_request: pull_request:

View file

@ -14,10 +14,12 @@ repos:
types: [python] types: [python]
files: litellm/.*\.py files: litellm/.*\.py
exclude: ^litellm/__init__.py$ exclude: ^litellm/__init__.py$
- repo: https://github.com/psf/black
rev: 24.2.0
hooks:
- id: black - id: black
name: black
entry: poetry run black
language: system
types: [python]
files: litellm/.*\.py
- repo: https://github.com/pycqa/flake8 - repo: https://github.com/pycqa/flake8
rev: 7.0.0 # The version of flake8 to use rev: 7.0.0 # The version of flake8 to use
hooks: hooks:

View file

@ -444,9 +444,7 @@ class _ENTERPRISE_SecretDetection(CustomGuardrail):
detected_secrets = [] detected_secrets = []
for file in secrets.files: for file in secrets.files:
for found_secret in secrets[file]: for found_secret in secrets[file]:
if found_secret.secret_value is None: if found_secret.secret_value is None:
continue continue
detected_secrets.append( detected_secrets.append(
@ -471,14 +469,12 @@ class _ENTERPRISE_SecretDetection(CustomGuardrail):
data: dict, data: dict,
call_type: str, # "completion", "embeddings", "image_generation", "moderation" call_type: str, # "completion", "embeddings", "image_generation", "moderation"
): ):
if await self.should_run_check(user_api_key_dict) is False: if await self.should_run_check(user_api_key_dict) is False:
return return
if "messages" in data and isinstance(data["messages"], list): if "messages" in data and isinstance(data["messages"], list):
for message in data["messages"]: for message in data["messages"]:
if "content" in message and isinstance(message["content"], str): if "content" in message and isinstance(message["content"], str):
detected_secrets = self.scan_message_for_secrets(message["content"]) detected_secrets = self.scan_message_for_secrets(message["content"])
for secret in detected_secrets: for secret in detected_secrets:

View file

@ -122,19 +122,19 @@ langsmith_batch_size: Optional[int] = None
prometheus_initialize_budget_metrics: Optional[bool] = False prometheus_initialize_budget_metrics: Optional[bool] = False
argilla_batch_size: Optional[int] = None argilla_batch_size: Optional[int] = None
datadog_use_v1: Optional[bool] = False # if you want to use v1 datadog logged payload datadog_use_v1: Optional[bool] = False # if you want to use v1 datadog logged payload
gcs_pub_sub_use_v1: Optional[bool] = ( gcs_pub_sub_use_v1: Optional[
False # if you want to use v1 gcs pubsub logged payload bool
) ] = False # if you want to use v1 gcs pubsub logged payload
argilla_transformation_object: Optional[Dict[str, Any]] = None argilla_transformation_object: Optional[Dict[str, Any]] = None
_async_input_callback: List[Union[str, Callable, CustomLogger]] = ( _async_input_callback: List[
[] Union[str, Callable, CustomLogger]
) # internal variable - async custom callbacks are routed here. ] = [] # internal variable - async custom callbacks are routed here.
_async_success_callback: List[Union[str, Callable, CustomLogger]] = ( _async_success_callback: List[
[] Union[str, Callable, CustomLogger]
) # internal variable - async custom callbacks are routed here. ] = [] # internal variable - async custom callbacks are routed here.
_async_failure_callback: List[Union[str, Callable, CustomLogger]] = ( _async_failure_callback: List[
[] Union[str, Callable, CustomLogger]
) # internal variable - async custom callbacks are routed here. ] = [] # internal variable - async custom callbacks are routed here.
pre_call_rules: List[Callable] = [] pre_call_rules: List[Callable] = []
post_call_rules: List[Callable] = [] post_call_rules: List[Callable] = []
turn_off_message_logging: Optional[bool] = False turn_off_message_logging: Optional[bool] = False
@ -142,18 +142,18 @@ log_raw_request_response: bool = False
redact_messages_in_exceptions: Optional[bool] = False redact_messages_in_exceptions: Optional[bool] = False
redact_user_api_key_info: Optional[bool] = False redact_user_api_key_info: Optional[bool] = False
filter_invalid_headers: Optional[bool] = False filter_invalid_headers: Optional[bool] = False
add_user_information_to_llm_headers: Optional[bool] = ( add_user_information_to_llm_headers: Optional[
None # adds user_id, team_id, token hash (params from StandardLoggingMetadata) to request headers bool
) ] = None # adds user_id, team_id, token hash (params from StandardLoggingMetadata) to request headers
store_audit_logs = False # Enterprise feature, allow users to see audit logs store_audit_logs = False # Enterprise feature, allow users to see audit logs
### end of callbacks ############# ### end of callbacks #############
email: Optional[str] = ( email: Optional[
None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 str
) ] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
token: Optional[str] = ( token: Optional[
None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 str
) ] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
telemetry = True telemetry = True
max_tokens = 256 # OpenAI Defaults max_tokens = 256 # OpenAI Defaults
drop_params = bool(os.getenv("LITELLM_DROP_PARAMS", False)) drop_params = bool(os.getenv("LITELLM_DROP_PARAMS", False))
@ -229,24 +229,20 @@ enable_loadbalancing_on_batch_endpoints: Optional[bool] = None
enable_caching_on_provider_specific_optional_params: bool = ( enable_caching_on_provider_specific_optional_params: bool = (
False # feature-flag for caching on optional params - e.g. 'top_k' False # feature-flag for caching on optional params - e.g. 'top_k'
) )
caching: bool = ( caching: bool = False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 caching_with_models: bool = False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
) cache: Optional[
caching_with_models: bool = ( Cache
False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 ] = None # cache object <- use this - https://docs.litellm.ai/docs/caching
)
cache: Optional[Cache] = (
None # cache object <- use this - https://docs.litellm.ai/docs/caching
)
default_in_memory_ttl: Optional[float] = None default_in_memory_ttl: Optional[float] = None
default_redis_ttl: Optional[float] = None default_redis_ttl: Optional[float] = None
default_redis_batch_cache_expiry: Optional[float] = None default_redis_batch_cache_expiry: Optional[float] = None
model_alias_map: Dict[str, str] = {} model_alias_map: Dict[str, str] = {}
model_group_alias_map: Dict[str, str] = {} model_group_alias_map: Dict[str, str] = {}
max_budget: float = 0.0 # set the max budget across all providers max_budget: float = 0.0 # set the max budget across all providers
budget_duration: Optional[str] = ( budget_duration: Optional[
None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"). str
) ] = None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
default_soft_budget: float = ( default_soft_budget: float = (
50.0 # by default all litellm proxy keys have a soft budget of 50.0 50.0 # by default all litellm proxy keys have a soft budget of 50.0
) )
@ -255,15 +251,11 @@ forward_traceparent_to_llm_provider: bool = False
_current_cost = 0.0 # private variable, used if max budget is set _current_cost = 0.0 # private variable, used if max budget is set
error_logs: Dict = {} error_logs: Dict = {}
add_function_to_prompt: bool = ( add_function_to_prompt: bool = False # if function calling not supported by api, append function call details to system prompt
False # if function calling not supported by api, append function call details to system prompt
)
client_session: Optional[httpx.Client] = None client_session: Optional[httpx.Client] = None
aclient_session: Optional[httpx.AsyncClient] = None aclient_session: Optional[httpx.AsyncClient] = None
model_fallbacks: Optional[List] = None # Deprecated for 'litellm.fallbacks' model_fallbacks: Optional[List] = None # Deprecated for 'litellm.fallbacks'
model_cost_map_url: str = ( model_cost_map_url: str = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
"https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
)
suppress_debug_info = False suppress_debug_info = False
dynamodb_table_name: Optional[str] = None dynamodb_table_name: Optional[str] = None
s3_callback_params: Optional[Dict] = None s3_callback_params: Optional[Dict] = None
@ -285,9 +277,7 @@ disable_end_user_cost_tracking_prometheus_only: Optional[bool] = None
custom_prometheus_metadata_labels: List[str] = [] custom_prometheus_metadata_labels: List[str] = []
#### REQUEST PRIORITIZATION #### #### REQUEST PRIORITIZATION ####
priority_reservation: Optional[Dict[str, float]] = None priority_reservation: Optional[Dict[str, float]] = None
force_ipv4: bool = ( force_ipv4: bool = False # when True, litellm will force ipv4 for all LLM requests. Some users have seen httpx ConnectionError when using ipv6.
False # when True, litellm will force ipv4 for all LLM requests. Some users have seen httpx ConnectionError when using ipv6.
)
module_level_aclient = AsyncHTTPHandler( module_level_aclient = AsyncHTTPHandler(
timeout=request_timeout, client_alias="module level aclient" timeout=request_timeout, client_alias="module level aclient"
) )
@ -301,13 +291,13 @@ fallbacks: Optional[List] = None
context_window_fallbacks: Optional[List] = None context_window_fallbacks: Optional[List] = None
content_policy_fallbacks: Optional[List] = None content_policy_fallbacks: Optional[List] = None
allowed_fails: int = 3 allowed_fails: int = 3
num_retries_per_request: Optional[int] = ( num_retries_per_request: Optional[
None # for the request overall (incl. fallbacks + model retries) int
) ] = None # for the request overall (incl. fallbacks + model retries)
####### SECRET MANAGERS ##################### ####### SECRET MANAGERS #####################
secret_manager_client: Optional[Any] = ( secret_manager_client: Optional[
None # list of instantiated key management clients - e.g. azure kv, infisical, etc. Any
) ] = None # list of instantiated key management clients - e.g. azure kv, infisical, etc.
_google_kms_resource_name: Optional[str] = None _google_kms_resource_name: Optional[str] = None
_key_management_system: Optional[KeyManagementSystem] = None _key_management_system: Optional[KeyManagementSystem] = None
_key_management_settings: KeyManagementSettings = KeyManagementSettings() _key_management_settings: KeyManagementSettings = KeyManagementSettings()
@ -1056,10 +1046,10 @@ from .types.llms.custom_llm import CustomLLMItem
from .types.utils import GenericStreamingChunk from .types.utils import GenericStreamingChunk
custom_provider_map: List[CustomLLMItem] = [] custom_provider_map: List[CustomLLMItem] = []
_custom_providers: List[str] = ( _custom_providers: List[
[] str
) # internal helper util, used to track names of custom providers ] = [] # internal helper util, used to track names of custom providers
disable_hf_tokenizer_download: Optional[bool] = ( disable_hf_tokenizer_download: Optional[
None # disable huggingface tokenizer download. Defaults to openai clk100 bool
) ] = None # disable huggingface tokenizer download. Defaults to openai clk100
global_disable_no_log_param: bool = False global_disable_no_log_param: bool = False

View file

@ -15,7 +15,7 @@ from .types.services import ServiceLoggerPayload, ServiceTypes
if TYPE_CHECKING: if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span from opentelemetry.trace import Span as _Span
Span = _Span Span = Union[_Span, Any]
OTELClass = OpenTelemetry OTELClass = OpenTelemetry
else: else:
Span = Any Span = Any

View file

@ -153,7 +153,6 @@ def create_batch(
) )
api_base: Optional[str] = None api_base: Optional[str] = None
if custom_llm_provider == "openai": if custom_llm_provider == "openai":
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
api_base = ( api_base = (
optional_params.api_base optional_params.api_base
@ -358,7 +357,6 @@ def retrieve_batch(
_is_async = kwargs.pop("aretrieve_batch", False) is True _is_async = kwargs.pop("aretrieve_batch", False) is True
api_base: Optional[str] = None api_base: Optional[str] = None
if custom_llm_provider == "openai": if custom_llm_provider == "openai":
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
api_base = ( api_base = (
optional_params.api_base optional_params.api_base

View file

@ -9,12 +9,12 @@ Has 4 methods:
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional, Union
if TYPE_CHECKING: if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span from opentelemetry.trace import Span as _Span
Span = _Span Span = Union[_Span, Any]
else: else:
Span = Any Span = Any

View file

@ -66,9 +66,7 @@ class CachingHandlerResponse(BaseModel):
cached_result: Optional[Any] = None cached_result: Optional[Any] = None
final_embedding_cached_response: Optional[EmbeddingResponse] = None final_embedding_cached_response: Optional[EmbeddingResponse] = None
embedding_all_elements_cache_hit: bool = ( embedding_all_elements_cache_hit: bool = False # this is set to True when all elements in the list have a cache hit in the embedding cache, if true return the final_embedding_cached_response no need to make an API call
False # this is set to True when all elements in the list have a cache hit in the embedding cache, if true return the final_embedding_cached_response no need to make an API call
)
class LLMCachingHandler: class LLMCachingHandler:
@ -738,7 +736,6 @@ class LLMCachingHandler:
if self._should_store_result_in_cache( if self._should_store_result_in_cache(
original_function=self.original_function, kwargs=new_kwargs original_function=self.original_function, kwargs=new_kwargs
): ):
litellm.cache.add_cache(result, **new_kwargs) litellm.cache.add_cache(result, **new_kwargs)
return return
@ -865,9 +862,9 @@ class LLMCachingHandler:
} }
if litellm.cache is not None: if litellm.cache is not None:
litellm_params["preset_cache_key"] = ( litellm_params[
litellm.cache._get_preset_cache_key_from_kwargs(**kwargs) "preset_cache_key"
) ] = litellm.cache._get_preset_cache_key_from_kwargs(**kwargs)
else: else:
litellm_params["preset_cache_key"] = None litellm_params["preset_cache_key"] = None

View file

@ -1,12 +1,12 @@
import json import json
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional, Union
from .base_cache import BaseCache from .base_cache import BaseCache
if TYPE_CHECKING: if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span from opentelemetry.trace import Span as _Span
Span = _Span Span = Union[_Span, Any]
else: else:
Span = Any Span = Any

View file

@ -12,7 +12,7 @@ import asyncio
import time import time
import traceback import traceback
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import TYPE_CHECKING, Any, List, Optional from typing import TYPE_CHECKING, Any, List, Optional, Union
import litellm import litellm
from litellm._logging import print_verbose, verbose_logger from litellm._logging import print_verbose, verbose_logger
@ -24,7 +24,7 @@ from .redis_cache import RedisCache
if TYPE_CHECKING: if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span from opentelemetry.trace import Span as _Span
Span = _Span Span = Union[_Span, Any]
else: else:
Span = Any Span = Any

View file

@ -8,7 +8,6 @@ from .in_memory_cache import InMemoryCache
class LLMClientCache(InMemoryCache): class LLMClientCache(InMemoryCache):
def update_cache_key_with_event_loop(self, key): def update_cache_key_with_event_loop(self, key):
""" """
Add the event loop to the cache key, to prevent event loop closed errors. Add the event loop to the cache key, to prevent event loop closed errors.

View file

@ -34,7 +34,7 @@ if TYPE_CHECKING:
cluster_pipeline = ClusterPipeline cluster_pipeline = ClusterPipeline
async_redis_client = Redis async_redis_client = Redis
async_redis_cluster_client = RedisCluster async_redis_cluster_client = RedisCluster
Span = _Span Span = Union[_Span, Any]
else: else:
pipeline = Any pipeline = Any
cluster_pipeline = Any cluster_pipeline = Any
@ -57,7 +57,6 @@ class RedisCache(BaseCache):
socket_timeout: Optional[float] = 5.0, # default 5 second timeout socket_timeout: Optional[float] = 5.0, # default 5 second timeout
**kwargs, **kwargs,
): ):
from litellm._service_logger import ServiceLogging from litellm._service_logger import ServiceLogging
from .._redis import get_redis_client, get_redis_connection_pool from .._redis import get_redis_client, get_redis_connection_pool

View file

@ -5,7 +5,7 @@ Key differences:
- RedisClient NEEDs to be re-used across requests, adds 3000ms latency if it's re-created - RedisClient NEEDs to be re-used across requests, adds 3000ms latency if it's re-created
""" """
from typing import TYPE_CHECKING, Any, List, Optional from typing import TYPE_CHECKING, Any, List, Optional, Union
from litellm.caching.redis_cache import RedisCache from litellm.caching.redis_cache import RedisCache
@ -16,7 +16,7 @@ if TYPE_CHECKING:
pipeline = Pipeline pipeline = Pipeline
async_redis_client = Redis async_redis_client = Redis
Span = _Span Span = Union[_Span, Any]
else: else:
pipeline = Any pipeline = Any
async_redis_client = Any async_redis_client = Any

View file

@ -13,11 +13,15 @@ import ast
import asyncio import asyncio
import json import json
import os import os
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple, cast
import litellm import litellm
from litellm._logging import print_verbose from litellm._logging import print_verbose
from litellm.litellm_core_utils.prompt_templates.common_utils import get_str_from_messages from litellm.litellm_core_utils.prompt_templates.common_utils import (
get_str_from_messages,
)
from litellm.types.utils import EmbeddingResponse
from .base_cache import BaseCache from .base_cache import BaseCache
@ -87,14 +91,16 @@ class RedisSemanticCache(BaseCache):
if redis_url is None: if redis_url is None:
try: try:
# Attempt to use provided parameters or fallback to environment variables # Attempt to use provided parameters or fallback to environment variables
host = host or os.environ['REDIS_HOST'] host = host or os.environ["REDIS_HOST"]
port = port or os.environ['REDIS_PORT'] port = port or os.environ["REDIS_PORT"]
password = password or os.environ['REDIS_PASSWORD'] password = password or os.environ["REDIS_PASSWORD"]
except KeyError as e: except KeyError as e:
# Raise a more informative exception if any of the required keys are missing # Raise a more informative exception if any of the required keys are missing
missing_var = e.args[0] missing_var = e.args[0]
raise ValueError(f"Missing required Redis configuration: {missing_var}. " raise ValueError(
f"Provide {missing_var} or redis_url.") from e f"Missing required Redis configuration: {missing_var}. "
f"Provide {missing_var} or redis_url."
) from e
redis_url = f"redis://:{password}@{host}:{port}" redis_url = f"redis://:{password}@{host}:{port}"
@ -137,10 +143,13 @@ class RedisSemanticCache(BaseCache):
List[float]: The embedding vector List[float]: The embedding vector
""" """
# Create an embedding from prompt # Create an embedding from prompt
embedding_response = litellm.embedding( embedding_response = cast(
EmbeddingResponse,
litellm.embedding(
model=self.embedding_model, model=self.embedding_model,
input=prompt, input=prompt,
cache={"no-store": True, "no-cache": True}, cache={"no-store": True, "no-cache": True},
),
) )
embedding = embedding_response["data"][0]["embedding"] embedding = embedding_response["data"][0]["embedding"]
return embedding return embedding
@ -186,6 +195,7 @@ class RedisSemanticCache(BaseCache):
""" """
print_verbose(f"Redis semantic-cache set_cache, kwargs: {kwargs}") print_verbose(f"Redis semantic-cache set_cache, kwargs: {kwargs}")
value_str: Optional[str] = None
try: try:
# Extract the prompt from messages # Extract the prompt from messages
messages = kwargs.get("messages", []) messages = kwargs.get("messages", [])
@ -203,7 +213,9 @@ class RedisSemanticCache(BaseCache):
else: else:
self.llmcache.store(prompt, value_str) self.llmcache.store(prompt, value_str)
except Exception as e: except Exception as e:
print_verbose(f"Error setting {value_str} in the Redis semantic cache: {str(e)}") print_verbose(
f"Error setting {value_str or value} in the Redis semantic cache: {str(e)}"
)
def get_cache(self, key: str, **kwargs) -> Any: def get_cache(self, key: str, **kwargs) -> Any:
""" """
@ -336,13 +348,13 @@ class RedisSemanticCache(BaseCache):
prompt, prompt,
value_str, value_str,
vector=prompt_embedding, # Pass through custom embedding vector=prompt_embedding, # Pass through custom embedding
ttl=ttl ttl=ttl,
) )
else: else:
await self.llmcache.astore( await self.llmcache.astore(
prompt, prompt,
value_str, value_str,
vector=prompt_embedding # Pass through custom embedding vector=prompt_embedding, # Pass through custom embedding
) )
except Exception as e: except Exception as e:
print_verbose(f"Error in async_set_cache: {str(e)}") print_verbose(f"Error in async_set_cache: {str(e)}")
@ -374,14 +386,13 @@ class RedisSemanticCache(BaseCache):
prompt_embedding = await self._get_async_embedding(prompt, **kwargs) prompt_embedding = await self._get_async_embedding(prompt, **kwargs)
# Check the cache for semantically similar prompts # Check the cache for semantically similar prompts
results = await self.llmcache.acheck( results = await self.llmcache.acheck(prompt=prompt, vector=prompt_embedding)
prompt=prompt,
vector=prompt_embedding
)
# handle results / cache hit # handle results / cache hit
if not results: if not results:
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0 # TODO why here but not above?? kwargs.setdefault("metadata", {})[
"semantic-similarity"
] = 0.0 # TODO why here but not above??
return None return None
cache_hit = results[0] cache_hit = results[0]
@ -420,7 +431,9 @@ class RedisSemanticCache(BaseCache):
aindex = await self.llmcache._get_async_index() aindex = await self.llmcache._get_async_index()
return await aindex.info() return await aindex.info()
async def async_set_cache_pipeline(self, cache_list: List[Tuple[str, Any]], **kwargs) -> None: async def async_set_cache_pipeline(
self, cache_list: List[Tuple[str, Any]], **kwargs
) -> None:
""" """
Asynchronously store multiple values in the semantic cache. Asynchronously store multiple values in the semantic cache.

View file

@ -123,7 +123,7 @@ class S3Cache(BaseCache):
) # Convert string to dictionary ) # Convert string to dictionary
except Exception: except Exception:
cached_response = ast.literal_eval(cached_response) cached_response = ast.literal_eval(cached_response)
if type(cached_response) is not dict: if not isinstance(cached_response, dict):
cached_response = dict(cached_response) cached_response = dict(cached_response)
verbose_logger.debug( verbose_logger.debug(
f"Got S3 Cache: key: {key}, cached_response {cached_response}. Type Response {type(cached_response)}" f"Got S3 Cache: key: {key}, cached_response {cached_response}. Type Response {type(cached_response)}"

View file

@ -580,7 +580,6 @@ def completion_cost( # noqa: PLR0915
- For un-mapped Replicate models, the cost is calculated based on the total time used for the request. - For un-mapped Replicate models, the cost is calculated based on the total time used for the request.
""" """
try: try:
call_type = _infer_call_type(call_type, completion_response) or "completion" call_type = _infer_call_type(call_type, completion_response) or "completion"
if ( if (

View file

@ -138,7 +138,6 @@ def create_fine_tuning_job(
# OpenAI # OpenAI
if custom_llm_provider == "openai": if custom_llm_provider == "openai":
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
api_base = ( api_base = (
optional_params.api_base optional_params.api_base
@ -360,7 +359,6 @@ def cancel_fine_tuning_job(
# OpenAI # OpenAI
if custom_llm_provider == "openai": if custom_llm_provider == "openai":
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
api_base = ( api_base = (
optional_params.api_base optional_params.api_base
@ -522,7 +520,6 @@ def list_fine_tuning_jobs(
# OpenAI # OpenAI
if custom_llm_provider == "openai": if custom_llm_provider == "openai":
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
api_base = ( api_base = (
optional_params.api_base optional_params.api_base

View file

@ -19,7 +19,6 @@ else:
def squash_payloads(queue): def squash_payloads(queue):
squashed = {} squashed = {}
if len(queue) == 0: if len(queue) == 0:
return squashed return squashed

View file

@ -195,13 +195,16 @@ class SlackAlerting(CustomBatchLogger):
if self.alerting is None or self.alert_types is None: if self.alerting is None or self.alert_types is None:
return return
time_difference_float, model, api_base, messages = ( (
self._response_taking_too_long_callback_helper( time_difference_float,
model,
api_base,
messages,
) = self._response_taking_too_long_callback_helper(
kwargs=kwargs, kwargs=kwargs,
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
) )
)
if litellm.turn_off_message_logging or litellm.redact_messages_in_exceptions: if litellm.turn_off_message_logging or litellm.redact_messages_in_exceptions:
messages = "Message not logged. litellm.redact_messages_in_exceptions=True" messages = "Message not logged. litellm.redact_messages_in_exceptions=True"
request_info = f"\nRequest Model: `{model}`\nAPI Base: `{api_base}`\nMessages: `{messages}`" request_info = f"\nRequest Model: `{model}`\nAPI Base: `{api_base}`\nMessages: `{messages}`"
@ -819,9 +822,9 @@ class SlackAlerting(CustomBatchLogger):
### UNIQUE CACHE KEY ### ### UNIQUE CACHE KEY ###
cache_key = provider + region_name cache_key = provider + region_name
outage_value: Optional[ProviderRegionOutageModel] = ( outage_value: Optional[
await self.internal_usage_cache.async_get_cache(key=cache_key) ProviderRegionOutageModel
) ] = await self.internal_usage_cache.async_get_cache(key=cache_key)
if ( if (
getattr(exception, "status_code", None) is None getattr(exception, "status_code", None) is None
@ -1402,9 +1405,9 @@ Model Info:
self.alert_to_webhook_url is not None self.alert_to_webhook_url is not None
and alert_type in self.alert_to_webhook_url and alert_type in self.alert_to_webhook_url
): ):
slack_webhook_url: Optional[Union[str, List[str]]] = ( slack_webhook_url: Optional[
self.alert_to_webhook_url[alert_type] Union[str, List[str]]
) ] = self.alert_to_webhook_url[alert_type]
elif self.default_webhook_url is not None: elif self.default_webhook_url is not None:
slack_webhook_url = self.default_webhook_url slack_webhook_url = self.default_webhook_url
else: else:
@ -1768,7 +1771,6 @@ Model Info:
- Team Created, Updated, Deleted - Team Created, Updated, Deleted
""" """
try: try:
message = f"`{event_name}`\n" message = f"`{event_name}`\n"
key_event_dict = key_event.model_dump() key_event_dict = key_event.model_dump()

View file

@ -98,7 +98,6 @@ class ArgillaLogger(CustomBatchLogger):
argilla_dataset_name: Optional[str], argilla_dataset_name: Optional[str],
argilla_base_url: Optional[str], argilla_base_url: Optional[str],
) -> ArgillaCredentialsObject: ) -> ArgillaCredentialsObject:
_credentials_api_key = argilla_api_key or os.getenv("ARGILLA_API_KEY") _credentials_api_key = argilla_api_key or os.getenv("ARGILLA_API_KEY")
if _credentials_api_key is None: if _credentials_api_key is None:
raise Exception("Invalid Argilla API Key given. _credentials_api_key=None.") raise Exception("Invalid Argilla API Key given. _credentials_api_key=None.")

View file

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional, Union
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
@ -7,7 +7,7 @@ from litellm.types.utils import StandardLoggingPayload
if TYPE_CHECKING: if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span from opentelemetry.trace import Span as _Span
Span = _Span Span = Union[_Span, Any]
else: else:
Span = Any Span = Any

View file

@ -19,14 +19,13 @@ if TYPE_CHECKING:
from litellm.types.integrations.arize import Protocol as _Protocol from litellm.types.integrations.arize import Protocol as _Protocol
Protocol = _Protocol Protocol = _Protocol
Span = _Span Span = Union[_Span, Any]
else: else:
Protocol = Any Protocol = Any
Span = Any Span = Any
class ArizeLogger(OpenTelemetry): class ArizeLogger(OpenTelemetry):
def set_attributes(self, span: Span, kwargs, response_obj: Optional[Any]): def set_attributes(self, span: Span, kwargs, response_obj: Optional[Any]):
ArizeLogger.set_arize_attributes(span, kwargs, response_obj) ArizeLogger.set_arize_attributes(span, kwargs, response_obj)
return return

View file

@ -1,17 +1,20 @@
import os import os
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any, Union
from litellm.integrations.arize import _utils
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.integrations.arize import _utils
from litellm.types.integrations.arize_phoenix import ArizePhoenixConfig from litellm.types.integrations.arize_phoenix import ArizePhoenixConfig
if TYPE_CHECKING: if TYPE_CHECKING:
from .opentelemetry import OpenTelemetryConfig as _OpenTelemetryConfig
from litellm.types.integrations.arize import Protocol as _Protocol
from opentelemetry.trace import Span as _Span from opentelemetry.trace import Span as _Span
from litellm.types.integrations.arize import Protocol as _Protocol
from .opentelemetry import OpenTelemetryConfig as _OpenTelemetryConfig
Protocol = _Protocol Protocol = _Protocol
OpenTelemetryConfig = _OpenTelemetryConfig OpenTelemetryConfig = _OpenTelemetryConfig
Span = _Span Span = Union[_Span, Any]
else: else:
Protocol = Any Protocol = Any
OpenTelemetryConfig = Any OpenTelemetryConfig = Any
@ -20,6 +23,7 @@ else:
ARIZE_HOSTED_PHOENIX_ENDPOINT = "https://app.phoenix.arize.com/v1/traces" ARIZE_HOSTED_PHOENIX_ENDPOINT = "https://app.phoenix.arize.com/v1/traces"
class ArizePhoenixLogger: class ArizePhoenixLogger:
@staticmethod @staticmethod
def set_arize_phoenix_attributes(span: Span, kwargs, response_obj): def set_arize_phoenix_attributes(span: Span, kwargs, response_obj):
@ -59,15 +63,14 @@ class ArizePhoenixLogger:
# a slightly different auth header format than self hosted phoenix # a slightly different auth header format than self hosted phoenix
if endpoint == ARIZE_HOSTED_PHOENIX_ENDPOINT: if endpoint == ARIZE_HOSTED_PHOENIX_ENDPOINT:
if api_key is None: if api_key is None:
raise ValueError("PHOENIX_API_KEY must be set when the Arize hosted Phoenix endpoint is used.") raise ValueError(
"PHOENIX_API_KEY must be set when the Arize hosted Phoenix endpoint is used."
)
otlp_auth_headers = f"api_key={api_key}" otlp_auth_headers = f"api_key={api_key}"
elif api_key is not None: elif api_key is not None:
# api_key/auth is optional for self hosted phoenix # api_key/auth is optional for self hosted phoenix
otlp_auth_headers = f"Authorization=Bearer {api_key}" otlp_auth_headers = f"Authorization=Bearer {api_key}"
return ArizePhoenixConfig( return ArizePhoenixConfig(
otlp_auth_headers=otlp_auth_headers, otlp_auth_headers=otlp_auth_headers, protocol=protocol, endpoint=endpoint
protocol=protocol,
endpoint=endpoint
) )

View file

@ -12,7 +12,10 @@ class AthinaLogger:
"athina-api-key": self.athina_api_key, "athina-api-key": self.athina_api_key,
"Content-Type": "application/json", "Content-Type": "application/json",
} }
self.athina_logging_url = os.getenv("ATHINA_BASE_URL", "https://log.athina.ai") + "/api/v1/log/inference" self.athina_logging_url = (
os.getenv("ATHINA_BASE_URL", "https://log.athina.ai")
+ "/api/v1/log/inference"
)
self.additional_keys = [ self.additional_keys = [
"environment", "environment",
"prompt_slug", "prompt_slug",

View file

@ -50,12 +50,12 @@ class AzureBlobStorageLogger(CustomBatchLogger):
self.azure_storage_file_system: str = _azure_storage_file_system self.azure_storage_file_system: str = _azure_storage_file_system
# Internal variables used for Token based authentication # Internal variables used for Token based authentication
self.azure_auth_token: Optional[str] = ( self.azure_auth_token: Optional[
None # the Azure AD token to use for Azure Storage API requests str
) ] = None # the Azure AD token to use for Azure Storage API requests
self.token_expiry: Optional[datetime] = ( self.token_expiry: Optional[
None # the expiry time of the currentAzure AD token datetime
) ] = None # the expiry time of the currentAzure AD token
asyncio.create_task(self.periodic_flush()) asyncio.create_task(self.periodic_flush())
self.flush_lock = asyncio.Lock() self.flush_lock = asyncio.Lock()
@ -153,7 +153,6 @@ class AzureBlobStorageLogger(CustomBatchLogger):
3. Flush the data 3. Flush the data
""" """
try: try:
if self.azure_storage_account_key: if self.azure_storage_account_key:
await self.upload_to_azure_data_lake_with_azure_account_key( await self.upload_to_azure_data_lake_with_azure_account_key(
payload=payload payload=payload

View file

@ -4,7 +4,7 @@
import copy import copy
import os import os
from datetime import datetime from datetime import datetime
from typing import Optional, Dict from typing import Dict, Optional
import httpx import httpx
from pydantic import BaseModel from pydantic import BaseModel
@ -19,7 +19,9 @@ from litellm.llms.custom_httpx.http_handler import (
) )
from litellm.utils import print_verbose from litellm.utils import print_verbose
global_braintrust_http_handler = get_async_httpx_client(llm_provider=httpxSpecialProvider.LoggingCallback) global_braintrust_http_handler = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback
)
global_braintrust_sync_http_handler = HTTPHandler() global_braintrust_sync_http_handler = HTTPHandler()
API_BASE = "https://api.braintrustdata.com/v1" API_BASE = "https://api.braintrustdata.com/v1"
@ -35,7 +37,9 @@ def get_utc_datetime():
class BraintrustLogger(CustomLogger): class BraintrustLogger(CustomLogger):
def __init__(self, api_key: Optional[str] = None, api_base: Optional[str] = None) -> None: def __init__(
self, api_key: Optional[str] = None, api_base: Optional[str] = None
) -> None:
super().__init__() super().__init__()
self.validate_environment(api_key=api_key) self.validate_environment(api_key=api_key)
self.api_base = api_base or API_BASE self.api_base = api_base or API_BASE
@ -45,7 +49,9 @@ class BraintrustLogger(CustomLogger):
"Authorization": "Bearer " + self.api_key, "Authorization": "Bearer " + self.api_key,
"Content-Type": "application/json", "Content-Type": "application/json",
} }
self._project_id_cache: Dict[str, str] = {} # Cache mapping project names to IDs self._project_id_cache: Dict[
str, str
] = {} # Cache mapping project names to IDs
def validate_environment(self, api_key: Optional[str]): def validate_environment(self, api_key: Optional[str]):
""" """
@ -71,7 +77,9 @@ class BraintrustLogger(CustomLogger):
try: try:
response = global_braintrust_sync_http_handler.post( response = global_braintrust_sync_http_handler.post(
f"{self.api_base}/project", headers=self.headers, json={"name": project_name} f"{self.api_base}/project",
headers=self.headers,
json={"name": project_name},
) )
project_dict = response.json() project_dict = response.json()
project_id = project_dict["id"] project_id = project_dict["id"]
@ -89,7 +97,9 @@ class BraintrustLogger(CustomLogger):
try: try:
response = await global_braintrust_http_handler.post( response = await global_braintrust_http_handler.post(
f"{self.api_base}/project/register", headers=self.headers, json={"name": project_name} f"{self.api_base}/project/register",
headers=self.headers,
json={"name": project_name},
) )
project_dict = response.json() project_dict = response.json()
project_id = project_dict["id"] project_id = project_dict["id"]
@ -116,15 +126,21 @@ class BraintrustLogger(CustomLogger):
if metadata is None: if metadata is None:
metadata = {} metadata = {}
proxy_headers = litellm_params.get("proxy_server_request", {}).get("headers", {}) or {} proxy_headers = (
litellm_params.get("proxy_server_request", {}).get("headers", {}) or {}
)
for metadata_param_key in proxy_headers: for metadata_param_key in proxy_headers:
if metadata_param_key.startswith("braintrust"): if metadata_param_key.startswith("braintrust"):
trace_param_key = metadata_param_key.replace("braintrust", "", 1) trace_param_key = metadata_param_key.replace("braintrust", "", 1)
if trace_param_key in metadata: if trace_param_key in metadata:
verbose_logger.warning(f"Overwriting Braintrust `{trace_param_key}` from request header") verbose_logger.warning(
f"Overwriting Braintrust `{trace_param_key}` from request header"
)
else: else:
verbose_logger.debug(f"Found Braintrust `{trace_param_key}` in request header") verbose_logger.debug(
f"Found Braintrust `{trace_param_key}` in request header"
)
metadata[trace_param_key] = proxy_headers.get(metadata_param_key) metadata[trace_param_key] = proxy_headers.get(metadata_param_key)
return metadata return metadata
@ -157,24 +173,35 @@ class BraintrustLogger(CustomLogger):
output = None output = None
choices = [] choices = []
if response_obj is not None and ( if response_obj is not None and (
kwargs.get("call_type", None) == "embedding" or isinstance(response_obj, litellm.EmbeddingResponse) kwargs.get("call_type", None) == "embedding"
or isinstance(response_obj, litellm.EmbeddingResponse)
): ):
output = None output = None
elif response_obj is not None and isinstance(response_obj, litellm.ModelResponse): elif response_obj is not None and isinstance(
response_obj, litellm.ModelResponse
):
output = response_obj["choices"][0]["message"].json() output = response_obj["choices"][0]["message"].json()
choices = response_obj["choices"] choices = response_obj["choices"]
elif response_obj is not None and isinstance(response_obj, litellm.TextCompletionResponse): elif response_obj is not None and isinstance(
response_obj, litellm.TextCompletionResponse
):
output = response_obj.choices[0].text output = response_obj.choices[0].text
choices = response_obj.choices choices = response_obj.choices
elif response_obj is not None and isinstance(response_obj, litellm.ImageResponse): elif response_obj is not None and isinstance(
response_obj, litellm.ImageResponse
):
output = response_obj["data"] output = response_obj["data"]
litellm_params = kwargs.get("litellm_params", {}) litellm_params = kwargs.get("litellm_params", {})
metadata = litellm_params.get("metadata", {}) or {} # if litellm_params['metadata'] == None metadata = (
litellm_params.get("metadata", {}) or {}
) # if litellm_params['metadata'] == None
metadata = self.add_metadata_from_header(litellm_params, metadata) metadata = self.add_metadata_from_header(litellm_params, metadata)
clean_metadata = {} clean_metadata = {}
try: try:
metadata = copy.deepcopy(metadata) # Avoid modifying the original metadata metadata = copy.deepcopy(
metadata
) # Avoid modifying the original metadata
except Exception: except Exception:
new_metadata = {} new_metadata = {}
for key, value in metadata.items(): for key, value in metadata.items():
@ -192,7 +219,9 @@ class BraintrustLogger(CustomLogger):
project_id = metadata.get("project_id") project_id = metadata.get("project_id")
if project_id is None: if project_id is None:
project_name = metadata.get("project_name") project_name = metadata.get("project_name")
project_id = self.get_project_id_sync(project_name) if project_name else None project_id = (
self.get_project_id_sync(project_name) if project_name else None
)
if project_id is None: if project_id is None:
if self.default_project_id is None: if self.default_project_id is None:
@ -234,7 +263,8 @@ class BraintrustLogger(CustomLogger):
"completion_tokens": usage_obj.completion_tokens, "completion_tokens": usage_obj.completion_tokens,
"total_tokens": usage_obj.total_tokens, "total_tokens": usage_obj.total_tokens,
"total_cost": cost, "total_cost": cost,
"time_to_first_token": end_time.timestamp() - start_time.timestamp(), "time_to_first_token": end_time.timestamp()
- start_time.timestamp(),
"start": start_time.timestamp(), "start": start_time.timestamp(),
"end": end_time.timestamp(), "end": end_time.timestamp(),
} }
@ -255,7 +285,9 @@ class BraintrustLogger(CustomLogger):
request_data["metrics"] = metrics request_data["metrics"] = metrics
try: try:
print_verbose(f"global_braintrust_sync_http_handler.post: {global_braintrust_sync_http_handler.post}") print_verbose(
f"global_braintrust_sync_http_handler.post: {global_braintrust_sync_http_handler.post}"
)
global_braintrust_sync_http_handler.post( global_braintrust_sync_http_handler.post(
url=f"{self.api_base}/project_logs/{project_id}/insert", url=f"{self.api_base}/project_logs/{project_id}/insert",
json={"events": [request_data]}, json={"events": [request_data]},
@ -276,20 +308,29 @@ class BraintrustLogger(CustomLogger):
output = None output = None
choices = [] choices = []
if response_obj is not None and ( if response_obj is not None and (
kwargs.get("call_type", None) == "embedding" or isinstance(response_obj, litellm.EmbeddingResponse) kwargs.get("call_type", None) == "embedding"
or isinstance(response_obj, litellm.EmbeddingResponse)
): ):
output = None output = None
elif response_obj is not None and isinstance(response_obj, litellm.ModelResponse): elif response_obj is not None and isinstance(
response_obj, litellm.ModelResponse
):
output = response_obj["choices"][0]["message"].json() output = response_obj["choices"][0]["message"].json()
choices = response_obj["choices"] choices = response_obj["choices"]
elif response_obj is not None and isinstance(response_obj, litellm.TextCompletionResponse): elif response_obj is not None and isinstance(
response_obj, litellm.TextCompletionResponse
):
output = response_obj.choices[0].text output = response_obj.choices[0].text
choices = response_obj.choices choices = response_obj.choices
elif response_obj is not None and isinstance(response_obj, litellm.ImageResponse): elif response_obj is not None and isinstance(
response_obj, litellm.ImageResponse
):
output = response_obj["data"] output = response_obj["data"]
litellm_params = kwargs.get("litellm_params", {}) litellm_params = kwargs.get("litellm_params", {})
metadata = litellm_params.get("metadata", {}) or {} # if litellm_params['metadata'] == None metadata = (
litellm_params.get("metadata", {}) or {}
) # if litellm_params['metadata'] == None
metadata = self.add_metadata_from_header(litellm_params, metadata) metadata = self.add_metadata_from_header(litellm_params, metadata)
clean_metadata = {} clean_metadata = {}
new_metadata = {} new_metadata = {}
@ -313,7 +354,11 @@ class BraintrustLogger(CustomLogger):
project_id = metadata.get("project_id") project_id = metadata.get("project_id")
if project_id is None: if project_id is None:
project_name = metadata.get("project_name") project_name = metadata.get("project_name")
project_id = await self.get_project_id_async(project_name) if project_name else None project_id = (
await self.get_project_id_async(project_name)
if project_name
else None
)
if project_id is None: if project_id is None:
if self.default_project_id is None: if self.default_project_id is None:
@ -362,8 +407,14 @@ class BraintrustLogger(CustomLogger):
api_call_start_time = kwargs.get("api_call_start_time") api_call_start_time = kwargs.get("api_call_start_time")
completion_start_time = kwargs.get("completion_start_time") completion_start_time = kwargs.get("completion_start_time")
if api_call_start_time is not None and completion_start_time is not None: if (
metrics["time_to_first_token"] = completion_start_time.timestamp() - api_call_start_time.timestamp() api_call_start_time is not None
and completion_start_time is not None
):
metrics["time_to_first_token"] = (
completion_start_time.timestamp()
- api_call_start_time.timestamp()
)
request_data = { request_data = {
"id": litellm_call_id, "id": litellm_call_id,

View file

@ -14,7 +14,6 @@ from litellm.integrations.custom_logger import CustomLogger
class CustomBatchLogger(CustomLogger): class CustomBatchLogger(CustomLogger):
def __init__( def __init__(
self, self,
flush_lock: Optional[asyncio.Lock] = None, flush_lock: Optional[asyncio.Lock] = None,

View file

@ -7,7 +7,6 @@ from litellm.types.utils import StandardLoggingGuardrailInformation
class CustomGuardrail(CustomLogger): class CustomGuardrail(CustomLogger):
def __init__( def __init__(
self, self,
guardrail_name: Optional[str] = None, guardrail_name: Optional[str] = None,

View file

@ -31,7 +31,7 @@ from litellm.types.utils import (
if TYPE_CHECKING: if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span from opentelemetry.trace import Span as _Span
Span = _Span Span = Union[_Span, Any]
else: else:
Span = Any Span = Any

View file

@ -233,7 +233,6 @@ class DataDogLogger(
pass pass
async def _log_async_event(self, kwargs, response_obj, start_time, end_time): async def _log_async_event(self, kwargs, response_obj, start_time, end_time):
dd_payload = self.create_datadog_logging_payload( dd_payload = self.create_datadog_logging_payload(
kwargs=kwargs, kwargs=kwargs,
response_obj=response_obj, response_obj=response_obj,

View file

@ -125,9 +125,9 @@ class GCSBucketBase(CustomBatchLogger):
if kwargs is None: if kwargs is None:
kwargs = {} kwargs = {}
standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = ( standard_callback_dynamic_params: Optional[
kwargs.get("standard_callback_dynamic_params", None) StandardCallbackDynamicParams
) ] = kwargs.get("standard_callback_dynamic_params", None)
bucket_name: str bucket_name: str
path_service_account: Optional[str] path_service_account: Optional[str]

View file

@ -70,13 +70,14 @@ class GcsPubSubLogger(CustomBatchLogger):
"""Construct authorization headers using Vertex AI auth""" """Construct authorization headers using Vertex AI auth"""
from litellm import vertex_chat_completion from litellm import vertex_chat_completion
_auth_header, vertex_project = ( (
await vertex_chat_completion._ensure_access_token_async( _auth_header,
vertex_project,
) = await vertex_chat_completion._ensure_access_token_async(
credentials=self.path_service_account_json, credentials=self.path_service_account_json,
project_id=None, project_id=None,
custom_llm_provider="vertex_ai", custom_llm_provider="vertex_ai",
) )
)
auth_header, _ = vertex_chat_completion._get_token_and_url( auth_header, _ = vertex_chat_completion._get_token_and_url(
model="pub-sub", model="pub-sub",

View file

@ -155,11 +155,7 @@ class HumanloopLogger(CustomLogger):
prompt_id: str, prompt_id: str,
prompt_variables: Optional[dict], prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams, dynamic_callback_params: StandardCallbackDynamicParams,
) -> Tuple[ ) -> Tuple[str, List[AllMessageValues], dict,]:
str,
List[AllMessageValues],
dict,
]:
humanloop_api_key = dynamic_callback_params.get( humanloop_api_key = dynamic_callback_params.get(
"humanloop_api_key" "humanloop_api_key"
) or get_secret_str("HUMANLOOP_API_KEY") ) or get_secret_str("HUMANLOOP_API_KEY")

View file

@ -471,9 +471,9 @@ class LangFuseLogger:
# we clean out all extra litellm metadata params before logging # we clean out all extra litellm metadata params before logging
clean_metadata: Dict[str, Any] = {} clean_metadata: Dict[str, Any] = {}
if prompt_management_metadata is not None: if prompt_management_metadata is not None:
clean_metadata["prompt_management_metadata"] = ( clean_metadata[
prompt_management_metadata "prompt_management_metadata"
) ] = prompt_management_metadata
if isinstance(metadata, dict): if isinstance(metadata, dict):
for key, value in metadata.items(): for key, value in metadata.items():
# generate langfuse tags - Default Tags sent to Langfuse from LiteLLM Proxy # generate langfuse tags - Default Tags sent to Langfuse from LiteLLM Proxy

View file

@ -19,7 +19,6 @@ else:
class LangFuseHandler: class LangFuseHandler:
@staticmethod @staticmethod
def get_langfuse_logger_for_request( def get_langfuse_logger_for_request(
standard_callback_dynamic_params: StandardCallbackDynamicParams, standard_callback_dynamic_params: StandardCallbackDynamicParams,
@ -87,7 +86,9 @@ class LangFuseHandler:
if globalLangfuseLogger is not None: if globalLangfuseLogger is not None:
return globalLangfuseLogger return globalLangfuseLogger
credentials_dict: Dict[str, Any] = ( credentials_dict: Dict[
str, Any
] = (
{} {}
) # the global langfuse logger uses Environment Variables, there are no dynamic credentials ) # the global langfuse logger uses Environment Variables, there are no dynamic credentials
globalLangfuseLogger = in_memory_dynamic_logger_cache.get_cache( globalLangfuseLogger = in_memory_dynamic_logger_cache.get_cache(

View file

@ -172,11 +172,7 @@ class LangfusePromptManagement(LangFuseLogger, PromptManagementBase, CustomLogge
prompt_id: str, prompt_id: str,
prompt_variables: Optional[dict], prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams, dynamic_callback_params: StandardCallbackDynamicParams,
) -> Tuple[ ) -> Tuple[str, List[AllMessageValues], dict,]:
str,
List[AllMessageValues],
dict,
]:
return self.get_chat_completion_prompt( return self.get_chat_completion_prompt(
model, model,
messages, messages,

View file

@ -75,7 +75,6 @@ class LangsmithLogger(CustomBatchLogger):
langsmith_project: Optional[str] = None, langsmith_project: Optional[str] = None,
langsmith_base_url: Optional[str] = None, langsmith_base_url: Optional[str] = None,
) -> LangsmithCredentialsObject: ) -> LangsmithCredentialsObject:
_credentials_api_key = langsmith_api_key or os.getenv("LANGSMITH_API_KEY") _credentials_api_key = langsmith_api_key or os.getenv("LANGSMITH_API_KEY")
if _credentials_api_key is None: if _credentials_api_key is None:
raise Exception( raise Exception(
@ -443,9 +442,9 @@ class LangsmithLogger(CustomBatchLogger):
Otherwise, use the default credentials. Otherwise, use the default credentials.
""" """
standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = ( standard_callback_dynamic_params: Optional[
kwargs.get("standard_callback_dynamic_params", None) StandardCallbackDynamicParams
) ] = kwargs.get("standard_callback_dynamic_params", None)
if standard_callback_dynamic_params is not None: if standard_callback_dynamic_params is not None:
credentials = self.get_credentials_from_env( credentials = self.get_credentials_from_env(
langsmith_api_key=standard_callback_dynamic_params.get( langsmith_api_key=standard_callback_dynamic_params.get(
@ -481,7 +480,6 @@ class LangsmithLogger(CustomBatchLogger):
asyncio.run(self.async_send_batch()) asyncio.run(self.async_send_batch())
def get_run_by_id(self, run_id): def get_run_by_id(self, run_id):
langsmith_api_key = self.default_credentials["LANGSMITH_API_KEY"] langsmith_api_key = self.default_credentials["LANGSMITH_API_KEY"]
langsmith_api_base = self.default_credentials["LANGSMITH_BASE_URL"] langsmith_api_base = self.default_credentials["LANGSMITH_BASE_URL"]

View file

@ -1,12 +1,12 @@
import json import json
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any, Union
from litellm.proxy._types import SpanAttributes from litellm.proxy._types import SpanAttributes
if TYPE_CHECKING: if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span from opentelemetry.trace import Span as _Span
Span = _Span Span = Union[_Span, Any]
else: else:
Span = Any Span = Any

View file

@ -20,7 +20,6 @@ def parse_tool_calls(tool_calls):
return None return None
def clean_tool_call(tool_call): def clean_tool_call(tool_call):
serialized = { serialized = {
"type": tool_call.type, "type": tool_call.type,
"id": tool_call.id, "id": tool_call.id,
@ -36,7 +35,6 @@ def parse_tool_calls(tool_calls):
def parse_messages(input): def parse_messages(input):
if input is None: if input is None:
return None return None

View file

@ -48,14 +48,17 @@ class MlflowLogger(CustomLogger):
def _extract_and_set_chat_attributes(self, span, kwargs, response_obj): def _extract_and_set_chat_attributes(self, span, kwargs, response_obj):
try: try:
from mlflow.tracing.utils import set_span_chat_messages, set_span_chat_tools from mlflow.tracing.utils import set_span_chat_messages # type: ignore
from mlflow.tracing.utils import set_span_chat_tools # type: ignore
except ImportError: except ImportError:
return return
inputs = self._construct_input(kwargs) inputs = self._construct_input(kwargs)
input_messages = inputs.get("messages", []) input_messages = inputs.get("messages", [])
output_messages = [c.message.model_dump(exclude_none=True) output_messages = [
for c in getattr(response_obj, "choices", [])] c.message.model_dump(exclude_none=True)
for c in getattr(response_obj, "choices", [])
]
if messages := [*input_messages, *output_messages]: if messages := [*input_messages, *output_messages]:
set_span_chat_messages(span, messages) set_span_chat_messages(span, messages)
if tools := inputs.get("tools"): if tools := inputs.get("tools"):

View file

@ -1,7 +1,7 @@
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
import litellm import litellm
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
@ -23,10 +23,10 @@ if TYPE_CHECKING:
) )
from litellm.proxy.proxy_server import UserAPIKeyAuth as _UserAPIKeyAuth from litellm.proxy.proxy_server import UserAPIKeyAuth as _UserAPIKeyAuth
Span = _Span Span = Union[_Span, Any]
SpanExporter = _SpanExporter SpanExporter = Union[_SpanExporter, Any]
UserAPIKeyAuth = _UserAPIKeyAuth UserAPIKeyAuth = Union[_UserAPIKeyAuth, Any]
ManagementEndpointLoggingPayload = _ManagementEndpointLoggingPayload ManagementEndpointLoggingPayload = Union[_ManagementEndpointLoggingPayload, Any]
else: else:
Span = Any Span = Any
SpanExporter = Any SpanExporter = Any
@ -46,7 +46,6 @@ LITELLM_REQUEST_SPAN_NAME = "litellm_request"
@dataclass @dataclass
class OpenTelemetryConfig: class OpenTelemetryConfig:
exporter: Union[str, SpanExporter] = "console" exporter: Union[str, SpanExporter] = "console"
endpoint: Optional[str] = None endpoint: Optional[str] = None
headers: Optional[str] = None headers: Optional[str] = None
@ -154,7 +153,6 @@ class OpenTelemetry(CustomLogger):
end_time: Optional[Union[datetime, float]] = None, end_time: Optional[Union[datetime, float]] = None,
event_metadata: Optional[dict] = None, event_metadata: Optional[dict] = None,
): ):
from opentelemetry import trace from opentelemetry import trace
from opentelemetry.trace import Status, StatusCode from opentelemetry.trace import Status, StatusCode
@ -215,7 +213,6 @@ class OpenTelemetry(CustomLogger):
end_time: Optional[Union[float, datetime]] = None, end_time: Optional[Union[float, datetime]] = None,
event_metadata: Optional[dict] = None, event_metadata: Optional[dict] = None,
): ):
from opentelemetry import trace from opentelemetry import trace
from opentelemetry.trace import Status, StatusCode from opentelemetry.trace import Status, StatusCode
@ -353,9 +350,9 @@ class OpenTelemetry(CustomLogger):
""" """
from opentelemetry import trace from opentelemetry import trace
standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = ( standard_callback_dynamic_params: Optional[
kwargs.get("standard_callback_dynamic_params") StandardCallbackDynamicParams
) ] = kwargs.get("standard_callback_dynamic_params")
if not standard_callback_dynamic_params: if not standard_callback_dynamic_params:
return return
@ -722,7 +719,6 @@ class OpenTelemetry(CustomLogger):
span.set_attribute(key, primitive_value) span.set_attribute(key, primitive_value)
def set_raw_request_attributes(self, span: Span, kwargs, response_obj): def set_raw_request_attributes(self, span: Span, kwargs, response_obj):
kwargs.get("optional_params", {}) kwargs.get("optional_params", {})
litellm_params = kwargs.get("litellm_params", {}) or {} litellm_params = kwargs.get("litellm_params", {}) or {}
custom_llm_provider = litellm_params.get("custom_llm_provider", "Unknown") custom_llm_provider = litellm_params.get("custom_llm_provider", "Unknown")
@ -843,12 +839,14 @@ class OpenTelemetry(CustomLogger):
headers=dynamic_headers or self.OTEL_HEADERS headers=dynamic_headers or self.OTEL_HEADERS
) )
if isinstance(self.OTEL_EXPORTER, SpanExporter): if hasattr(
self.OTEL_EXPORTER, "export"
): # Check if it has the export method that SpanExporter requires
verbose_logger.debug( verbose_logger.debug(
"OpenTelemetry: intiializing SpanExporter. Value of OTEL_EXPORTER: %s", "OpenTelemetry: intiializing SpanExporter. Value of OTEL_EXPORTER: %s",
self.OTEL_EXPORTER, self.OTEL_EXPORTER,
) )
return SimpleSpanProcessor(self.OTEL_EXPORTER) return SimpleSpanProcessor(cast(SpanExporter, self.OTEL_EXPORTER))
if self.OTEL_EXPORTER == "console": if self.OTEL_EXPORTER == "console":
verbose_logger.debug( verbose_logger.debug(
@ -907,7 +905,6 @@ class OpenTelemetry(CustomLogger):
logging_payload: ManagementEndpointLoggingPayload, logging_payload: ManagementEndpointLoggingPayload,
parent_otel_span: Optional[Span] = None, parent_otel_span: Optional[Span] = None,
): ):
from opentelemetry import trace from opentelemetry import trace
from opentelemetry.trace import Status, StatusCode from opentelemetry.trace import Status, StatusCode
@ -961,7 +958,6 @@ class OpenTelemetry(CustomLogger):
logging_payload: ManagementEndpointLoggingPayload, logging_payload: ManagementEndpointLoggingPayload,
parent_otel_span: Optional[Span] = None, parent_otel_span: Optional[Span] = None,
): ):
from opentelemetry import trace from opentelemetry import trace
from opentelemetry.trace import Status, StatusCode from opentelemetry.trace import Status, StatusCode

View file

@ -185,7 +185,6 @@ class OpikLogger(CustomBatchLogger):
def _create_opik_payload( # noqa: PLR0915 def _create_opik_payload( # noqa: PLR0915
self, kwargs, response_obj, start_time, end_time self, kwargs, response_obj, start_time, end_time
) -> List[Dict]: ) -> List[Dict]:
# Get metadata # Get metadata
_litellm_params = kwargs.get("litellm_params", {}) or {} _litellm_params = kwargs.get("litellm_params", {}) or {}
litellm_params_metadata = _litellm_params.get("metadata", {}) or {} litellm_params_metadata = _litellm_params.get("metadata", {}) or {}

View file

@ -988,9 +988,9 @@ class PrometheusLogger(CustomLogger):
): ):
try: try:
verbose_logger.debug("setting remaining tokens requests metric") verbose_logger.debug("setting remaining tokens requests metric")
standard_logging_payload: Optional[StandardLoggingPayload] = ( standard_logging_payload: Optional[
request_kwargs.get("standard_logging_object") StandardLoggingPayload
) ] = request_kwargs.get("standard_logging_object")
if standard_logging_payload is None: if standard_logging_payload is None:
return return

View file

@ -14,7 +14,6 @@ class PromptManagementClient(TypedDict):
class PromptManagementBase(ABC): class PromptManagementBase(ABC):
@property @property
@abstractmethod @abstractmethod
def integration_name(self) -> str: def integration_name(self) -> str:
@ -83,11 +82,7 @@ class PromptManagementBase(ABC):
prompt_id: str, prompt_id: str,
prompt_variables: Optional[dict], prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams, dynamic_callback_params: StandardCallbackDynamicParams,
) -> Tuple[ ) -> Tuple[str, List[AllMessageValues], dict,]:
str,
List[AllMessageValues],
dict,
]:
if not self.should_run_prompt_management( if not self.should_run_prompt_management(
prompt_id=prompt_id, dynamic_callback_params=dynamic_callback_params prompt_id=prompt_id, dynamic_callback_params=dynamic_callback_params
): ):

View file

@ -38,7 +38,7 @@ class S3Logger:
if litellm.s3_callback_params is not None: if litellm.s3_callback_params is not None:
# read in .env variables - example os.environ/AWS_BUCKET_NAME # read in .env variables - example os.environ/AWS_BUCKET_NAME
for key, value in litellm.s3_callback_params.items(): for key, value in litellm.s3_callback_params.items():
if type(value) is str and value.startswith("os.environ/"): if isinstance(value, str) and value.startswith("os.environ/"):
litellm.s3_callback_params[key] = litellm.get_secret(value) litellm.s3_callback_params[key] = litellm.get_secret(value)
# now set s3 params from litellm.s3_logger_params # now set s3 params from litellm.s3_logger_params
s3_bucket_name = litellm.s3_callback_params.get("s3_bucket_name") s3_bucket_name = litellm.s3_callback_params.get("s3_bucket_name")

View file

@ -21,11 +21,11 @@ try:
# contains a (known) object attribute # contains a (known) object attribute
object: Literal["chat.completion", "edit", "text_completion"] object: Literal["chat.completion", "edit", "text_completion"]
def __getitem__(self, key: K) -> V: ... # noqa def __getitem__(self, key: K) -> V:
... # noqa
def get( # noqa def get(self, key: K, default: Optional[V] = None) -> Optional[V]: # noqa
self, key: K, default: Optional[V] = None ... # pragma: no cover
) -> Optional[V]: ... # pragma: no cover
class OpenAIRequestResponseResolver: class OpenAIRequestResponseResolver:
def __call__( def __call__(

View file

@ -10,7 +10,7 @@ from litellm.types.llms.openai import AllMessageValues
if TYPE_CHECKING: if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span from opentelemetry.trace import Span as _Span
Span = _Span Span = Union[_Span, Any]
else: else:
Span = Any Span = Any

View file

@ -11,7 +11,9 @@ except (ImportError, AttributeError):
# Old way to access resources, which setuptools deprecated some time ago # Old way to access resources, which setuptools deprecated some time ago
import pkg_resources # type: ignore import pkg_resources # type: ignore
filename = pkg_resources.resource_filename(__name__, "litellm_core_utils/tokenizers") filename = pkg_resources.resource_filename(
__name__, "litellm_core_utils/tokenizers"
)
os.environ["TIKTOKEN_CACHE_DIR"] = os.getenv( os.environ["TIKTOKEN_CACHE_DIR"] = os.getenv(
"CUSTOM_TIKTOKEN_CACHE_DIR", filename "CUSTOM_TIKTOKEN_CACHE_DIR", filename

View file

@ -239,9 +239,9 @@ class Logging(LiteLLMLoggingBaseClass):
self.litellm_trace_id = litellm_trace_id self.litellm_trace_id = litellm_trace_id
self.function_id = function_id self.function_id = function_id
self.streaming_chunks: List[Any] = [] # for generating complete stream response self.streaming_chunks: List[Any] = [] # for generating complete stream response
self.sync_streaming_chunks: List[Any] = ( self.sync_streaming_chunks: List[
[] Any
) # for generating complete stream response ] = [] # for generating complete stream response
self.log_raw_request_response = log_raw_request_response self.log_raw_request_response = log_raw_request_response
# Initialize dynamic callbacks # Initialize dynamic callbacks
@ -452,11 +452,13 @@ class Logging(LiteLLMLoggingBaseClass):
prompt_id: str, prompt_id: str,
prompt_variables: Optional[dict], prompt_variables: Optional[dict],
) -> Tuple[str, List[AllMessageValues], dict]: ) -> Tuple[str, List[AllMessageValues], dict]:
custom_logger = self.get_custom_logger_for_prompt_management(model) custom_logger = self.get_custom_logger_for_prompt_management(model)
if custom_logger: if custom_logger:
model, messages, non_default_params = ( (
custom_logger.get_chat_completion_prompt( model,
messages,
non_default_params,
) = custom_logger.get_chat_completion_prompt(
model=model, model=model,
messages=messages, messages=messages,
non_default_params=non_default_params, non_default_params=non_default_params,
@ -464,7 +466,6 @@ class Logging(LiteLLMLoggingBaseClass):
prompt_variables=prompt_variables, prompt_variables=prompt_variables,
dynamic_callback_params=self.standard_callback_dynamic_params, dynamic_callback_params=self.standard_callback_dynamic_params,
) )
)
self.messages = messages self.messages = messages
return model, messages, non_default_params return model, messages, non_default_params
@ -541,12 +542,11 @@ class Logging(LiteLLMLoggingBaseClass):
model model
): # if model name was changes pre-call, overwrite the initial model call name with the new one ): # if model name was changes pre-call, overwrite the initial model call name with the new one
self.model_call_details["model"] = model self.model_call_details["model"] = model
self.model_call_details["litellm_params"]["api_base"] = ( self.model_call_details["litellm_params"][
self._get_masked_api_base(additional_args.get("api_base", "")) "api_base"
) ] = self._get_masked_api_base(additional_args.get("api_base", ""))
def pre_call(self, input, api_key, model=None, additional_args={}): # noqa: PLR0915 def pre_call(self, input, api_key, model=None, additional_args={}): # noqa: PLR0915
# Log the exact input to the LLM API # Log the exact input to the LLM API
litellm.error_logs["PRE_CALL"] = locals() litellm.error_logs["PRE_CALL"] = locals()
try: try:
@ -568,19 +568,16 @@ class Logging(LiteLLMLoggingBaseClass):
self.log_raw_request_response is True self.log_raw_request_response is True
or log_raw_request_response is True or log_raw_request_response is True
): ):
_litellm_params = self.model_call_details.get("litellm_params", {}) _litellm_params = self.model_call_details.get("litellm_params", {})
_metadata = _litellm_params.get("metadata", {}) or {} _metadata = _litellm_params.get("metadata", {}) or {}
try: try:
# [Non-blocking Extra Debug Information in metadata] # [Non-blocking Extra Debug Information in metadata]
if turn_off_message_logging is True: if turn_off_message_logging is True:
_metadata[
_metadata["raw_request"] = ( "raw_request"
"redacted by litellm. \ ] = "redacted by litellm. \
'litellm.turn_off_message_logging=True'" 'litellm.turn_off_message_logging=True'"
)
else: else:
curl_command = self._get_request_curl_command( curl_command = self._get_request_curl_command(
api_base=additional_args.get("api_base", ""), api_base=additional_args.get("api_base", ""),
headers=additional_args.get("headers", {}), headers=additional_args.get("headers", {}),
@ -590,8 +587,9 @@ class Logging(LiteLLMLoggingBaseClass):
_metadata["raw_request"] = str(curl_command) _metadata["raw_request"] = str(curl_command)
# split up, so it's easier to parse in the UI # split up, so it's easier to parse in the UI
self.model_call_details["raw_request_typed_dict"] = ( self.model_call_details[
RawRequestTypedDict( "raw_request_typed_dict"
] = RawRequestTypedDict(
raw_request_api_base=str( raw_request_api_base=str(
additional_args.get("api_base") or "" additional_args.get("api_base") or ""
), ),
@ -604,20 +602,19 @@ class Logging(LiteLLMLoggingBaseClass):
), ),
error=None, error=None,
) )
)
except Exception as e: except Exception as e:
self.model_call_details["raw_request_typed_dict"] = ( self.model_call_details[
RawRequestTypedDict( "raw_request_typed_dict"
] = RawRequestTypedDict(
error=str(e), error=str(e),
) )
)
traceback.print_exc() traceback.print_exc()
_metadata["raw_request"] = ( _metadata[
"Unable to Log \ "raw_request"
] = "Unable to Log \
raw request: {}".format( raw request: {}".format(
str(e) str(e)
) )
)
if self.logger_fn and callable(self.logger_fn): if self.logger_fn and callable(self.logger_fn):
try: try:
self.logger_fn( self.logger_fn(
@ -941,9 +938,9 @@ class Logging(LiteLLMLoggingBaseClass):
verbose_logger.debug( verbose_logger.debug(
f"response_cost_failure_debug_information: {debug_info}" f"response_cost_failure_debug_information: {debug_info}"
) )
self.model_call_details["response_cost_failure_debug_information"] = ( self.model_call_details[
debug_info "response_cost_failure_debug_information"
) ] = debug_info
return None return None
try: try:
@ -968,9 +965,9 @@ class Logging(LiteLLMLoggingBaseClass):
verbose_logger.debug( verbose_logger.debug(
f"response_cost_failure_debug_information: {debug_info}" f"response_cost_failure_debug_information: {debug_info}"
) )
self.model_call_details["response_cost_failure_debug_information"] = ( self.model_call_details[
debug_info "response_cost_failure_debug_information"
) ] = debug_info
return None return None
@ -995,7 +992,6 @@ class Logging(LiteLLMLoggingBaseClass):
def should_run_callback( def should_run_callback(
self, callback: litellm.CALLBACK_TYPES, litellm_params: dict, event_hook: str self, callback: litellm.CALLBACK_TYPES, litellm_params: dict, event_hook: str
) -> bool: ) -> bool:
if litellm.global_disable_no_log_param: if litellm.global_disable_no_log_param:
return True return True
@ -1027,9 +1023,9 @@ class Logging(LiteLLMLoggingBaseClass):
end_time = datetime.datetime.now() end_time = datetime.datetime.now()
if self.completion_start_time is None: if self.completion_start_time is None:
self.completion_start_time = end_time self.completion_start_time = end_time
self.model_call_details["completion_start_time"] = ( self.model_call_details[
self.completion_start_time "completion_start_time"
) ] = self.completion_start_time
self.model_call_details["log_event_type"] = "successful_api_call" self.model_call_details["log_event_type"] = "successful_api_call"
self.model_call_details["end_time"] = end_time self.model_call_details["end_time"] = end_time
self.model_call_details["cache_hit"] = cache_hit self.model_call_details["cache_hit"] = cache_hit
@ -1083,13 +1079,14 @@ class Logging(LiteLLMLoggingBaseClass):
"response_cost" "response_cost"
] ]
else: else:
self.model_call_details["response_cost"] = ( self.model_call_details[
self._response_cost_calculator(result=result) "response_cost"
) ] = self._response_cost_calculator(result=result)
## STANDARDIZED LOGGING PAYLOAD ## STANDARDIZED LOGGING PAYLOAD
self.model_call_details["standard_logging_object"] = ( self.model_call_details[
get_standard_logging_object_payload( "standard_logging_object"
] = get_standard_logging_object_payload(
kwargs=self.model_call_details, kwargs=self.model_call_details,
init_response_obj=result, init_response_obj=result,
start_time=start_time, start_time=start_time,
@ -1098,11 +1095,11 @@ class Logging(LiteLLMLoggingBaseClass):
status="success", status="success",
standard_built_in_tools_params=self.standard_built_in_tools_params, standard_built_in_tools_params=self.standard_built_in_tools_params,
) )
)
elif isinstance(result, dict): # pass-through endpoints elif isinstance(result, dict): # pass-through endpoints
## STANDARDIZED LOGGING PAYLOAD ## STANDARDIZED LOGGING PAYLOAD
self.model_call_details["standard_logging_object"] = ( self.model_call_details[
get_standard_logging_object_payload( "standard_logging_object"
] = get_standard_logging_object_payload(
kwargs=self.model_call_details, kwargs=self.model_call_details,
init_response_obj=result, init_response_obj=result,
start_time=start_time, start_time=start_time,
@ -1111,11 +1108,10 @@ class Logging(LiteLLMLoggingBaseClass):
status="success", status="success",
standard_built_in_tools_params=self.standard_built_in_tools_params, standard_built_in_tools_params=self.standard_built_in_tools_params,
) )
)
elif standard_logging_object is not None: elif standard_logging_object is not None:
self.model_call_details["standard_logging_object"] = ( self.model_call_details[
standard_logging_object "standard_logging_object"
) ] = standard_logging_object
else: # streaming chunks + image gen. else: # streaming chunks + image gen.
self.model_call_details["response_cost"] = None self.model_call_details["response_cost"] = None
@ -1154,7 +1150,6 @@ class Logging(LiteLLMLoggingBaseClass):
standard_logging_object=kwargs.get("standard_logging_object", None), standard_logging_object=kwargs.get("standard_logging_object", None),
) )
try: try:
## BUILD COMPLETE STREAMED RESPONSE ## BUILD COMPLETE STREAMED RESPONSE
complete_streaming_response: Optional[ complete_streaming_response: Optional[
Union[ModelResponse, TextCompletionResponse, ResponsesAPIResponse] Union[ModelResponse, TextCompletionResponse, ResponsesAPIResponse]
@ -1172,15 +1167,16 @@ class Logging(LiteLLMLoggingBaseClass):
verbose_logger.debug( verbose_logger.debug(
"Logging Details LiteLLM-Success Call streaming complete" "Logging Details LiteLLM-Success Call streaming complete"
) )
self.model_call_details["complete_streaming_response"] = ( self.model_call_details[
complete_streaming_response "complete_streaming_response"
) ] = complete_streaming_response
self.model_call_details["response_cost"] = ( self.model_call_details[
self._response_cost_calculator(result=complete_streaming_response) "response_cost"
) ] = self._response_cost_calculator(result=complete_streaming_response)
## STANDARDIZED LOGGING PAYLOAD ## STANDARDIZED LOGGING PAYLOAD
self.model_call_details["standard_logging_object"] = ( self.model_call_details[
get_standard_logging_object_payload( "standard_logging_object"
] = get_standard_logging_object_payload(
kwargs=self.model_call_details, kwargs=self.model_call_details,
init_response_obj=complete_streaming_response, init_response_obj=complete_streaming_response,
start_time=start_time, start_time=start_time,
@ -1189,7 +1185,6 @@ class Logging(LiteLLMLoggingBaseClass):
status="success", status="success",
standard_built_in_tools_params=self.standard_built_in_tools_params, standard_built_in_tools_params=self.standard_built_in_tools_params,
) )
)
callbacks = self.get_combined_callback_list( callbacks = self.get_combined_callback_list(
dynamic_success_callbacks=self.dynamic_success_callbacks, dynamic_success_callbacks=self.dynamic_success_callbacks,
global_callbacks=litellm.success_callback, global_callbacks=litellm.success_callback,
@ -1207,7 +1202,6 @@ class Logging(LiteLLMLoggingBaseClass):
## LOGGING HOOK ## ## LOGGING HOOK ##
for callback in callbacks: for callback in callbacks:
if isinstance(callback, CustomLogger): if isinstance(callback, CustomLogger):
self.model_call_details, result = callback.logging_hook( self.model_call_details, result = callback.logging_hook(
kwargs=self.model_call_details, kwargs=self.model_call_details,
result=result, result=result,
@ -1538,11 +1532,11 @@ class Logging(LiteLLMLoggingBaseClass):
) )
else: else:
if self.stream and complete_streaming_response: if self.stream and complete_streaming_response:
self.model_call_details["complete_response"] = ( self.model_call_details[
self.model_call_details.get( "complete_response"
] = self.model_call_details.get(
"complete_streaming_response", {} "complete_streaming_response", {}
) )
)
result = self.model_call_details["complete_response"] result = self.model_call_details["complete_response"]
openMeterLogger.log_success_event( openMeterLogger.log_success_event(
kwargs=self.model_call_details, kwargs=self.model_call_details,
@ -1581,11 +1575,11 @@ class Logging(LiteLLMLoggingBaseClass):
) )
else: else:
if self.stream and complete_streaming_response: if self.stream and complete_streaming_response:
self.model_call_details["complete_response"] = ( self.model_call_details[
self.model_call_details.get( "complete_response"
] = self.model_call_details.get(
"complete_streaming_response", {} "complete_streaming_response", {}
) )
)
result = self.model_call_details["complete_response"] result = self.model_call_details["complete_response"]
callback.log_success_event( callback.log_success_event(
@ -1659,7 +1653,6 @@ class Logging(LiteLLMLoggingBaseClass):
if self.call_type == CallTypes.aretrieve_batch.value and isinstance( if self.call_type == CallTypes.aretrieve_batch.value and isinstance(
result, LiteLLMBatch result, LiteLLMBatch
): ):
response_cost, batch_usage, batch_models = await _handle_completed_batch( response_cost, batch_usage, batch_models = await _handle_completed_batch(
batch=result, custom_llm_provider=self.custom_llm_provider batch=result, custom_llm_provider=self.custom_llm_provider
) )
@ -1692,9 +1685,9 @@ class Logging(LiteLLMLoggingBaseClass):
if complete_streaming_response is not None: if complete_streaming_response is not None:
print_verbose("Async success callbacks: Got a complete streaming response") print_verbose("Async success callbacks: Got a complete streaming response")
self.model_call_details["async_complete_streaming_response"] = ( self.model_call_details[
complete_streaming_response "async_complete_streaming_response"
) ] = complete_streaming_response
try: try:
if self.model_call_details.get("cache_hit", False) is True: if self.model_call_details.get("cache_hit", False) is True:
self.model_call_details["response_cost"] = 0.0 self.model_call_details["response_cost"] = 0.0
@ -1704,11 +1697,11 @@ class Logging(LiteLLMLoggingBaseClass):
model_call_details=self.model_call_details model_call_details=self.model_call_details
) )
# base_model defaults to None if not set on model_info # base_model defaults to None if not set on model_info
self.model_call_details["response_cost"] = ( self.model_call_details[
self._response_cost_calculator( "response_cost"
] = self._response_cost_calculator(
result=complete_streaming_response result=complete_streaming_response
) )
)
verbose_logger.debug( verbose_logger.debug(
f"Model={self.model}; cost={self.model_call_details['response_cost']}" f"Model={self.model}; cost={self.model_call_details['response_cost']}"
@ -1720,8 +1713,9 @@ class Logging(LiteLLMLoggingBaseClass):
self.model_call_details["response_cost"] = None self.model_call_details["response_cost"] = None
## STANDARDIZED LOGGING PAYLOAD ## STANDARDIZED LOGGING PAYLOAD
self.model_call_details["standard_logging_object"] = ( self.model_call_details[
get_standard_logging_object_payload( "standard_logging_object"
] = get_standard_logging_object_payload(
kwargs=self.model_call_details, kwargs=self.model_call_details,
init_response_obj=complete_streaming_response, init_response_obj=complete_streaming_response,
start_time=start_time, start_time=start_time,
@ -1730,7 +1724,6 @@ class Logging(LiteLLMLoggingBaseClass):
status="success", status="success",
standard_built_in_tools_params=self.standard_built_in_tools_params, standard_built_in_tools_params=self.standard_built_in_tools_params,
) )
)
callbacks = self.get_combined_callback_list( callbacks = self.get_combined_callback_list(
dynamic_success_callbacks=self.dynamic_async_success_callbacks, dynamic_success_callbacks=self.dynamic_async_success_callbacks,
global_callbacks=litellm._async_success_callback, global_callbacks=litellm._async_success_callback,
@ -1935,8 +1928,9 @@ class Logging(LiteLLMLoggingBaseClass):
## STANDARDIZED LOGGING PAYLOAD ## STANDARDIZED LOGGING PAYLOAD
self.model_call_details["standard_logging_object"] = ( self.model_call_details[
get_standard_logging_object_payload( "standard_logging_object"
] = get_standard_logging_object_payload(
kwargs=self.model_call_details, kwargs=self.model_call_details,
init_response_obj={}, init_response_obj={},
start_time=start_time, start_time=start_time,
@ -1947,7 +1941,6 @@ class Logging(LiteLLMLoggingBaseClass):
original_exception=exception, original_exception=exception,
standard_built_in_tools_params=self.standard_built_in_tools_params, standard_built_in_tools_params=self.standard_built_in_tools_params,
) )
)
return start_time, end_time return start_time, end_time
async def special_failure_handlers(self, exception: Exception): async def special_failure_handlers(self, exception: Exception):
@ -2084,7 +2077,6 @@ class Logging(LiteLLMLoggingBaseClass):
) )
is not True is not True
): # custom logger class ): # custom logger class
callback.log_failure_event( callback.log_failure_event(
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
@ -2713,9 +2705,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
endpoint=arize_config.endpoint, endpoint=arize_config.endpoint,
) )
os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = ( os.environ[
f"space_key={arize_config.space_key},api_key={arize_config.api_key}" "OTEL_EXPORTER_OTLP_TRACES_HEADERS"
) ] = f"space_key={arize_config.space_key},api_key={arize_config.api_key}"
for callback in _in_memory_loggers: for callback in _in_memory_loggers:
if ( if (
isinstance(callback, ArizeLogger) isinstance(callback, ArizeLogger)
@ -2739,9 +2731,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
# auth can be disabled on local deployments of arize phoenix # auth can be disabled on local deployments of arize phoenix
if arize_phoenix_config.otlp_auth_headers is not None: if arize_phoenix_config.otlp_auth_headers is not None:
os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = ( os.environ[
arize_phoenix_config.otlp_auth_headers "OTEL_EXPORTER_OTLP_TRACES_HEADERS"
) ] = arize_phoenix_config.otlp_auth_headers
for callback in _in_memory_loggers: for callback in _in_memory_loggers:
if ( if (
@ -2832,9 +2824,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
exporter="otlp_http", exporter="otlp_http",
endpoint="https://langtrace.ai/api/trace", endpoint="https://langtrace.ai/api/trace",
) )
os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = ( os.environ[
f"api_key={os.getenv('LANGTRACE_API_KEY')}" "OTEL_EXPORTER_OTLP_TRACES_HEADERS"
) ] = f"api_key={os.getenv('LANGTRACE_API_KEY')}"
for callback in _in_memory_loggers: for callback in _in_memory_loggers:
if ( if (
isinstance(callback, OpenTelemetry) isinstance(callback, OpenTelemetry)
@ -3223,7 +3215,6 @@ class StandardLoggingPayloadSetup:
custom_llm_provider: Optional[str], custom_llm_provider: Optional[str],
init_response_obj: Union[Any, BaseModel, dict], init_response_obj: Union[Any, BaseModel, dict],
) -> StandardLoggingModelInformation: ) -> StandardLoggingModelInformation:
model_cost_name = _select_model_name_for_cost_calc( model_cost_name = _select_model_name_for_cost_calc(
model=None, model=None,
completion_response=init_response_obj, # type: ignore completion_response=init_response_obj, # type: ignore
@ -3286,7 +3277,6 @@ class StandardLoggingPayloadSetup:
def get_additional_headers( def get_additional_headers(
additiona_headers: Optional[dict], additiona_headers: Optional[dict],
) -> Optional[StandardLoggingAdditionalHeaders]: ) -> Optional[StandardLoggingAdditionalHeaders]:
if additiona_headers is None: if additiona_headers is None:
return None return None
@ -3322,11 +3312,11 @@ class StandardLoggingPayloadSetup:
for key in StandardLoggingHiddenParams.__annotations__.keys(): for key in StandardLoggingHiddenParams.__annotations__.keys():
if key in hidden_params: if key in hidden_params:
if key == "additional_headers": if key == "additional_headers":
clean_hidden_params["additional_headers"] = ( clean_hidden_params[
StandardLoggingPayloadSetup.get_additional_headers( "additional_headers"
] = StandardLoggingPayloadSetup.get_additional_headers(
hidden_params[key] hidden_params[key]
) )
)
else: else:
clean_hidden_params[key] = hidden_params[key] # type: ignore clean_hidden_params[key] = hidden_params[key] # type: ignore
return clean_hidden_params return clean_hidden_params
@ -3463,13 +3453,15 @@ def get_standard_logging_object_payload(
) )
# cleanup timestamps # cleanup timestamps
start_time_float, end_time_float, completion_start_time_float = ( (
StandardLoggingPayloadSetup.cleanup_timestamps( start_time_float,
end_time_float,
completion_start_time_float,
) = StandardLoggingPayloadSetup.cleanup_timestamps(
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
completion_start_time=completion_start_time, completion_start_time=completion_start_time,
) )
)
response_time = StandardLoggingPayloadSetup.get_response_time( response_time = StandardLoggingPayloadSetup.get_response_time(
start_time_float=start_time_float, start_time_float=start_time_float,
end_time_float=end_time_float, end_time_float=end_time_float,
@ -3495,7 +3487,6 @@ def get_standard_logging_object_payload(
saved_cache_cost: float = 0.0 saved_cache_cost: float = 0.0
if cache_hit is True: if cache_hit is True:
id = f"{id}_cache_hit{time.time()}" # do not duplicate the request id id = f"{id}_cache_hit{time.time()}" # do not duplicate the request id
saved_cache_cost = ( saved_cache_cost = (
logging_obj._response_cost_calculator( logging_obj._response_cost_calculator(
@ -3658,9 +3649,9 @@ def scrub_sensitive_keys_in_metadata(litellm_params: Optional[dict]):
): ):
for k, v in metadata["user_api_key_metadata"].items(): for k, v in metadata["user_api_key_metadata"].items():
if k == "logging": # prevent logging user logging keys if k == "logging": # prevent logging user logging keys
cleaned_user_api_key_metadata[k] = ( cleaned_user_api_key_metadata[
"scrubbed_by_litellm_for_sensitive_keys" k
) ] = "scrubbed_by_litellm_for_sensitive_keys"
else: else:
cleaned_user_api_key_metadata[k] = v cleaned_user_api_key_metadata[k] = v

View file

@ -258,14 +258,12 @@ def _extract_reasoning_content(message: dict) -> Tuple[Optional[str], Optional[s
class LiteLLMResponseObjectHandler: class LiteLLMResponseObjectHandler:
@staticmethod @staticmethod
def convert_to_image_response( def convert_to_image_response(
response_object: dict, response_object: dict,
model_response_object: Optional[ImageResponse] = None, model_response_object: Optional[ImageResponse] = None,
hidden_params: Optional[dict] = None, hidden_params: Optional[dict] = None,
) -> ImageResponse: ) -> ImageResponse:
response_object.update({"hidden_params": hidden_params}) response_object.update({"hidden_params": hidden_params})
if model_response_object is None: if model_response_object is None:
@ -481,9 +479,9 @@ def convert_to_model_response_object( # noqa: PLR0915
provider_specific_fields["thinking_blocks"] = thinking_blocks provider_specific_fields["thinking_blocks"] = thinking_blocks
if reasoning_content: if reasoning_content:
provider_specific_fields["reasoning_content"] = ( provider_specific_fields[
reasoning_content "reasoning_content"
) ] = reasoning_content
message = Message( message = Message(
content=content, content=content,

View file

@ -17,7 +17,6 @@ from litellm.types.rerank import RerankRequest
class ModelParamHelper: class ModelParamHelper:
@staticmethod @staticmethod
def get_standard_logging_model_parameters( def get_standard_logging_model_parameters(
model_parameters: dict, model_parameters: dict,

View file

@ -257,7 +257,6 @@ def _insert_assistant_continue_message(
and message.get("role") == "user" # Current is user and message.get("role") == "user" # Current is user
and messages[i + 1].get("role") == "user" and messages[i + 1].get("role") == "user"
): # Next is user ): # Next is user
# Insert assistant message # Insert assistant message
continue_message = ( continue_message = (
assistant_continue_message or DEFAULT_ASSISTANT_CONTINUE_MESSAGE assistant_continue_message or DEFAULT_ASSISTANT_CONTINUE_MESSAGE

View file

@ -1042,11 +1042,11 @@ def convert_to_gemini_tool_call_invoke(
if tool_calls is not None: if tool_calls is not None:
for tool in tool_calls: for tool in tool_calls:
if "function" in tool: if "function" in tool:
gemini_function_call: Optional[VertexFunctionCall] = ( gemini_function_call: Optional[
_gemini_tool_call_invoke_helper( VertexFunctionCall
] = _gemini_tool_call_invoke_helper(
function_call_params=tool["function"] function_call_params=tool["function"]
) )
)
if gemini_function_call is not None: if gemini_function_call is not None:
_parts_list.append( _parts_list.append(
VertexPartType(function_call=gemini_function_call) VertexPartType(function_call=gemini_function_call)
@ -1432,9 +1432,9 @@ def anthropic_messages_pt( # noqa: PLR0915
) )
if "cache_control" in _content_element: if "cache_control" in _content_element:
_anthropic_content_element["cache_control"] = ( _anthropic_content_element[
_content_element["cache_control"] "cache_control"
) ] = _content_element["cache_control"]
user_content.append(_anthropic_content_element) user_content.append(_anthropic_content_element)
elif m.get("type", "") == "text": elif m.get("type", "") == "text":
m = cast(ChatCompletionTextObject, m) m = cast(ChatCompletionTextObject, m)
@ -1466,9 +1466,9 @@ def anthropic_messages_pt( # noqa: PLR0915
) )
if "cache_control" in _content_element: if "cache_control" in _content_element:
_anthropic_content_text_element["cache_control"] = ( _anthropic_content_text_element[
_content_element["cache_control"] "cache_control"
) ] = _content_element["cache_control"]
user_content.append(_anthropic_content_text_element) user_content.append(_anthropic_content_text_element)
@ -1533,7 +1533,6 @@ def anthropic_messages_pt( # noqa: PLR0915
"content" "content"
] # don't pass empty text blocks. anthropic api raises errors. ] # don't pass empty text blocks. anthropic api raises errors.
): ):
_anthropic_text_content_element = AnthropicMessagesTextParam( _anthropic_text_content_element = AnthropicMessagesTextParam(
type="text", type="text",
text=assistant_content_block["content"], text=assistant_content_block["content"],
@ -1569,7 +1568,6 @@ def anthropic_messages_pt( # noqa: PLR0915
msg_i += 1 msg_i += 1
if assistant_content: if assistant_content:
new_messages.append({"role": "assistant", "content": assistant_content}) new_messages.append({"role": "assistant", "content": assistant_content})
if msg_i == init_msg_i: # prevent infinite loops if msg_i == init_msg_i: # prevent infinite loops
@ -2245,7 +2243,6 @@ class BedrockImageProcessor:
@staticmethod @staticmethod
async def get_image_details_async(image_url) -> Tuple[str, str]: async def get_image_details_async(image_url) -> Tuple[str, str]:
try: try:
client = get_async_httpx_client( client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.PromptFactory, llm_provider=httpxSpecialProvider.PromptFactory,
params={"concurrent_limit": 1}, params={"concurrent_limit": 1},
@ -2612,7 +2609,6 @@ def get_user_message_block_or_continue_message(
for item in modified_content_block: for item in modified_content_block:
# Check if the list is empty # Check if the list is empty
if item["type"] == "text": if item["type"] == "text":
if not item["text"].strip(): if not item["text"].strip():
# Replace empty text with continue message # Replace empty text with continue message
_user_continue_message = ChatCompletionUserMessage( _user_continue_message = ChatCompletionUserMessage(
@ -3207,7 +3203,6 @@ def _bedrock_converse_messages_pt( # noqa: PLR0915
assistant_content: List[BedrockContentBlock] = [] assistant_content: List[BedrockContentBlock] = []
## MERGE CONSECUTIVE ASSISTANT CONTENT ## ## MERGE CONSECUTIVE ASSISTANT CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant": while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
assistant_message_block = get_assistant_message_block_or_continue_message( assistant_message_block = get_assistant_message_block_or_continue_message(
message=messages[msg_i], message=messages[msg_i],
assistant_continue_message=assistant_continue_message, assistant_continue_message=assistant_continue_message,
@ -3410,7 +3405,6 @@ def response_schema_prompt(model: str, response_schema: dict) -> str:
{"role": "user", "content": "{}".format(response_schema)} {"role": "user", "content": "{}".format(response_schema)}
] ]
if f"{model}/response_schema_prompt" in litellm.custom_prompt_dict: if f"{model}/response_schema_prompt" in litellm.custom_prompt_dict:
custom_prompt_details = litellm.custom_prompt_dict[ custom_prompt_details = litellm.custom_prompt_dict[
f"{model}/response_schema_prompt" f"{model}/response_schema_prompt"
] # allow user to define custom response schema prompt by model ] # allow user to define custom response schema prompt by model

View file

@ -122,7 +122,6 @@ class RealTimeStreaming:
pass pass
async def bidirectional_forward(self): async def bidirectional_forward(self):
forward_task = asyncio.create_task(self.backend_to_client_send_messages()) forward_task = asyncio.create_task(self.backend_to_client_send_messages())
try: try:
await self.client_ack_messages() await self.client_ack_messages()

View file

@ -135,9 +135,9 @@ def _get_turn_off_message_logging_from_dynamic_params(
handles boolean and string values of `turn_off_message_logging` handles boolean and string values of `turn_off_message_logging`
""" """
standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = ( standard_callback_dynamic_params: Optional[
model_call_details.get("standard_callback_dynamic_params", None) StandardCallbackDynamicParams
) ] = model_call_details.get("standard_callback_dynamic_params", None)
if standard_callback_dynamic_params: if standard_callback_dynamic_params:
_turn_off_message_logging = standard_callback_dynamic_params.get( _turn_off_message_logging = standard_callback_dynamic_params.get(
"turn_off_message_logging" "turn_off_message_logging"

View file

@ -1,4 +1,5 @@
from typing import Any, Dict, Optional, Set from typing import Any, Dict, Optional, Set
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
@ -40,7 +41,10 @@ class SensitiveDataMasker:
return result return result
def mask_dict( def mask_dict(
self, data: Dict[str, Any], depth: int = 0, max_depth: int = DEFAULT_MAX_RECURSE_DEPTH self,
data: Dict[str, Any],
depth: int = 0,
max_depth: int = DEFAULT_MAX_RECURSE_DEPTH,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
if depth >= max_depth: if depth >= max_depth:
return data return data

View file

@ -104,7 +104,6 @@ class ChunkProcessor:
def get_combined_tool_content( def get_combined_tool_content(
self, tool_call_chunks: List[Dict[str, Any]] self, tool_call_chunks: List[Dict[str, Any]]
) -> List[ChatCompletionMessageToolCall]: ) -> List[ChatCompletionMessageToolCall]:
argument_list: List[str] = [] argument_list: List[str] = []
delta = tool_call_chunks[0]["choices"][0]["delta"] delta = tool_call_chunks[0]["choices"][0]["delta"]
id = None id = None

View file

@ -84,9 +84,9 @@ class CustomStreamWrapper:
self.system_fingerprint: Optional[str] = None self.system_fingerprint: Optional[str] = None
self.received_finish_reason: Optional[str] = None self.received_finish_reason: Optional[str] = None
self.intermittent_finish_reason: Optional[str] = ( self.intermittent_finish_reason: Optional[
None # finish reasons that show up mid-stream str
) ] = None # finish reasons that show up mid-stream
self.special_tokens = [ self.special_tokens = [
"<|assistant|>", "<|assistant|>",
"<|system|>", "<|system|>",
@ -814,7 +814,6 @@ class CustomStreamWrapper:
model_response: ModelResponseStream, model_response: ModelResponseStream,
response_obj: Dict[str, Any], response_obj: Dict[str, Any],
): ):
print_verbose( print_verbose(
f"completion_obj: {completion_obj}, model_response.choices[0]: {model_response.choices[0]}, response_obj: {response_obj}" f"completion_obj: {completion_obj}, model_response.choices[0]: {model_response.choices[0]}, response_obj: {response_obj}"
) )
@ -1008,7 +1007,6 @@ class CustomStreamWrapper:
self.custom_llm_provider self.custom_llm_provider
and self.custom_llm_provider in litellm._custom_providers and self.custom_llm_provider in litellm._custom_providers
): ):
if self.received_finish_reason is not None: if self.received_finish_reason is not None:
if "provider_specific_fields" not in chunk: if "provider_specific_fields" not in chunk:
raise StopIteration raise StopIteration
@ -1379,9 +1377,9 @@ class CustomStreamWrapper:
_json_delta = delta.model_dump() _json_delta = delta.model_dump()
print_verbose(f"_json_delta: {_json_delta}") print_verbose(f"_json_delta: {_json_delta}")
if "role" not in _json_delta or _json_delta["role"] is None: if "role" not in _json_delta or _json_delta["role"] is None:
_json_delta["role"] = ( _json_delta[
"assistant" # mistral's api returns role as None "role"
) ] = "assistant" # mistral's api returns role as None
if "tool_calls" in _json_delta and isinstance( if "tool_calls" in _json_delta and isinstance(
_json_delta["tool_calls"], list _json_delta["tool_calls"], list
): ):
@ -1758,9 +1756,9 @@ class CustomStreamWrapper:
chunk = next(self.completion_stream) chunk = next(self.completion_stream)
if chunk is not None and chunk != b"": if chunk is not None and chunk != b"":
print_verbose(f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}") print_verbose(f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}")
processed_chunk: Optional[ModelResponseStream] = ( processed_chunk: Optional[
self.chunk_creator(chunk=chunk) ModelResponseStream
) ] = self.chunk_creator(chunk=chunk)
print_verbose( print_verbose(
f"PROCESSED CHUNK POST CHUNK CREATOR: {processed_chunk}" f"PROCESSED CHUNK POST CHUNK CREATOR: {processed_chunk}"
) )

View file

@ -290,7 +290,6 @@ class AnthropicChatCompletion(BaseLLM):
headers={}, headers={},
client=None, client=None,
): ):
optional_params = copy.deepcopy(optional_params) optional_params = copy.deepcopy(optional_params)
stream = optional_params.pop("stream", None) stream = optional_params.pop("stream", None)
json_mode: bool = optional_params.pop("json_mode", False) json_mode: bool = optional_params.pop("json_mode", False)
@ -491,7 +490,6 @@ class ModelResponseIterator:
def _handle_usage( def _handle_usage(
self, anthropic_usage_chunk: Union[dict, UsageDelta] self, anthropic_usage_chunk: Union[dict, UsageDelta]
) -> AnthropicChatCompletionUsageBlock: ) -> AnthropicChatCompletionUsageBlock:
usage_block = AnthropicChatCompletionUsageBlock( usage_block = AnthropicChatCompletionUsageBlock(
prompt_tokens=anthropic_usage_chunk.get("input_tokens", 0), prompt_tokens=anthropic_usage_chunk.get("input_tokens", 0),
completion_tokens=anthropic_usage_chunk.get("output_tokens", 0), completion_tokens=anthropic_usage_chunk.get("output_tokens", 0),
@ -515,7 +513,9 @@ class ModelResponseIterator:
return usage_block return usage_block
def _content_block_delta_helper(self, chunk: dict) -> Tuple[ def _content_block_delta_helper(
self, chunk: dict
) -> Tuple[
str, str,
Optional[ChatCompletionToolCallChunk], Optional[ChatCompletionToolCallChunk],
List[ChatCompletionThinkingBlock], List[ChatCompletionThinkingBlock],
@ -592,9 +592,12 @@ class ModelResponseIterator:
Anthropic content chunk Anthropic content chunk
chunk = {'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': 'Hello'}} chunk = {'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': 'Hello'}}
""" """
text, tool_use, thinking_blocks, provider_specific_fields = ( (
self._content_block_delta_helper(chunk=chunk) text,
) tool_use,
thinking_blocks,
provider_specific_fields,
) = self._content_block_delta_helper(chunk=chunk)
if thinking_blocks: if thinking_blocks:
reasoning_content = self._handle_reasoning_content( reasoning_content = self._handle_reasoning_content(
thinking_blocks=thinking_blocks thinking_blocks=thinking_blocks
@ -620,7 +623,6 @@ class ModelResponseIterator:
"index": self.tool_index, "index": self.tool_index,
} }
elif type_chunk == "content_block_stop": elif type_chunk == "content_block_stop":
ContentBlockStop(**chunk) # type: ignore ContentBlockStop(**chunk) # type: ignore
# check if tool call content block # check if tool call content block
is_empty = self.check_empty_tool_call_args() is_empty = self.check_empty_tool_call_args()

View file

@ -49,9 +49,9 @@ class AnthropicConfig(BaseConfig):
to pass metadata to anthropic, it's {"user_id": "any-relevant-information"} to pass metadata to anthropic, it's {"user_id": "any-relevant-information"}
""" """
max_tokens: Optional[int] = ( max_tokens: Optional[
4096 # anthropic requires a default value (Opus, Sonnet, and Haiku have the same default) int
) ] = 4096 # anthropic requires a default value (Opus, Sonnet, and Haiku have the same default)
stop_sequences: Optional[list] = None stop_sequences: Optional[list] = None
temperature: Optional[int] = None temperature: Optional[int] = None
top_p: Optional[int] = None top_p: Optional[int] = None
@ -104,7 +104,6 @@ class AnthropicConfig(BaseConfig):
def get_json_schema_from_pydantic_object( def get_json_schema_from_pydantic_object(
self, response_format: Union[Any, Dict, None] self, response_format: Union[Any, Dict, None]
) -> Optional[dict]: ) -> Optional[dict]:
return type_to_response_format_param( return type_to_response_format_param(
response_format, ref_template="/$defs/{model}" response_format, ref_template="/$defs/{model}"
) # Relevant issue: https://github.com/BerriAI/litellm/issues/7755 ) # Relevant issue: https://github.com/BerriAI/litellm/issues/7755
@ -125,7 +124,6 @@ class AnthropicConfig(BaseConfig):
is_vertex_request: bool = False, is_vertex_request: bool = False,
user_anthropic_beta_headers: Optional[List[str]] = None, user_anthropic_beta_headers: Optional[List[str]] = None,
) -> dict: ) -> dict:
betas = set() betas = set()
if prompt_caching_set: if prompt_caching_set:
betas.add("prompt-caching-2024-07-31") betas.add("prompt-caching-2024-07-31")
@ -300,7 +298,6 @@ class AnthropicConfig(BaseConfig):
model: str, model: str,
drop_params: bool, drop_params: bool,
) -> dict: ) -> dict:
is_thinking_enabled = self.is_thinking_enabled( is_thinking_enabled = self.is_thinking_enabled(
non_default_params=non_default_params non_default_params=non_default_params
) )
@ -321,12 +318,12 @@ class AnthropicConfig(BaseConfig):
optional_params=optional_params, tools=tool_value optional_params=optional_params, tools=tool_value
) )
if param == "tool_choice" or param == "parallel_tool_calls": if param == "tool_choice" or param == "parallel_tool_calls":
_tool_choice: Optional[AnthropicMessagesToolChoice] = ( _tool_choice: Optional[
self._map_tool_choice( AnthropicMessagesToolChoice
] = self._map_tool_choice(
tool_choice=non_default_params.get("tool_choice"), tool_choice=non_default_params.get("tool_choice"),
parallel_tool_use=non_default_params.get("parallel_tool_calls"), parallel_tool_use=non_default_params.get("parallel_tool_calls"),
) )
)
if _tool_choice is not None: if _tool_choice is not None:
optional_params["tool_choice"] = _tool_choice optional_params["tool_choice"] = _tool_choice
@ -341,7 +338,6 @@ class AnthropicConfig(BaseConfig):
if param == "top_p": if param == "top_p":
optional_params["top_p"] = value optional_params["top_p"] = value
if param == "response_format" and isinstance(value, dict): if param == "response_format" and isinstance(value, dict):
ignore_response_format_types = ["text"] ignore_response_format_types = ["text"]
if value["type"] in ignore_response_format_types: # value is a no-op if value["type"] in ignore_response_format_types: # value is a no-op
continue continue
@ -470,9 +466,9 @@ class AnthropicConfig(BaseConfig):
text=system_message_block["content"], text=system_message_block["content"],
) )
if "cache_control" in system_message_block: if "cache_control" in system_message_block:
anthropic_system_message_content["cache_control"] = ( anthropic_system_message_content[
system_message_block["cache_control"] "cache_control"
) ] = system_message_block["cache_control"]
anthropic_system_message_list.append( anthropic_system_message_list.append(
anthropic_system_message_content anthropic_system_message_content
) )
@ -486,9 +482,9 @@ class AnthropicConfig(BaseConfig):
) )
) )
if "cache_control" in _content: if "cache_control" in _content:
anthropic_system_message_content["cache_control"] = ( anthropic_system_message_content[
_content["cache_control"] "cache_control"
) ] = _content["cache_control"]
anthropic_system_message_list.append( anthropic_system_message_list.append(
anthropic_system_message_content anthropic_system_message_content
@ -597,7 +593,9 @@ class AnthropicConfig(BaseConfig):
) )
return _message return _message
def extract_response_content(self, completion_response: dict) -> Tuple[ def extract_response_content(
self, completion_response: dict
) -> Tuple[
str, str,
Optional[List[Any]], Optional[List[Any]],
Optional[List[ChatCompletionThinkingBlock]], Optional[List[ChatCompletionThinkingBlock]],
@ -693,9 +691,13 @@ class AnthropicConfig(BaseConfig):
reasoning_content: Optional[str] = None reasoning_content: Optional[str] = None
tool_calls: List[ChatCompletionToolCallChunk] = [] tool_calls: List[ChatCompletionToolCallChunk] = []
text_content, citations, thinking_blocks, reasoning_content, tool_calls = ( (
self.extract_response_content(completion_response=completion_response) text_content,
) citations,
thinking_blocks,
reasoning_content,
tool_calls,
) = self.extract_response_content(completion_response=completion_response)
_message = litellm.Message( _message = litellm.Message(
tool_calls=tool_calls, tool_calls=tool_calls,

View file

@ -54,9 +54,9 @@ class AnthropicTextConfig(BaseConfig):
to pass metadata to anthropic, it's {"user_id": "any-relevant-information"} to pass metadata to anthropic, it's {"user_id": "any-relevant-information"}
""" """
max_tokens_to_sample: Optional[int] = ( max_tokens_to_sample: Optional[
litellm.max_tokens int
) # anthropic requires a default ] = litellm.max_tokens # anthropic requires a default
stop_sequences: Optional[list] = None stop_sequences: Optional[list] = None
temperature: Optional[int] = None temperature: Optional[int] = None
top_p: Optional[int] = None top_p: Optional[int] = None

View file

@ -25,7 +25,6 @@ from litellm.utils import ProviderConfigManager, client
class AnthropicMessagesHandler: class AnthropicMessagesHandler:
@staticmethod @staticmethod
async def _handle_anthropic_streaming( async def _handle_anthropic_streaming(
response: httpx.Response, response: httpx.Response,
@ -74,20 +73,23 @@ async def anthropic_messages(
""" """
# Use provided client or create a new one # Use provided client or create a new one
optional_params = GenericLiteLLMParams(**kwargs) optional_params = GenericLiteLLMParams(**kwargs)
model, _custom_llm_provider, dynamic_api_key, dynamic_api_base = ( (
litellm.get_llm_provider( model,
_custom_llm_provider,
dynamic_api_key,
dynamic_api_base,
) = litellm.get_llm_provider(
model=model, model=model,
custom_llm_provider=custom_llm_provider, custom_llm_provider=custom_llm_provider,
api_base=optional_params.api_base, api_base=optional_params.api_base,
api_key=optional_params.api_key, api_key=optional_params.api_key,
) )
) anthropic_messages_provider_config: Optional[
anthropic_messages_provider_config: Optional[BaseAnthropicMessagesConfig] = ( BaseAnthropicMessagesConfig
ProviderConfigManager.get_provider_anthropic_messages_config( ] = ProviderConfigManager.get_provider_anthropic_messages_config(
model=model, model=model,
provider=litellm.LlmProviders(_custom_llm_provider), provider=litellm.LlmProviders(_custom_llm_provider),
) )
)
if anthropic_messages_provider_config is None: if anthropic_messages_provider_config is None:
raise ValueError( raise ValueError(
f"Anthropic messages provider config not found for model: {model}" f"Anthropic messages provider config not found for model: {model}"

View file

@ -654,7 +654,6 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
) -> EmbeddingResponse: ) -> EmbeddingResponse:
response = None response = None
try: try:
openai_aclient = self.get_azure_openai_client( openai_aclient = self.get_azure_openai_client(
api_version=api_version, api_version=api_version,
api_base=api_base, api_base=api_base,
@ -835,7 +834,6 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
"2023-10-01-preview", "2023-10-01-preview",
] ]
): # CREATE + POLL for azure dall-e-2 calls ): # CREATE + POLL for azure dall-e-2 calls
api_base = modify_url( api_base = modify_url(
original_url=api_base, new_path="/openai/images/generations:submit" original_url=api_base, new_path="/openai/images/generations:submit"
) )
@ -867,7 +865,6 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
) )
while response.json()["status"] not in ["succeeded", "failed"]: while response.json()["status"] not in ["succeeded", "failed"]:
if time.time() - start_time > timeout_secs: if time.time() - start_time > timeout_secs:
raise AzureOpenAIError( raise AzureOpenAIError(
status_code=408, message="Operation polling timed out." status_code=408, message="Operation polling timed out."
) )
@ -935,7 +932,6 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
"2023-10-01-preview", "2023-10-01-preview",
] ]
): # CREATE + POLL for azure dall-e-2 calls ): # CREATE + POLL for azure dall-e-2 calls
api_base = modify_url( api_base = modify_url(
original_url=api_base, new_path="/openai/images/generations:submit" original_url=api_base, new_path="/openai/images/generations:submit"
) )
@ -1199,7 +1195,6 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
client=None, client=None,
litellm_params: Optional[dict] = None, litellm_params: Optional[dict] = None,
) -> HttpxBinaryResponseContent: ) -> HttpxBinaryResponseContent:
max_retries = optional_params.pop("max_retries", 2) max_retries = optional_params.pop("max_retries", 2)
if aspeech is not None and aspeech is True: if aspeech is not None and aspeech is True:
@ -1253,7 +1248,6 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
client=None, client=None,
litellm_params: Optional[dict] = None, litellm_params: Optional[dict] = None,
) -> HttpxBinaryResponseContent: ) -> HttpxBinaryResponseContent:
azure_client: AsyncAzureOpenAI = self.get_azure_openai_client( azure_client: AsyncAzureOpenAI = self.get_azure_openai_client(
api_base=api_base, api_base=api_base,
api_version=api_version, api_version=api_version,

View file

@ -50,8 +50,9 @@ class AzureBatchesAPI(BaseAzureLLM):
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
litellm_params: Optional[dict] = None, litellm_params: Optional[dict] = None,
) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]: ) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( azure_client: Optional[
self.get_azure_openai_client( Union[AzureOpenAI, AsyncAzureOpenAI]
] = self.get_azure_openai_client(
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
api_version=api_version, api_version=api_version,
@ -59,7 +60,6 @@ class AzureBatchesAPI(BaseAzureLLM):
_is_async=_is_async, _is_async=_is_async,
litellm_params=litellm_params or {}, litellm_params=litellm_params or {},
) )
)
if azure_client is None: if azure_client is None:
raise ValueError( raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
@ -96,8 +96,9 @@ class AzureBatchesAPI(BaseAzureLLM):
client: Optional[AzureOpenAI] = None, client: Optional[AzureOpenAI] = None,
litellm_params: Optional[dict] = None, litellm_params: Optional[dict] = None,
): ):
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( azure_client: Optional[
self.get_azure_openai_client( Union[AzureOpenAI, AsyncAzureOpenAI]
] = self.get_azure_openai_client(
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
api_version=api_version, api_version=api_version,
@ -105,7 +106,6 @@ class AzureBatchesAPI(BaseAzureLLM):
_is_async=_is_async, _is_async=_is_async,
litellm_params=litellm_params or {}, litellm_params=litellm_params or {},
) )
)
if azure_client is None: if azure_client is None:
raise ValueError( raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
@ -144,8 +144,9 @@ class AzureBatchesAPI(BaseAzureLLM):
client: Optional[AzureOpenAI] = None, client: Optional[AzureOpenAI] = None,
litellm_params: Optional[dict] = None, litellm_params: Optional[dict] = None,
): ):
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( azure_client: Optional[
self.get_azure_openai_client( Union[AzureOpenAI, AsyncAzureOpenAI]
] = self.get_azure_openai_client(
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
api_version=api_version, api_version=api_version,
@ -153,7 +154,6 @@ class AzureBatchesAPI(BaseAzureLLM):
_is_async=_is_async, _is_async=_is_async,
litellm_params=litellm_params or {}, litellm_params=litellm_params or {},
) )
)
if azure_client is None: if azure_client is None:
raise ValueError( raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
@ -183,8 +183,9 @@ class AzureBatchesAPI(BaseAzureLLM):
client: Optional[AzureOpenAI] = None, client: Optional[AzureOpenAI] = None,
litellm_params: Optional[dict] = None, litellm_params: Optional[dict] = None,
): ):
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( azure_client: Optional[
self.get_azure_openai_client( Union[AzureOpenAI, AsyncAzureOpenAI]
] = self.get_azure_openai_client(
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
api_version=api_version, api_version=api_version,
@ -192,7 +193,6 @@ class AzureBatchesAPI(BaseAzureLLM):
_is_async=_is_async, _is_async=_is_async,
litellm_params=litellm_params or {}, litellm_params=litellm_params or {},
) )
)
if azure_client is None: if azure_client is None:
raise ValueError( raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."

View file

@ -306,7 +306,6 @@ class BaseAzureLLM(BaseOpenAILLM):
api_version: Optional[str], api_version: Optional[str],
is_async: bool, is_async: bool,
) -> dict: ) -> dict:
azure_ad_token_provider: Optional[Callable[[], str]] = None azure_ad_token_provider: Optional[Callable[[], str]] = None
# If we have api_key, then we have higher priority # If we have api_key, then we have higher priority
azure_ad_token = litellm_params.get("azure_ad_token") azure_ad_token = litellm_params.get("azure_ad_token")

View file

@ -46,9 +46,9 @@ class AzureOpenAIFilesAPI(BaseAzureLLM):
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
litellm_params: Optional[dict] = None, litellm_params: Optional[dict] = None,
) -> Union[FileObject, Coroutine[Any, Any, FileObject]]: ) -> Union[FileObject, Coroutine[Any, Any, FileObject]]:
openai_client: Optional[
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( Union[AzureOpenAI, AsyncAzureOpenAI]
self.get_azure_openai_client( ] = self.get_azure_openai_client(
litellm_params=litellm_params or {}, litellm_params=litellm_params or {},
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
@ -56,7 +56,6 @@ class AzureOpenAIFilesAPI(BaseAzureLLM):
client=client, client=client,
_is_async=_is_async, _is_async=_is_async,
) )
)
if openai_client is None: if openai_client is None:
raise ValueError( raise ValueError(
"AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." "AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
@ -95,8 +94,9 @@ class AzureOpenAIFilesAPI(BaseAzureLLM):
) -> Union[ ) -> Union[
HttpxBinaryResponseContent, Coroutine[Any, Any, HttpxBinaryResponseContent] HttpxBinaryResponseContent, Coroutine[Any, Any, HttpxBinaryResponseContent]
]: ]:
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( openai_client: Optional[
self.get_azure_openai_client( Union[AzureOpenAI, AsyncAzureOpenAI]
] = self.get_azure_openai_client(
litellm_params=litellm_params or {}, litellm_params=litellm_params or {},
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
@ -104,7 +104,6 @@ class AzureOpenAIFilesAPI(BaseAzureLLM):
client=client, client=client,
_is_async=_is_async, _is_async=_is_async,
) )
)
if openai_client is None: if openai_client is None:
raise ValueError( raise ValueError(
"AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." "AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
@ -145,8 +144,9 @@ class AzureOpenAIFilesAPI(BaseAzureLLM):
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
litellm_params: Optional[dict] = None, litellm_params: Optional[dict] = None,
): ):
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( openai_client: Optional[
self.get_azure_openai_client( Union[AzureOpenAI, AsyncAzureOpenAI]
] = self.get_azure_openai_client(
litellm_params=litellm_params or {}, litellm_params=litellm_params or {},
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
@ -154,7 +154,6 @@ class AzureOpenAIFilesAPI(BaseAzureLLM):
client=client, client=client,
_is_async=_is_async, _is_async=_is_async,
) )
)
if openai_client is None: if openai_client is None:
raise ValueError( raise ValueError(
"AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." "AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
@ -197,8 +196,9 @@ class AzureOpenAIFilesAPI(BaseAzureLLM):
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
litellm_params: Optional[dict] = None, litellm_params: Optional[dict] = None,
): ):
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( openai_client: Optional[
self.get_azure_openai_client( Union[AzureOpenAI, AsyncAzureOpenAI]
] = self.get_azure_openai_client(
litellm_params=litellm_params or {}, litellm_params=litellm_params or {},
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
@ -206,7 +206,6 @@ class AzureOpenAIFilesAPI(BaseAzureLLM):
client=client, client=client,
_is_async=_is_async, _is_async=_is_async,
) )
)
if openai_client is None: if openai_client is None:
raise ValueError( raise ValueError(
"AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." "AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
@ -251,8 +250,9 @@ class AzureOpenAIFilesAPI(BaseAzureLLM):
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
litellm_params: Optional[dict] = None, litellm_params: Optional[dict] = None,
): ):
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( openai_client: Optional[
self.get_azure_openai_client( Union[AzureOpenAI, AsyncAzureOpenAI]
] = self.get_azure_openai_client(
litellm_params=litellm_params or {}, litellm_params=litellm_params or {},
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
@ -260,7 +260,6 @@ class AzureOpenAIFilesAPI(BaseAzureLLM):
client=client, client=client,
_is_async=_is_async, _is_async=_is_async,
) )
)
if openai_client is None: if openai_client is None:
raise ValueError( raise ValueError(
"AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." "AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."

View file

@ -25,14 +25,7 @@ class AzureOpenAIFineTuningAPI(OpenAIFineTuningAPI, BaseAzureLLM):
_is_async: bool = False, _is_async: bool = False,
api_version: Optional[str] = None, api_version: Optional[str] = None,
litellm_params: Optional[dict] = None, litellm_params: Optional[dict] = None,
) -> Optional[ ) -> Optional[Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI,]]:
Union[
OpenAI,
AsyncOpenAI,
AzureOpenAI,
AsyncAzureOpenAI,
]
]:
# Override to use Azure-specific client initialization # Override to use Azure-specific client initialization
if isinstance(client, OpenAI) or isinstance(client, AsyncOpenAI): if isinstance(client, OpenAI) or isinstance(client, AsyncOpenAI):
client = None client = None

View file

@ -145,7 +145,6 @@ class AzureAIStudioConfig(OpenAIConfig):
2. If message contains an image or audio, send as is (user-intended) 2. If message contains an image or audio, send as is (user-intended)
""" """
for message in messages: for message in messages:
# Do nothing if the message contains an image or audio # Do nothing if the message contains an image or audio
if _audio_or_image_in_message_content(message): if _audio_or_image_in_message_content(message):
continue continue

View file

@ -22,7 +22,6 @@ class AzureAICohereConfig:
pass pass
def _map_azure_model_group(self, model: str) -> str: def _map_azure_model_group(self, model: str) -> str:
if model == "offer-cohere-embed-multili-paygo": if model == "offer-cohere-embed-multili-paygo":
return "Cohere-embed-v3-multilingual" return "Cohere-embed-v3-multilingual"
elif model == "offer-cohere-embed-english-paygo": elif model == "offer-cohere-embed-english-paygo":

View file

@ -17,7 +17,6 @@ from .cohere_transformation import AzureAICohereConfig
class AzureAIEmbedding(OpenAIChatCompletion): class AzureAIEmbedding(OpenAIChatCompletion):
def _process_response( def _process_response(
self, self,
image_embedding_responses: Optional[List], image_embedding_responses: Optional[List],
@ -145,7 +144,6 @@ class AzureAIEmbedding(OpenAIChatCompletion):
api_base: Optional[str] = None, api_base: Optional[str] = None,
client=None, client=None,
) -> EmbeddingResponse: ) -> EmbeddingResponse:
( (
image_embeddings_request, image_embeddings_request,
v1_embeddings_request, v1_embeddings_request,

View file

@ -17,6 +17,7 @@ class AzureAIRerankConfig(CohereRerankConfig):
""" """
Azure AI Rerank - Follows the same Spec as Cohere Rerank Azure AI Rerank - Follows the same Spec as Cohere Rerank
""" """
def get_complete_url(self, api_base: Optional[str], model: str) -> str: def get_complete_url(self, api_base: Optional[str], model: str) -> str:
if api_base is None: if api_base is None:
raise ValueError( raise ValueError(

View file

@ -9,7 +9,6 @@ from litellm.types.utils import ModelResponse, TextCompletionResponse
class BaseLLM: class BaseLLM:
_client_session: Optional[httpx.Client] = None _client_session: Optional[httpx.Client] = None
def process_response( def process_response(

View file

@ -218,7 +218,6 @@ class BaseConfig(ABC):
json_schema = value["json_schema"]["schema"] json_schema = value["json_schema"]["schema"]
if json_schema and not is_response_format_supported: if json_schema and not is_response_format_supported:
_tool_choice = ChatCompletionToolChoiceObjectParam( _tool_choice = ChatCompletionToolChoiceObjectParam(
type="function", type="function",
function=ChatCompletionToolChoiceFunctionParam( function=ChatCompletionToolChoiceFunctionParam(

View file

@ -58,7 +58,6 @@ class BaseResponsesAPIConfig(ABC):
model: str, model: str,
drop_params: bool, drop_params: bool,
) -> Dict: ) -> Dict:
pass pass
@abstractmethod @abstractmethod

View file

@ -81,7 +81,6 @@ def make_sync_call(
class BedrockConverseLLM(BaseAWSLLM): class BedrockConverseLLM(BaseAWSLLM):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
@ -114,7 +113,6 @@ class BedrockConverseLLM(BaseAWSLLM):
fake_stream: bool = False, fake_stream: bool = False,
json_mode: Optional[bool] = False, json_mode: Optional[bool] = False,
) -> CustomStreamWrapper: ) -> CustomStreamWrapper:
request_data = await litellm.AmazonConverseConfig()._async_transform_request( request_data = await litellm.AmazonConverseConfig()._async_transform_request(
model=model, model=model,
messages=messages, messages=messages,
@ -179,7 +177,6 @@ class BedrockConverseLLM(BaseAWSLLM):
headers: dict = {}, headers: dict = {},
client: Optional[AsyncHTTPHandler] = None, client: Optional[AsyncHTTPHandler] = None,
) -> Union[ModelResponse, CustomStreamWrapper]: ) -> Union[ModelResponse, CustomStreamWrapper]:
request_data = await litellm.AmazonConverseConfig()._async_transform_request( request_data = await litellm.AmazonConverseConfig()._async_transform_request(
model=model, model=model,
messages=messages, messages=messages,
@ -265,7 +262,6 @@ class BedrockConverseLLM(BaseAWSLLM):
extra_headers: Optional[dict] = None, extra_headers: Optional[dict] = None,
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None, client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
): ):
## SETUP ## ## SETUP ##
stream = optional_params.pop("stream", None) stream = optional_params.pop("stream", None)
unencoded_model_id = optional_params.pop("model_id", None) unencoded_model_id = optional_params.pop("model_id", None)
@ -301,9 +297,9 @@ class BedrockConverseLLM(BaseAWSLLM):
aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None) aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None)
optional_params.pop("aws_region_name", None) optional_params.pop("aws_region_name", None)
litellm_params["aws_region_name"] = ( litellm_params[
aws_region_name # [DO NOT DELETE] important for async calls "aws_region_name"
) ] = aws_region_name # [DO NOT DELETE] important for async calls
credentials: Credentials = self.get_credentials( credentials: Credentials = self.get_credentials(
aws_access_key_id=aws_access_key_id, aws_access_key_id=aws_access_key_id,

View file

@ -223,7 +223,6 @@ class AmazonConverseConfig(BaseConfig):
) )
for param, value in non_default_params.items(): for param, value in non_default_params.items():
if param == "response_format" and isinstance(value, dict): if param == "response_format" and isinstance(value, dict):
ignore_response_format_types = ["text"] ignore_response_format_types = ["text"]
if value["type"] in ignore_response_format_types: # value is a no-op if value["type"] in ignore_response_format_types: # value is a no-op
continue continue
@ -715,9 +714,9 @@ class AmazonConverseConfig(BaseConfig):
chat_completion_message: ChatCompletionResponseMessage = {"role": "assistant"} chat_completion_message: ChatCompletionResponseMessage = {"role": "assistant"}
content_str = "" content_str = ""
tools: List[ChatCompletionToolCallChunk] = [] tools: List[ChatCompletionToolCallChunk] = []
reasoningContentBlocks: Optional[List[BedrockConverseReasoningContentBlock]] = ( reasoningContentBlocks: Optional[
None List[BedrockConverseReasoningContentBlock]
) ] = None
if message is not None: if message is not None:
for idx, content in enumerate(message["content"]): for idx, content in enumerate(message["content"]):
@ -727,7 +726,6 @@ class AmazonConverseConfig(BaseConfig):
if "text" in content: if "text" in content:
content_str += content["text"] content_str += content["text"]
if "toolUse" in content: if "toolUse" in content:
## check tool name was formatted by litellm ## check tool name was formatted by litellm
_response_tool_name = content["toolUse"]["name"] _response_tool_name = content["toolUse"]["name"]
response_tool_name = get_bedrock_tool_name( response_tool_name = get_bedrock_tool_name(
@ -754,12 +752,12 @@ class AmazonConverseConfig(BaseConfig):
chat_completion_message["provider_specific_fields"] = { chat_completion_message["provider_specific_fields"] = {
"reasoningContentBlocks": reasoningContentBlocks, "reasoningContentBlocks": reasoningContentBlocks,
} }
chat_completion_message["reasoning_content"] = ( chat_completion_message[
self._transform_reasoning_content(reasoningContentBlocks) "reasoning_content"
) ] = self._transform_reasoning_content(reasoningContentBlocks)
chat_completion_message["thinking_blocks"] = ( chat_completion_message[
self._transform_thinking_blocks(reasoningContentBlocks) "thinking_blocks"
) ] = self._transform_thinking_blocks(reasoningContentBlocks)
chat_completion_message["content"] = content_str chat_completion_message["content"] = content_str
if json_mode is True and tools is not None and len(tools) == 1: if json_mode is True and tools is not None and len(tools) == 1:
# to support 'json_schema' logic on bedrock models # to support 'json_schema' logic on bedrock models

View file

@ -496,9 +496,9 @@ class BedrockLLM(BaseAWSLLM):
content=None, content=None,
) )
model_response.choices[0].message = _message # type: ignore model_response.choices[0].message = _message # type: ignore
model_response._hidden_params["original_response"] = ( model_response._hidden_params[
outputText # allow user to access raw anthropic tool calling response "original_response"
) ] = outputText # allow user to access raw anthropic tool calling response
if ( if (
_is_function_call is True _is_function_call is True
and stream is not None and stream is not None
@ -806,9 +806,9 @@ class BedrockLLM(BaseAWSLLM):
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v inference_params[k] = v
if stream is True: if stream is True:
inference_params["stream"] = ( inference_params[
True # cohere requires stream = True in inference params "stream"
) ] = True # cohere requires stream = True in inference params
data = json.dumps({"prompt": prompt, **inference_params}) data = json.dumps({"prompt": prompt, **inference_params})
elif provider == "anthropic": elif provider == "anthropic":
if model.startswith("anthropic.claude-3"): if model.startswith("anthropic.claude-3"):
@ -1205,7 +1205,6 @@ class BedrockLLM(BaseAWSLLM):
def get_response_stream_shape(): def get_response_stream_shape():
global _response_stream_shape_cache global _response_stream_shape_cache
if _response_stream_shape_cache is None: if _response_stream_shape_cache is None:
from botocore.loaders import Loader from botocore.loaders import Loader
from botocore.model import ServiceModel from botocore.model import ServiceModel
@ -1539,7 +1538,6 @@ class AmazonDeepSeekR1StreamDecoder(AWSEventStreamDecoder):
model: str, model: str,
sync_stream: bool, sync_stream: bool,
) -> None: ) -> None:
super().__init__(model=model) super().__init__(model=model)
from litellm.llms.bedrock.chat.invoke_transformations.amazon_deepseek_transformation import ( from litellm.llms.bedrock.chat.invoke_transformations.amazon_deepseek_transformation import (
AmazonDeepseekR1ResponseIterator, AmazonDeepseekR1ResponseIterator,

View file

@ -225,9 +225,9 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v inference_params[k] = v
if stream is True: if stream is True:
inference_params["stream"] = ( inference_params[
True # cohere requires stream = True in inference params "stream"
) ] = True # cohere requires stream = True in inference params
request_data = {"prompt": prompt, **inference_params} request_data = {"prompt": prompt, **inference_params}
elif provider == "anthropic": elif provider == "anthropic":
return litellm.AmazonAnthropicClaude3Config().transform_request( return litellm.AmazonAnthropicClaude3Config().transform_request(
@ -311,7 +311,6 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
api_key: Optional[str] = None, api_key: Optional[str] = None,
json_mode: Optional[bool] = None, json_mode: Optional[bool] = None,
) -> ModelResponse: ) -> ModelResponse:
try: try:
completion_response = raw_response.json() completion_response = raw_response.json()
except Exception: except Exception:

View file

@ -314,7 +314,6 @@ def get_bedrock_tool_name(response_tool_name: str) -> str:
class BedrockModelInfo(BaseLLMModelInfo): class BedrockModelInfo(BaseLLMModelInfo):
global_config = AmazonBedrockGlobalConfig() global_config = AmazonBedrockGlobalConfig()
all_global_regions = global_config.get_all_regions() all_global_regions = global_config.get_all_regions()

View file

@ -33,9 +33,9 @@ class AmazonTitanMultimodalEmbeddingG1Config:
) -> dict: ) -> dict:
for k, v in non_default_params.items(): for k, v in non_default_params.items():
if k == "dimensions": if k == "dimensions":
optional_params["embeddingConfig"] = ( optional_params[
AmazonTitanMultimodalEmbeddingConfig(outputEmbeddingLength=v) "embeddingConfig"
) ] = AmazonTitanMultimodalEmbeddingConfig(outputEmbeddingLength=v)
return optional_params return optional_params
def _transform_request( def _transform_request(
@ -58,7 +58,6 @@ class AmazonTitanMultimodalEmbeddingG1Config:
def _transform_response( def _transform_response(
self, response_list: List[dict], model: str self, response_list: List[dict], model: str
) -> EmbeddingResponse: ) -> EmbeddingResponse:
total_prompt_tokens = 0 total_prompt_tokens = 0
transformed_responses: List[Embedding] = [] transformed_responses: List[Embedding] = []
for index, response in enumerate(response_list): for index, response in enumerate(response_list):

View file

@ -1,12 +1,16 @@
import types import types
from typing import List, Optional from typing import Any, Dict, List, Optional
from openai.types.image import Image from openai.types.image import Image
from litellm.types.llms.bedrock import ( from litellm.types.llms.bedrock import (
AmazonNovaCanvasTextToImageRequest, AmazonNovaCanvasTextToImageResponse, AmazonNovaCanvasColorGuidedGenerationParams,
AmazonNovaCanvasTextToImageParams, AmazonNovaCanvasRequestBase, AmazonNovaCanvasColorGuidedGenerationParams,
AmazonNovaCanvasColorGuidedRequest, AmazonNovaCanvasColorGuidedRequest,
AmazonNovaCanvasImageGenerationConfig,
AmazonNovaCanvasRequestBase,
AmazonNovaCanvasTextToImageParams,
AmazonNovaCanvasTextToImageRequest,
AmazonNovaCanvasTextToImageResponse,
) )
from litellm.types.utils import ImageResponse from litellm.types.utils import ImageResponse
@ -37,8 +41,7 @@ class AmazonNovaCanvasConfig:
@classmethod @classmethod
def get_supported_openai_params(cls, model: Optional[str] = None) -> List: def get_supported_openai_params(cls, model: Optional[str] = None) -> List:
""" """ """
"""
return ["n", "size", "quality"] return ["n", "size", "quality"]
@classmethod @classmethod
@ -65,18 +68,64 @@ class AmazonNovaCanvasConfig:
image_generation_config = optional_params.pop("imageGenerationConfig", {}) image_generation_config = optional_params.pop("imageGenerationConfig", {})
image_generation_config = {**image_generation_config, **optional_params} image_generation_config = {**image_generation_config, **optional_params}
if task_type == "TEXT_IMAGE": if task_type == "TEXT_IMAGE":
text_to_image_params = image_generation_config.pop("textToImageParams", {}) text_to_image_params: Dict[str, Any] = image_generation_config.pop(
text_to_image_params = {"text" :text, **text_to_image_params} "textToImageParams", {}
text_to_image_params = AmazonNovaCanvasTextToImageParams(**text_to_image_params) )
return AmazonNovaCanvasTextToImageRequest(textToImageParams=text_to_image_params, taskType=task_type, text_to_image_params = {"text": text, **text_to_image_params}
imageGenerationConfig=image_generation_config) try:
text_to_image_params_typed = AmazonNovaCanvasTextToImageParams(
**text_to_image_params # type: ignore
)
except Exception as e:
raise ValueError(
f"Error transforming text to image params: {e}. Got params: {text_to_image_params}, Expected params: {AmazonNovaCanvasTextToImageParams.__annotations__}"
)
try:
image_generation_config_typed = AmazonNovaCanvasImageGenerationConfig(
**image_generation_config
)
except Exception as e:
raise ValueError(
f"Error transforming image generation config: {e}. Got params: {image_generation_config}, Expected params: {AmazonNovaCanvasImageGenerationConfig.__annotations__}"
)
return AmazonNovaCanvasTextToImageRequest(
textToImageParams=text_to_image_params_typed,
taskType=task_type,
imageGenerationConfig=image_generation_config_typed,
)
if task_type == "COLOR_GUIDED_GENERATION": if task_type == "COLOR_GUIDED_GENERATION":
color_guided_generation_params = image_generation_config.pop("colorGuidedGenerationParams", {}) color_guided_generation_params: Dict[
color_guided_generation_params = {"text": text, **color_guided_generation_params} str, Any
color_guided_generation_params = AmazonNovaCanvasColorGuidedGenerationParams(**color_guided_generation_params) ] = image_generation_config.pop("colorGuidedGenerationParams", {})
return AmazonNovaCanvasColorGuidedRequest(taskType=task_type, color_guided_generation_params = {
colorGuidedGenerationParams=color_guided_generation_params, "text": text,
imageGenerationConfig=image_generation_config) **color_guided_generation_params,
}
try:
color_guided_generation_params_typed = AmazonNovaCanvasColorGuidedGenerationParams(
**color_guided_generation_params # type: ignore
)
except Exception as e:
raise ValueError(
f"Error transforming color guided generation params: {e}. Got params: {color_guided_generation_params}, Expected params: {AmazonNovaCanvasColorGuidedGenerationParams.__annotations__}"
)
try:
image_generation_config_typed = AmazonNovaCanvasImageGenerationConfig(
**image_generation_config
)
except Exception as e:
raise ValueError(
f"Error transforming image generation config: {e}. Got params: {image_generation_config}, Expected params: {AmazonNovaCanvasImageGenerationConfig.__annotations__}"
)
return AmazonNovaCanvasColorGuidedRequest(
taskType=task_type,
colorGuidedGenerationParams=color_guided_generation_params_typed,
imageGenerationConfig=image_generation_config_typed,
)
raise NotImplementedError(f"Task type {task_type} is not supported") raise NotImplementedError(f"Task type {task_type} is not supported")
@classmethod @classmethod
@ -87,7 +136,9 @@ class AmazonNovaCanvasConfig:
_size = non_default_params.get("size") _size = non_default_params.get("size")
if _size is not None: if _size is not None:
width, height = _size.split("x") width, height = _size.split("x")
optional_params["width"], optional_params["height"] = int(width), int(height) optional_params["width"], optional_params["height"] = int(width), int(
height
)
if non_default_params.get("n") is not None: if non_default_params.get("n") is not None:
optional_params["numberOfImages"] = non_default_params.get("n") optional_params["numberOfImages"] = non_default_params.get("n")
if non_default_params.get("quality") is not None: if non_default_params.get("quality") is not None:

View file

@ -267,7 +267,11 @@ class BedrockImageGeneration(BaseAWSLLM):
**inference_params, **inference_params,
} }
elif provider == "amazon": elif provider == "amazon":
return dict(litellm.AmazonNovaCanvasConfig.transform_request_body(text=prompt, optional_params=optional_params)) return dict(
litellm.AmazonNovaCanvasConfig.transform_request_body(
text=prompt, optional_params=optional_params
)
)
else: else:
raise BedrockError( raise BedrockError(
status_code=422, message=f"Unsupported model={model}, passed in" status_code=422, message=f"Unsupported model={model}, passed in"
@ -303,9 +307,12 @@ class BedrockImageGeneration(BaseAWSLLM):
config_class = ( config_class = (
litellm.AmazonStability3Config litellm.AmazonStability3Config
if litellm.AmazonStability3Config._is_stability_3_model(model=model) if litellm.AmazonStability3Config._is_stability_3_model(model=model)
else litellm.AmazonNovaCanvasConfig if litellm.AmazonNovaCanvasConfig._is_nova_model(model=model) else (
litellm.AmazonNovaCanvasConfig
if litellm.AmazonNovaCanvasConfig._is_nova_model(model=model)
else litellm.AmazonStabilityConfig else litellm.AmazonStabilityConfig
) )
)
config_class.transform_response_dict_to_openai_response( config_class.transform_response_dict_to_openai_response(
model_response=model_response, model_response=model_response,
response_dict=response_dict, response_dict=response_dict,

View file

@ -60,7 +60,6 @@ class BedrockRerankHandler(BaseAWSLLM):
extra_headers: Optional[dict] = None, extra_headers: Optional[dict] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
) -> RerankResponse: ) -> RerankResponse:
request_data = RerankRequest( request_data = RerankRequest(
model=model, model=model,
query=query, query=query,

View file

@ -29,7 +29,6 @@ from litellm.types.rerank import (
class BedrockRerankConfig: class BedrockRerankConfig:
def _transform_sources( def _transform_sources(
self, documents: List[Union[str, dict]] self, documents: List[Union[str, dict]]
) -> List[BedrockRerankSource]: ) -> List[BedrockRerankSource]:

View file

@ -314,7 +314,6 @@ class CodestralTextCompletion:
return _response return _response
### SYNC COMPLETION ### SYNC COMPLETION
else: else:
response = litellm.module_level_client.post( response = litellm.module_level_client.post(
url=completion_url, url=completion_url,
headers=headers, headers=headers,
@ -352,13 +351,11 @@ class CodestralTextCompletion:
logger_fn=None, logger_fn=None,
headers={}, headers={},
) -> TextCompletionResponse: ) -> TextCompletionResponse:
async_handler = get_async_httpx_client( async_handler = get_async_httpx_client(
llm_provider=litellm.LlmProviders.TEXT_COMPLETION_CODESTRAL, llm_provider=litellm.LlmProviders.TEXT_COMPLETION_CODESTRAL,
params={"timeout": timeout}, params={"timeout": timeout},
) )
try: try:
response = await async_handler.post( response = await async_handler.post(
api_base, headers=headers, data=json.dumps(data) api_base, headers=headers, data=json.dumps(data)
) )

View file

@ -78,7 +78,6 @@ class CodestralTextCompletionConfig(OpenAITextCompletionConfig):
return optional_params return optional_params
def _chunk_parser(self, chunk_data: str) -> GenericStreamingChunk: def _chunk_parser(self, chunk_data: str) -> GenericStreamingChunk:
text = "" text = ""
is_finished = False is_finished = False
finish_reason = None finish_reason = None

View file

@ -180,7 +180,6 @@ class CohereChatConfig(BaseConfig):
litellm_params: dict, litellm_params: dict,
headers: dict, headers: dict,
) -> dict: ) -> dict:
## Load Config ## Load Config
for k, v in litellm.CohereChatConfig.get_config().items(): for k, v in litellm.CohereChatConfig.get_config().items():
if ( if (
@ -222,7 +221,6 @@ class CohereChatConfig(BaseConfig):
api_key: Optional[str] = None, api_key: Optional[str] = None,
json_mode: Optional[bool] = None, json_mode: Optional[bool] = None,
) -> ModelResponse: ) -> ModelResponse:
try: try:
raw_response_json = raw_response.json() raw_response_json = raw_response.json()
model_response.choices[0].message.content = raw_response_json["text"] # type: ignore model_response.choices[0].message.content = raw_response_json["text"] # type: ignore

View file

@ -56,7 +56,6 @@ async def async_embedding(
encoding: Callable, encoding: Callable,
client: Optional[AsyncHTTPHandler] = None, client: Optional[AsyncHTTPHandler] = None,
): ):
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=input, input=input,

View file

@ -72,7 +72,6 @@ class CohereEmbeddingConfig:
return transformed_request return transformed_request
def _calculate_usage(self, input: List[str], encoding: Any, meta: dict) -> Usage: def _calculate_usage(self, input: List[str], encoding: Any, meta: dict) -> Usage:
input_tokens = 0 input_tokens = 0
text_tokens: Optional[int] = meta.get("billed_units", {}).get("input_tokens") text_tokens: Optional[int] = meta.get("billed_units", {}).get("input_tokens")
@ -111,7 +110,6 @@ class CohereEmbeddingConfig:
encoding: Any, encoding: Any,
input: list, input: list,
) -> EmbeddingResponse: ) -> EmbeddingResponse:
response_json = response.json() response_json = response.json()
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(

View file

@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Union
from litellm.llms.cohere.rerank.transformation import CohereRerankConfig from litellm.llms.cohere.rerank.transformation import CohereRerankConfig
from litellm.types.rerank import OptionalRerankParams, RerankRequest from litellm.types.rerank import OptionalRerankParams, RerankRequest
class CohereRerankV2Config(CohereRerankConfig): class CohereRerankV2Config(CohereRerankConfig):
""" """
Reference: https://docs.cohere.com/v2/reference/rerank Reference: https://docs.cohere.com/v2/reference/rerank

View file

@ -32,7 +32,6 @@ DEFAULT_TIMEOUT = 600
class BaseLLMAIOHTTPHandler: class BaseLLMAIOHTTPHandler:
def __init__(self): def __init__(self):
self.client_session: Optional[aiohttp.ClientSession] = None self.client_session: Optional[aiohttp.ClientSession] = None
@ -110,7 +109,6 @@ class BaseLLMAIOHTTPHandler:
content: Any = None, content: Any = None,
params: Optional[dict] = None, params: Optional[dict] = None,
) -> httpx.Response: ) -> httpx.Response:
max_retry_on_unprocessable_entity_error = ( max_retry_on_unprocessable_entity_error = (
provider_config.max_retry_on_unprocessable_entity_error provider_config.max_retry_on_unprocessable_entity_error
) )

View file

@ -114,7 +114,6 @@ class AsyncHTTPHandler:
event_hooks: Optional[Mapping[str, List[Callable[..., Any]]]], event_hooks: Optional[Mapping[str, List[Callable[..., Any]]]],
ssl_verify: Optional[VerifyTypes] = None, ssl_verify: Optional[VerifyTypes] = None,
) -> httpx.AsyncClient: ) -> httpx.AsyncClient:
# SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts. # SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
# /path/to/certificate.pem # /path/to/certificate.pem
if ssl_verify is None: if ssl_verify is None:
@ -590,7 +589,6 @@ class HTTPHandler:
timeout: Optional[Union[float, httpx.Timeout]] = None, timeout: Optional[Union[float, httpx.Timeout]] = None,
): ):
try: try:
if timeout is not None: if timeout is not None:
req = self.client.build_request( req = self.client.build_request(
"PATCH", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore "PATCH", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore
@ -609,7 +607,6 @@ class HTTPHandler:
llm_provider="litellm-httpx-handler", llm_provider="litellm-httpx-handler",
) )
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
if stream is True: if stream is True:
setattr(e, "message", mask_sensitive_info(e.response.read())) setattr(e, "message", mask_sensitive_info(e.response.read()))
setattr(e, "text", mask_sensitive_info(e.response.read())) setattr(e, "text", mask_sensitive_info(e.response.read()))
@ -635,7 +632,6 @@ class HTTPHandler:
timeout: Optional[Union[float, httpx.Timeout]] = None, timeout: Optional[Union[float, httpx.Timeout]] = None,
): ):
try: try:
if timeout is not None: if timeout is not None:
req = self.client.build_request( req = self.client.build_request(
"PUT", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore "PUT", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore

View file

@ -41,7 +41,6 @@ else:
class BaseLLMHTTPHandler: class BaseLLMHTTPHandler:
async def _make_common_async_call( async def _make_common_async_call(
self, self,
async_httpx_client: AsyncHTTPHandler, async_httpx_client: AsyncHTTPHandler,
@ -109,7 +108,6 @@ class BaseLLMHTTPHandler:
logging_obj: LiteLLMLoggingObj, logging_obj: LiteLLMLoggingObj,
stream: bool = False, stream: bool = False,
) -> httpx.Response: ) -> httpx.Response:
max_retry_on_unprocessable_entity_error = ( max_retry_on_unprocessable_entity_error = (
provider_config.max_retry_on_unprocessable_entity_error provider_config.max_retry_on_unprocessable_entity_error
) )
@ -599,7 +597,6 @@ class BaseLLMHTTPHandler:
aembedding: bool = False, aembedding: bool = False,
headers={}, headers={},
) -> EmbeddingResponse: ) -> EmbeddingResponse:
provider_config = ProviderConfigManager.get_provider_embedding_config( provider_config = ProviderConfigManager.get_provider_embedding_config(
model=model, provider=litellm.LlmProviders(custom_llm_provider) model=model, provider=litellm.LlmProviders(custom_llm_provider)
) )
@ -742,7 +739,6 @@ class BaseLLMHTTPHandler:
api_base: Optional[str] = None, api_base: Optional[str] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
) -> RerankResponse: ) -> RerankResponse:
# get config from model, custom llm provider # get config from model, custom llm provider
headers = provider_config.validate_environment( headers = provider_config.validate_environment(
api_key=api_key, api_key=api_key,
@ -828,7 +824,6 @@ class BaseLLMHTTPHandler:
timeout: Optional[Union[float, httpx.Timeout]] = None, timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
) -> RerankResponse: ) -> RerankResponse:
if client is None or not isinstance(client, AsyncHTTPHandler): if client is None or not isinstance(client, AsyncHTTPHandler):
async_httpx_client = get_async_httpx_client( async_httpx_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders(custom_llm_provider) llm_provider=litellm.LlmProviders(custom_llm_provider)

View file

@ -16,9 +16,9 @@ class DatabricksBase:
api_base = api_base or f"{databricks_client.config.host}/serving-endpoints" api_base = api_base or f"{databricks_client.config.host}/serving-endpoints"
if api_key is None: if api_key is None:
databricks_auth_headers: dict[str, str] = ( databricks_auth_headers: dict[
databricks_client.config.authenticate() str, str
) ] = databricks_client.config.authenticate()
headers = {**databricks_auth_headers, **headers} headers = {**databricks_auth_headers, **headers}
return api_base, headers return api_base, headers

View file

@ -11,9 +11,9 @@ class DatabricksEmbeddingConfig:
Reference: https://learn.microsoft.com/en-us/azure/databricks/machine-learning/foundation-models/api-reference#--embedding-task Reference: https://learn.microsoft.com/en-us/azure/databricks/machine-learning/foundation-models/api-reference#--embedding-task
""" """
instruction: Optional[str] = ( instruction: Optional[
None # An optional instruction to pass to the embedding model. BGE Authors recommend 'Represent this sentence for searching relevant passages:' for retrieval queries str
) ] = None # An optional instruction to pass to the embedding model. BGE Authors recommend 'Represent this sentence for searching relevant passages:' for retrieval queries
def __init__(self, instruction: Optional[str] = None) -> None: def __init__(self, instruction: Optional[str] = None) -> None:
locals_ = locals().copy() locals_ = locals().copy()

View file

@ -55,7 +55,6 @@ class ModelResponseIterator:
usage_chunk: Optional[Usage] = getattr(processed_chunk, "usage", None) usage_chunk: Optional[Usage] = getattr(processed_chunk, "usage", None)
if usage_chunk is not None: if usage_chunk is not None:
usage = ChatCompletionUsageBlock( usage = ChatCompletionUsageBlock(
prompt_tokens=usage_chunk.prompt_tokens, prompt_tokens=usage_chunk.prompt_tokens,
completion_tokens=usage_chunk.completion_tokens, completion_tokens=usage_chunk.completion_tokens,

View file

@ -126,9 +126,9 @@ class DeepgramAudioTranscriptionConfig(BaseAudioTranscriptionConfig):
# Add additional metadata matching OpenAI format # Add additional metadata matching OpenAI format
response["task"] = "transcribe" response["task"] = "transcribe"
response["language"] = ( response[
"english" # Deepgram auto-detects but doesn't return language "language"
) ] = "english" # Deepgram auto-detects but doesn't return language
response["duration"] = response_json["metadata"]["duration"] response["duration"] = response_json["metadata"]["duration"]
# Transform words to match OpenAI format # Transform words to match OpenAI format

View file

@ -14,7 +14,6 @@ from ...openai.chat.gpt_transformation import OpenAIGPTConfig
class DeepSeekChatConfig(OpenAIGPTConfig): class DeepSeekChatConfig(OpenAIGPTConfig):
def _transform_messages( def _transform_messages(
self, messages: List[AllMessageValues], model: str self, messages: List[AllMessageValues], model: str
) -> List[AllMessageValues]: ) -> List[AllMessageValues]:

View file

@ -77,9 +77,9 @@ class AlephAlphaConfig:
- `control_log_additive` (boolean; default value: true): Method of applying control to attention scores. - `control_log_additive` (boolean; default value: true): Method of applying control to attention scores.
""" """
maximum_tokens: Optional[int] = ( maximum_tokens: Optional[
litellm.max_tokens int
) # aleph alpha requires max tokens ] = litellm.max_tokens # aleph alpha requires max tokens
minimum_tokens: Optional[int] = None minimum_tokens: Optional[int] = None
echo: Optional[bool] = None echo: Optional[bool] = None
temperature: Optional[int] = None temperature: Optional[int] = None

Some files were not shown because too many files have changed in this diff Show more