build(pyproject.toml): add new dev dependencies - for type checking (#9631)

* build(pyproject.toml): add new dev dependencies - for type checking

* build: reformat files to fit black

* ci: reformat to fit black

* ci(test-litellm.yml): make tests run clear

* build(pyproject.toml): add ruff

* fix: fix ruff checks

* build(mypy/): fix mypy linting errors

* fix(hashicorp_secret_manager.py): fix passing cert for tls auth

* build(mypy/): resolve all mypy errors

* test: update test

* fix: fix black formatting

* build(pre-commit-config.yaml): use poetry run black

* fix(proxy_server.py): fix linting error

* fix: fix ruff safe representation error
This commit is contained in:
Krish Dholakia 2025-03-29 11:02:13 -07:00 committed by GitHub
parent 72198737f8
commit d7b294dd0a
214 changed files with 1553 additions and 1433 deletions

53
.github/workflows/test-linting.yml vendored Normal file
View file

@ -0,0 +1,53 @@
name: LiteLLM Linting
on:
pull_request:
branches: [ main ]
jobs:
lint:
runs-on: ubuntu-latest
timeout-minutes: 5
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.12'
- name: Install Poetry
uses: snok/install-poetry@v1
- name: Install dependencies
run: |
poetry install --with dev
- name: Run Black formatting check
run: |
cd litellm
poetry run black . --check
cd ..
- name: Run Ruff linting
run: |
cd litellm
poetry run ruff check .
cd ..
- name: Run MyPy type checking
run: |
cd litellm
poetry run mypy . --ignore-missing-imports
cd ..
- name: Check for circular imports
run: |
cd litellm
poetry run python ../tests/documentation_tests/test_circular_imports.py
cd ..
- name: Check import safety
run: |
poetry run python -c "from litellm import *" || (echo '🚨 import failed, this means you introduced unprotected imports! 🚨'; exit 1)

View file

@ -1,4 +1,4 @@
name: LiteLLM Tests
name: LiteLLM Mock Tests (folder - tests/litellm)
on:
pull_request:

View file

@ -14,10 +14,12 @@ repos:
types: [python]
files: litellm/.*\.py
exclude: ^litellm/__init__.py$
- repo: https://github.com/psf/black
rev: 24.2.0
hooks:
- id: black
- id: black
name: black
entry: poetry run black
language: system
types: [python]
files: litellm/.*\.py
- repo: https://github.com/pycqa/flake8
rev: 7.0.0 # The version of flake8 to use
hooks:

View file

@ -444,9 +444,7 @@ class _ENTERPRISE_SecretDetection(CustomGuardrail):
detected_secrets = []
for file in secrets.files:
for found_secret in secrets[file]:
if found_secret.secret_value is None:
continue
detected_secrets.append(
@ -471,14 +469,12 @@ class _ENTERPRISE_SecretDetection(CustomGuardrail):
data: dict,
call_type: str, # "completion", "embeddings", "image_generation", "moderation"
):
if await self.should_run_check(user_api_key_dict) is False:
return
if "messages" in data and isinstance(data["messages"], list):
for message in data["messages"]:
if "content" in message and isinstance(message["content"], str):
detected_secrets = self.scan_message_for_secrets(message["content"])
for secret in detected_secrets:

View file

@ -122,19 +122,19 @@ langsmith_batch_size: Optional[int] = None
prometheus_initialize_budget_metrics: Optional[bool] = False
argilla_batch_size: Optional[int] = None
datadog_use_v1: Optional[bool] = False # if you want to use v1 datadog logged payload
gcs_pub_sub_use_v1: Optional[bool] = (
False # if you want to use v1 gcs pubsub logged payload
)
gcs_pub_sub_use_v1: Optional[
bool
] = False # if you want to use v1 gcs pubsub logged payload
argilla_transformation_object: Optional[Dict[str, Any]] = None
_async_input_callback: List[Union[str, Callable, CustomLogger]] = (
[]
) # internal variable - async custom callbacks are routed here.
_async_success_callback: List[Union[str, Callable, CustomLogger]] = (
[]
) # internal variable - async custom callbacks are routed here.
_async_failure_callback: List[Union[str, Callable, CustomLogger]] = (
[]
) # internal variable - async custom callbacks are routed here.
_async_input_callback: List[
Union[str, Callable, CustomLogger]
] = [] # internal variable - async custom callbacks are routed here.
_async_success_callback: List[
Union[str, Callable, CustomLogger]
] = [] # internal variable - async custom callbacks are routed here.
_async_failure_callback: List[
Union[str, Callable, CustomLogger]
] = [] # internal variable - async custom callbacks are routed here.
pre_call_rules: List[Callable] = []
post_call_rules: List[Callable] = []
turn_off_message_logging: Optional[bool] = False
@ -142,18 +142,18 @@ log_raw_request_response: bool = False
redact_messages_in_exceptions: Optional[bool] = False
redact_user_api_key_info: Optional[bool] = False
filter_invalid_headers: Optional[bool] = False
add_user_information_to_llm_headers: Optional[bool] = (
None # adds user_id, team_id, token hash (params from StandardLoggingMetadata) to request headers
)
add_user_information_to_llm_headers: Optional[
bool
] = None # adds user_id, team_id, token hash (params from StandardLoggingMetadata) to request headers
store_audit_logs = False # Enterprise feature, allow users to see audit logs
### end of callbacks #############
email: Optional[str] = (
None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
)
token: Optional[str] = (
None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
)
email: Optional[
str
] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
token: Optional[
str
] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
telemetry = True
max_tokens = 256 # OpenAI Defaults
drop_params = bool(os.getenv("LITELLM_DROP_PARAMS", False))
@ -229,24 +229,20 @@ enable_loadbalancing_on_batch_endpoints: Optional[bool] = None
enable_caching_on_provider_specific_optional_params: bool = (
False # feature-flag for caching on optional params - e.g. 'top_k'
)
caching: bool = (
False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
)
caching_with_models: bool = (
False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
)
cache: Optional[Cache] = (
None # cache object <- use this - https://docs.litellm.ai/docs/caching
)
caching: bool = False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
caching_with_models: bool = False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
cache: Optional[
Cache
] = None # cache object <- use this - https://docs.litellm.ai/docs/caching
default_in_memory_ttl: Optional[float] = None
default_redis_ttl: Optional[float] = None
default_redis_batch_cache_expiry: Optional[float] = None
model_alias_map: Dict[str, str] = {}
model_group_alias_map: Dict[str, str] = {}
max_budget: float = 0.0 # set the max budget across all providers
budget_duration: Optional[str] = (
None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
)
budget_duration: Optional[
str
] = None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
default_soft_budget: float = (
50.0 # by default all litellm proxy keys have a soft budget of 50.0
)
@ -255,15 +251,11 @@ forward_traceparent_to_llm_provider: bool = False
_current_cost = 0.0 # private variable, used if max budget is set
error_logs: Dict = {}
add_function_to_prompt: bool = (
False # if function calling not supported by api, append function call details to system prompt
)
add_function_to_prompt: bool = False # if function calling not supported by api, append function call details to system prompt
client_session: Optional[httpx.Client] = None
aclient_session: Optional[httpx.AsyncClient] = None
model_fallbacks: Optional[List] = None # Deprecated for 'litellm.fallbacks'
model_cost_map_url: str = (
"https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
)
model_cost_map_url: str = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
suppress_debug_info = False
dynamodb_table_name: Optional[str] = None
s3_callback_params: Optional[Dict] = None
@ -285,9 +277,7 @@ disable_end_user_cost_tracking_prometheus_only: Optional[bool] = None
custom_prometheus_metadata_labels: List[str] = []
#### REQUEST PRIORITIZATION ####
priority_reservation: Optional[Dict[str, float]] = None
force_ipv4: bool = (
False # when True, litellm will force ipv4 for all LLM requests. Some users have seen httpx ConnectionError when using ipv6.
)
force_ipv4: bool = False # when True, litellm will force ipv4 for all LLM requests. Some users have seen httpx ConnectionError when using ipv6.
module_level_aclient = AsyncHTTPHandler(
timeout=request_timeout, client_alias="module level aclient"
)
@ -301,13 +291,13 @@ fallbacks: Optional[List] = None
context_window_fallbacks: Optional[List] = None
content_policy_fallbacks: Optional[List] = None
allowed_fails: int = 3
num_retries_per_request: Optional[int] = (
None # for the request overall (incl. fallbacks + model retries)
)
num_retries_per_request: Optional[
int
] = None # for the request overall (incl. fallbacks + model retries)
####### SECRET MANAGERS #####################
secret_manager_client: Optional[Any] = (
None # list of instantiated key management clients - e.g. azure kv, infisical, etc.
)
secret_manager_client: Optional[
Any
] = None # list of instantiated key management clients - e.g. azure kv, infisical, etc.
_google_kms_resource_name: Optional[str] = None
_key_management_system: Optional[KeyManagementSystem] = None
_key_management_settings: KeyManagementSettings = KeyManagementSettings()
@ -1056,10 +1046,10 @@ from .types.llms.custom_llm import CustomLLMItem
from .types.utils import GenericStreamingChunk
custom_provider_map: List[CustomLLMItem] = []
_custom_providers: List[str] = (
[]
) # internal helper util, used to track names of custom providers
disable_hf_tokenizer_download: Optional[bool] = (
None # disable huggingface tokenizer download. Defaults to openai clk100
)
_custom_providers: List[
str
] = [] # internal helper util, used to track names of custom providers
disable_hf_tokenizer_download: Optional[
bool
] = None # disable huggingface tokenizer download. Defaults to openai clk100
global_disable_no_log_param: bool = False

View file

@ -15,7 +15,7 @@ from .types.services import ServiceLoggerPayload, ServiceTypes
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = _Span
Span = Union[_Span, Any]
OTELClass = OpenTelemetry
else:
Span = Any

View file

@ -153,7 +153,6 @@ def create_batch(
)
api_base: Optional[str] = None
if custom_llm_provider == "openai":
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
api_base = (
optional_params.api_base
@ -358,7 +357,6 @@ def retrieve_batch(
_is_async = kwargs.pop("aretrieve_batch", False) is True
api_base: Optional[str] = None
if custom_llm_provider == "openai":
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
api_base = (
optional_params.api_base

View file

@ -9,12 +9,12 @@ Has 4 methods:
"""
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Optional, Union
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = _Span
Span = Union[_Span, Any]
else:
Span = Any

View file

@ -66,9 +66,7 @@ class CachingHandlerResponse(BaseModel):
cached_result: Optional[Any] = None
final_embedding_cached_response: Optional[EmbeddingResponse] = None
embedding_all_elements_cache_hit: bool = (
False # this is set to True when all elements in the list have a cache hit in the embedding cache, if true return the final_embedding_cached_response no need to make an API call
)
embedding_all_elements_cache_hit: bool = False # this is set to True when all elements in the list have a cache hit in the embedding cache, if true return the final_embedding_cached_response no need to make an API call
class LLMCachingHandler:
@ -738,7 +736,6 @@ class LLMCachingHandler:
if self._should_store_result_in_cache(
original_function=self.original_function, kwargs=new_kwargs
):
litellm.cache.add_cache(result, **new_kwargs)
return
@ -865,9 +862,9 @@ class LLMCachingHandler:
}
if litellm.cache is not None:
litellm_params["preset_cache_key"] = (
litellm.cache._get_preset_cache_key_from_kwargs(**kwargs)
)
litellm_params[
"preset_cache_key"
] = litellm.cache._get_preset_cache_key_from_kwargs(**kwargs)
else:
litellm_params["preset_cache_key"] = None

View file

@ -1,12 +1,12 @@
import json
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Optional, Union
from .base_cache import BaseCache
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = _Span
Span = Union[_Span, Any]
else:
Span = Any

View file

@ -12,7 +12,7 @@ import asyncio
import time
import traceback
from concurrent.futures import ThreadPoolExecutor
from typing import TYPE_CHECKING, Any, List, Optional
from typing import TYPE_CHECKING, Any, List, Optional, Union
import litellm
from litellm._logging import print_verbose, verbose_logger
@ -24,7 +24,7 @@ from .redis_cache import RedisCache
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = _Span
Span = Union[_Span, Any]
else:
Span = Any

View file

@ -8,7 +8,6 @@ from .in_memory_cache import InMemoryCache
class LLMClientCache(InMemoryCache):
def update_cache_key_with_event_loop(self, key):
"""
Add the event loop to the cache key, to prevent event loop closed errors.

View file

@ -34,7 +34,7 @@ if TYPE_CHECKING:
cluster_pipeline = ClusterPipeline
async_redis_client = Redis
async_redis_cluster_client = RedisCluster
Span = _Span
Span = Union[_Span, Any]
else:
pipeline = Any
cluster_pipeline = Any
@ -57,7 +57,6 @@ class RedisCache(BaseCache):
socket_timeout: Optional[float] = 5.0, # default 5 second timeout
**kwargs,
):
from litellm._service_logger import ServiceLogging
from .._redis import get_redis_client, get_redis_connection_pool

View file

@ -5,7 +5,7 @@ Key differences:
- RedisClient NEEDs to be re-used across requests, adds 3000ms latency if it's re-created
"""
from typing import TYPE_CHECKING, Any, List, Optional
from typing import TYPE_CHECKING, Any, List, Optional, Union
from litellm.caching.redis_cache import RedisCache
@ -16,7 +16,7 @@ if TYPE_CHECKING:
pipeline = Pipeline
async_redis_client = Redis
Span = _Span
Span = Union[_Span, Any]
else:
pipeline = Any
async_redis_client = Any

View file

@ -13,11 +13,15 @@ import ast
import asyncio
import json
import os
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, cast
import litellm
from litellm._logging import print_verbose
from litellm.litellm_core_utils.prompt_templates.common_utils import get_str_from_messages
from litellm.litellm_core_utils.prompt_templates.common_utils import (
get_str_from_messages,
)
from litellm.types.utils import EmbeddingResponse
from .base_cache import BaseCache
@ -87,14 +91,16 @@ class RedisSemanticCache(BaseCache):
if redis_url is None:
try:
# Attempt to use provided parameters or fallback to environment variables
host = host or os.environ['REDIS_HOST']
port = port or os.environ['REDIS_PORT']
password = password or os.environ['REDIS_PASSWORD']
host = host or os.environ["REDIS_HOST"]
port = port or os.environ["REDIS_PORT"]
password = password or os.environ["REDIS_PASSWORD"]
except KeyError as e:
# Raise a more informative exception if any of the required keys are missing
missing_var = e.args[0]
raise ValueError(f"Missing required Redis configuration: {missing_var}. "
f"Provide {missing_var} or redis_url.") from e
raise ValueError(
f"Missing required Redis configuration: {missing_var}. "
f"Provide {missing_var} or redis_url."
) from e
redis_url = f"redis://:{password}@{host}:{port}"
@ -137,10 +143,13 @@ class RedisSemanticCache(BaseCache):
List[float]: The embedding vector
"""
# Create an embedding from prompt
embedding_response = litellm.embedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
embedding_response = cast(
EmbeddingResponse,
litellm.embedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
),
)
embedding = embedding_response["data"][0]["embedding"]
return embedding
@ -186,6 +195,7 @@ class RedisSemanticCache(BaseCache):
"""
print_verbose(f"Redis semantic-cache set_cache, kwargs: {kwargs}")
value_str: Optional[str] = None
try:
# Extract the prompt from messages
messages = kwargs.get("messages", [])
@ -203,7 +213,9 @@ class RedisSemanticCache(BaseCache):
else:
self.llmcache.store(prompt, value_str)
except Exception as e:
print_verbose(f"Error setting {value_str} in the Redis semantic cache: {str(e)}")
print_verbose(
f"Error setting {value_str or value} in the Redis semantic cache: {str(e)}"
)
def get_cache(self, key: str, **kwargs) -> Any:
"""
@ -336,13 +348,13 @@ class RedisSemanticCache(BaseCache):
prompt,
value_str,
vector=prompt_embedding, # Pass through custom embedding
ttl=ttl
ttl=ttl,
)
else:
await self.llmcache.astore(
prompt,
value_str,
vector=prompt_embedding # Pass through custom embedding
vector=prompt_embedding, # Pass through custom embedding
)
except Exception as e:
print_verbose(f"Error in async_set_cache: {str(e)}")
@ -374,14 +386,13 @@ class RedisSemanticCache(BaseCache):
prompt_embedding = await self._get_async_embedding(prompt, **kwargs)
# Check the cache for semantically similar prompts
results = await self.llmcache.acheck(
prompt=prompt,
vector=prompt_embedding
)
results = await self.llmcache.acheck(prompt=prompt, vector=prompt_embedding)
# handle results / cache hit
if not results:
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0 # TODO why here but not above??
kwargs.setdefault("metadata", {})[
"semantic-similarity"
] = 0.0 # TODO why here but not above??
return None
cache_hit = results[0]
@ -420,7 +431,9 @@ class RedisSemanticCache(BaseCache):
aindex = await self.llmcache._get_async_index()
return await aindex.info()
async def async_set_cache_pipeline(self, cache_list: List[Tuple[str, Any]], **kwargs) -> None:
async def async_set_cache_pipeline(
self, cache_list: List[Tuple[str, Any]], **kwargs
) -> None:
"""
Asynchronously store multiple values in the semantic cache.

View file

@ -123,7 +123,7 @@ class S3Cache(BaseCache):
) # Convert string to dictionary
except Exception:
cached_response = ast.literal_eval(cached_response)
if type(cached_response) is not dict:
if not isinstance(cached_response, dict):
cached_response = dict(cached_response)
verbose_logger.debug(
f"Got S3 Cache: key: {key}, cached_response {cached_response}. Type Response {type(cached_response)}"

View file

@ -580,7 +580,6 @@ def completion_cost( # noqa: PLR0915
- For un-mapped Replicate models, the cost is calculated based on the total time used for the request.
"""
try:
call_type = _infer_call_type(call_type, completion_response) or "completion"
if (

View file

@ -138,7 +138,6 @@ def create_fine_tuning_job(
# OpenAI
if custom_llm_provider == "openai":
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
api_base = (
optional_params.api_base
@ -360,7 +359,6 @@ def cancel_fine_tuning_job(
# OpenAI
if custom_llm_provider == "openai":
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
api_base = (
optional_params.api_base
@ -522,7 +520,6 @@ def list_fine_tuning_jobs(
# OpenAI
if custom_llm_provider == "openai":
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
api_base = (
optional_params.api_base

View file

@ -19,7 +19,6 @@ else:
def squash_payloads(queue):
squashed = {}
if len(queue) == 0:
return squashed

View file

@ -195,12 +195,15 @@ class SlackAlerting(CustomBatchLogger):
if self.alerting is None or self.alert_types is None:
return
time_difference_float, model, api_base, messages = (
self._response_taking_too_long_callback_helper(
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
)
(
time_difference_float,
model,
api_base,
messages,
) = self._response_taking_too_long_callback_helper(
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
)
if litellm.turn_off_message_logging or litellm.redact_messages_in_exceptions:
messages = "Message not logged. litellm.redact_messages_in_exceptions=True"
@ -819,9 +822,9 @@ class SlackAlerting(CustomBatchLogger):
### UNIQUE CACHE KEY ###
cache_key = provider + region_name
outage_value: Optional[ProviderRegionOutageModel] = (
await self.internal_usage_cache.async_get_cache(key=cache_key)
)
outage_value: Optional[
ProviderRegionOutageModel
] = await self.internal_usage_cache.async_get_cache(key=cache_key)
if (
getattr(exception, "status_code", None) is None
@ -1402,9 +1405,9 @@ Model Info:
self.alert_to_webhook_url is not None
and alert_type in self.alert_to_webhook_url
):
slack_webhook_url: Optional[Union[str, List[str]]] = (
self.alert_to_webhook_url[alert_type]
)
slack_webhook_url: Optional[
Union[str, List[str]]
] = self.alert_to_webhook_url[alert_type]
elif self.default_webhook_url is not None:
slack_webhook_url = self.default_webhook_url
else:
@ -1768,7 +1771,6 @@ Model Info:
- Team Created, Updated, Deleted
"""
try:
message = f"`{event_name}`\n"
key_event_dict = key_event.model_dump()

View file

@ -98,7 +98,6 @@ class ArgillaLogger(CustomBatchLogger):
argilla_dataset_name: Optional[str],
argilla_base_url: Optional[str],
) -> ArgillaCredentialsObject:
_credentials_api_key = argilla_api_key or os.getenv("ARGILLA_API_KEY")
if _credentials_api_key is None:
raise Exception("Invalid Argilla API Key given. _credentials_api_key=None.")

View file

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Optional, Union
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
@ -7,7 +7,7 @@ from litellm.types.utils import StandardLoggingPayload
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = _Span
Span = Union[_Span, Any]
else:
Span = Any

View file

@ -19,14 +19,13 @@ if TYPE_CHECKING:
from litellm.types.integrations.arize import Protocol as _Protocol
Protocol = _Protocol
Span = _Span
Span = Union[_Span, Any]
else:
Protocol = Any
Span = Any
class ArizeLogger(OpenTelemetry):
def set_attributes(self, span: Span, kwargs, response_obj: Optional[Any]):
ArizeLogger.set_arize_attributes(span, kwargs, response_obj)
return

View file

@ -1,17 +1,20 @@
import os
from typing import TYPE_CHECKING, Any
from litellm.integrations.arize import _utils
from typing import TYPE_CHECKING, Any, Union
from litellm._logging import verbose_logger
from litellm.integrations.arize import _utils
from litellm.types.integrations.arize_phoenix import ArizePhoenixConfig
if TYPE_CHECKING:
from .opentelemetry import OpenTelemetryConfig as _OpenTelemetryConfig
from litellm.types.integrations.arize import Protocol as _Protocol
from opentelemetry.trace import Span as _Span
from litellm.types.integrations.arize import Protocol as _Protocol
from .opentelemetry import OpenTelemetryConfig as _OpenTelemetryConfig
Protocol = _Protocol
OpenTelemetryConfig = _OpenTelemetryConfig
Span = _Span
Span = Union[_Span, Any]
else:
Protocol = Any
OpenTelemetryConfig = Any
@ -20,6 +23,7 @@ else:
ARIZE_HOSTED_PHOENIX_ENDPOINT = "https://app.phoenix.arize.com/v1/traces"
class ArizePhoenixLogger:
@staticmethod
def set_arize_phoenix_attributes(span: Span, kwargs, response_obj):
@ -59,15 +63,14 @@ class ArizePhoenixLogger:
# a slightly different auth header format than self hosted phoenix
if endpoint == ARIZE_HOSTED_PHOENIX_ENDPOINT:
if api_key is None:
raise ValueError("PHOENIX_API_KEY must be set when the Arize hosted Phoenix endpoint is used.")
raise ValueError(
"PHOENIX_API_KEY must be set when the Arize hosted Phoenix endpoint is used."
)
otlp_auth_headers = f"api_key={api_key}"
elif api_key is not None:
# api_key/auth is optional for self hosted phoenix
otlp_auth_headers = f"Authorization=Bearer {api_key}"
return ArizePhoenixConfig(
otlp_auth_headers=otlp_auth_headers,
protocol=protocol,
endpoint=endpoint
otlp_auth_headers=otlp_auth_headers, protocol=protocol, endpoint=endpoint
)

View file

@ -12,7 +12,10 @@ class AthinaLogger:
"athina-api-key": self.athina_api_key,
"Content-Type": "application/json",
}
self.athina_logging_url = os.getenv("ATHINA_BASE_URL", "https://log.athina.ai") + "/api/v1/log/inference"
self.athina_logging_url = (
os.getenv("ATHINA_BASE_URL", "https://log.athina.ai")
+ "/api/v1/log/inference"
)
self.additional_keys = [
"environment",
"prompt_slug",

View file

@ -50,12 +50,12 @@ class AzureBlobStorageLogger(CustomBatchLogger):
self.azure_storage_file_system: str = _azure_storage_file_system
# Internal variables used for Token based authentication
self.azure_auth_token: Optional[str] = (
None # the Azure AD token to use for Azure Storage API requests
)
self.token_expiry: Optional[datetime] = (
None # the expiry time of the currentAzure AD token
)
self.azure_auth_token: Optional[
str
] = None # the Azure AD token to use for Azure Storage API requests
self.token_expiry: Optional[
datetime
] = None # the expiry time of the currentAzure AD token
asyncio.create_task(self.periodic_flush())
self.flush_lock = asyncio.Lock()
@ -153,7 +153,6 @@ class AzureBlobStorageLogger(CustomBatchLogger):
3. Flush the data
"""
try:
if self.azure_storage_account_key:
await self.upload_to_azure_data_lake_with_azure_account_key(
payload=payload

View file

@ -4,7 +4,7 @@
import copy
import os
from datetime import datetime
from typing import Optional, Dict
from typing import Dict, Optional
import httpx
from pydantic import BaseModel
@ -19,7 +19,9 @@ from litellm.llms.custom_httpx.http_handler import (
)
from litellm.utils import print_verbose
global_braintrust_http_handler = get_async_httpx_client(llm_provider=httpxSpecialProvider.LoggingCallback)
global_braintrust_http_handler = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback
)
global_braintrust_sync_http_handler = HTTPHandler()
API_BASE = "https://api.braintrustdata.com/v1"
@ -35,7 +37,9 @@ def get_utc_datetime():
class BraintrustLogger(CustomLogger):
def __init__(self, api_key: Optional[str] = None, api_base: Optional[str] = None) -> None:
def __init__(
self, api_key: Optional[str] = None, api_base: Optional[str] = None
) -> None:
super().__init__()
self.validate_environment(api_key=api_key)
self.api_base = api_base or API_BASE
@ -45,7 +49,9 @@ class BraintrustLogger(CustomLogger):
"Authorization": "Bearer " + self.api_key,
"Content-Type": "application/json",
}
self._project_id_cache: Dict[str, str] = {} # Cache mapping project names to IDs
self._project_id_cache: Dict[
str, str
] = {} # Cache mapping project names to IDs
def validate_environment(self, api_key: Optional[str]):
"""
@ -71,7 +77,9 @@ class BraintrustLogger(CustomLogger):
try:
response = global_braintrust_sync_http_handler.post(
f"{self.api_base}/project", headers=self.headers, json={"name": project_name}
f"{self.api_base}/project",
headers=self.headers,
json={"name": project_name},
)
project_dict = response.json()
project_id = project_dict["id"]
@ -89,7 +97,9 @@ class BraintrustLogger(CustomLogger):
try:
response = await global_braintrust_http_handler.post(
f"{self.api_base}/project/register", headers=self.headers, json={"name": project_name}
f"{self.api_base}/project/register",
headers=self.headers,
json={"name": project_name},
)
project_dict = response.json()
project_id = project_dict["id"]
@ -116,15 +126,21 @@ class BraintrustLogger(CustomLogger):
if metadata is None:
metadata = {}
proxy_headers = litellm_params.get("proxy_server_request", {}).get("headers", {}) or {}
proxy_headers = (
litellm_params.get("proxy_server_request", {}).get("headers", {}) or {}
)
for metadata_param_key in proxy_headers:
if metadata_param_key.startswith("braintrust"):
trace_param_key = metadata_param_key.replace("braintrust", "", 1)
if trace_param_key in metadata:
verbose_logger.warning(f"Overwriting Braintrust `{trace_param_key}` from request header")
verbose_logger.warning(
f"Overwriting Braintrust `{trace_param_key}` from request header"
)
else:
verbose_logger.debug(f"Found Braintrust `{trace_param_key}` in request header")
verbose_logger.debug(
f"Found Braintrust `{trace_param_key}` in request header"
)
metadata[trace_param_key] = proxy_headers.get(metadata_param_key)
return metadata
@ -157,24 +173,35 @@ class BraintrustLogger(CustomLogger):
output = None
choices = []
if response_obj is not None and (
kwargs.get("call_type", None) == "embedding" or isinstance(response_obj, litellm.EmbeddingResponse)
kwargs.get("call_type", None) == "embedding"
or isinstance(response_obj, litellm.EmbeddingResponse)
):
output = None
elif response_obj is not None and isinstance(response_obj, litellm.ModelResponse):
elif response_obj is not None and isinstance(
response_obj, litellm.ModelResponse
):
output = response_obj["choices"][0]["message"].json()
choices = response_obj["choices"]
elif response_obj is not None and isinstance(response_obj, litellm.TextCompletionResponse):
elif response_obj is not None and isinstance(
response_obj, litellm.TextCompletionResponse
):
output = response_obj.choices[0].text
choices = response_obj.choices
elif response_obj is not None and isinstance(response_obj, litellm.ImageResponse):
elif response_obj is not None and isinstance(
response_obj, litellm.ImageResponse
):
output = response_obj["data"]
litellm_params = kwargs.get("litellm_params", {})
metadata = litellm_params.get("metadata", {}) or {} # if litellm_params['metadata'] == None
metadata = (
litellm_params.get("metadata", {}) or {}
) # if litellm_params['metadata'] == None
metadata = self.add_metadata_from_header(litellm_params, metadata)
clean_metadata = {}
try:
metadata = copy.deepcopy(metadata) # Avoid modifying the original metadata
metadata = copy.deepcopy(
metadata
) # Avoid modifying the original metadata
except Exception:
new_metadata = {}
for key, value in metadata.items():
@ -192,7 +219,9 @@ class BraintrustLogger(CustomLogger):
project_id = metadata.get("project_id")
if project_id is None:
project_name = metadata.get("project_name")
project_id = self.get_project_id_sync(project_name) if project_name else None
project_id = (
self.get_project_id_sync(project_name) if project_name else None
)
if project_id is None:
if self.default_project_id is None:
@ -234,7 +263,8 @@ class BraintrustLogger(CustomLogger):
"completion_tokens": usage_obj.completion_tokens,
"total_tokens": usage_obj.total_tokens,
"total_cost": cost,
"time_to_first_token": end_time.timestamp() - start_time.timestamp(),
"time_to_first_token": end_time.timestamp()
- start_time.timestamp(),
"start": start_time.timestamp(),
"end": end_time.timestamp(),
}
@ -255,7 +285,9 @@ class BraintrustLogger(CustomLogger):
request_data["metrics"] = metrics
try:
print_verbose(f"global_braintrust_sync_http_handler.post: {global_braintrust_sync_http_handler.post}")
print_verbose(
f"global_braintrust_sync_http_handler.post: {global_braintrust_sync_http_handler.post}"
)
global_braintrust_sync_http_handler.post(
url=f"{self.api_base}/project_logs/{project_id}/insert",
json={"events": [request_data]},
@ -276,20 +308,29 @@ class BraintrustLogger(CustomLogger):
output = None
choices = []
if response_obj is not None and (
kwargs.get("call_type", None) == "embedding" or isinstance(response_obj, litellm.EmbeddingResponse)
kwargs.get("call_type", None) == "embedding"
or isinstance(response_obj, litellm.EmbeddingResponse)
):
output = None
elif response_obj is not None and isinstance(response_obj, litellm.ModelResponse):
elif response_obj is not None and isinstance(
response_obj, litellm.ModelResponse
):
output = response_obj["choices"][0]["message"].json()
choices = response_obj["choices"]
elif response_obj is not None and isinstance(response_obj, litellm.TextCompletionResponse):
elif response_obj is not None and isinstance(
response_obj, litellm.TextCompletionResponse
):
output = response_obj.choices[0].text
choices = response_obj.choices
elif response_obj is not None and isinstance(response_obj, litellm.ImageResponse):
elif response_obj is not None and isinstance(
response_obj, litellm.ImageResponse
):
output = response_obj["data"]
litellm_params = kwargs.get("litellm_params", {})
metadata = litellm_params.get("metadata", {}) or {} # if litellm_params['metadata'] == None
metadata = (
litellm_params.get("metadata", {}) or {}
) # if litellm_params['metadata'] == None
metadata = self.add_metadata_from_header(litellm_params, metadata)
clean_metadata = {}
new_metadata = {}
@ -313,7 +354,11 @@ class BraintrustLogger(CustomLogger):
project_id = metadata.get("project_id")
if project_id is None:
project_name = metadata.get("project_name")
project_id = await self.get_project_id_async(project_name) if project_name else None
project_id = (
await self.get_project_id_async(project_name)
if project_name
else None
)
if project_id is None:
if self.default_project_id is None:
@ -362,8 +407,14 @@ class BraintrustLogger(CustomLogger):
api_call_start_time = kwargs.get("api_call_start_time")
completion_start_time = kwargs.get("completion_start_time")
if api_call_start_time is not None and completion_start_time is not None:
metrics["time_to_first_token"] = completion_start_time.timestamp() - api_call_start_time.timestamp()
if (
api_call_start_time is not None
and completion_start_time is not None
):
metrics["time_to_first_token"] = (
completion_start_time.timestamp()
- api_call_start_time.timestamp()
)
request_data = {
"id": litellm_call_id,

View file

@ -14,7 +14,6 @@ from litellm.integrations.custom_logger import CustomLogger
class CustomBatchLogger(CustomLogger):
def __init__(
self,
flush_lock: Optional[asyncio.Lock] = None,

View file

@ -7,7 +7,6 @@ from litellm.types.utils import StandardLoggingGuardrailInformation
class CustomGuardrail(CustomLogger):
def __init__(
self,
guardrail_name: Optional[str] = None,

View file

@ -31,7 +31,7 @@ from litellm.types.utils import (
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = _Span
Span = Union[_Span, Any]
else:
Span = Any

View file

@ -233,7 +233,6 @@ class DataDogLogger(
pass
async def _log_async_event(self, kwargs, response_obj, start_time, end_time):
dd_payload = self.create_datadog_logging_payload(
kwargs=kwargs,
response_obj=response_obj,

View file

@ -125,9 +125,9 @@ class GCSBucketBase(CustomBatchLogger):
if kwargs is None:
kwargs = {}
standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = (
kwargs.get("standard_callback_dynamic_params", None)
)
standard_callback_dynamic_params: Optional[
StandardCallbackDynamicParams
] = kwargs.get("standard_callback_dynamic_params", None)
bucket_name: str
path_service_account: Optional[str]

View file

@ -70,12 +70,13 @@ class GcsPubSubLogger(CustomBatchLogger):
"""Construct authorization headers using Vertex AI auth"""
from litellm import vertex_chat_completion
_auth_header, vertex_project = (
await vertex_chat_completion._ensure_access_token_async(
credentials=self.path_service_account_json,
project_id=None,
custom_llm_provider="vertex_ai",
)
(
_auth_header,
vertex_project,
) = await vertex_chat_completion._ensure_access_token_async(
credentials=self.path_service_account_json,
project_id=None,
custom_llm_provider="vertex_ai",
)
auth_header, _ = vertex_chat_completion._get_token_and_url(

View file

@ -155,11 +155,7 @@ class HumanloopLogger(CustomLogger):
prompt_id: str,
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
) -> Tuple[
str,
List[AllMessageValues],
dict,
]:
) -> Tuple[str, List[AllMessageValues], dict,]:
humanloop_api_key = dynamic_callback_params.get(
"humanloop_api_key"
) or get_secret_str("HUMANLOOP_API_KEY")

View file

@ -471,9 +471,9 @@ class LangFuseLogger:
# we clean out all extra litellm metadata params before logging
clean_metadata: Dict[str, Any] = {}
if prompt_management_metadata is not None:
clean_metadata["prompt_management_metadata"] = (
prompt_management_metadata
)
clean_metadata[
"prompt_management_metadata"
] = prompt_management_metadata
if isinstance(metadata, dict):
for key, value in metadata.items():
# generate langfuse tags - Default Tags sent to Langfuse from LiteLLM Proxy

View file

@ -19,7 +19,6 @@ else:
class LangFuseHandler:
@staticmethod
def get_langfuse_logger_for_request(
standard_callback_dynamic_params: StandardCallbackDynamicParams,
@ -87,7 +86,9 @@ class LangFuseHandler:
if globalLangfuseLogger is not None:
return globalLangfuseLogger
credentials_dict: Dict[str, Any] = (
credentials_dict: Dict[
str, Any
] = (
{}
) # the global langfuse logger uses Environment Variables, there are no dynamic credentials
globalLangfuseLogger = in_memory_dynamic_logger_cache.get_cache(

View file

@ -172,11 +172,7 @@ class LangfusePromptManagement(LangFuseLogger, PromptManagementBase, CustomLogge
prompt_id: str,
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
) -> Tuple[
str,
List[AllMessageValues],
dict,
]:
) -> Tuple[str, List[AllMessageValues], dict,]:
return self.get_chat_completion_prompt(
model,
messages,

View file

@ -75,7 +75,6 @@ class LangsmithLogger(CustomBatchLogger):
langsmith_project: Optional[str] = None,
langsmith_base_url: Optional[str] = None,
) -> LangsmithCredentialsObject:
_credentials_api_key = langsmith_api_key or os.getenv("LANGSMITH_API_KEY")
if _credentials_api_key is None:
raise Exception(
@ -443,9 +442,9 @@ class LangsmithLogger(CustomBatchLogger):
Otherwise, use the default credentials.
"""
standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = (
kwargs.get("standard_callback_dynamic_params", None)
)
standard_callback_dynamic_params: Optional[
StandardCallbackDynamicParams
] = kwargs.get("standard_callback_dynamic_params", None)
if standard_callback_dynamic_params is not None:
credentials = self.get_credentials_from_env(
langsmith_api_key=standard_callback_dynamic_params.get(
@ -481,7 +480,6 @@ class LangsmithLogger(CustomBatchLogger):
asyncio.run(self.async_send_batch())
def get_run_by_id(self, run_id):
langsmith_api_key = self.default_credentials["LANGSMITH_API_KEY"]
langsmith_api_base = self.default_credentials["LANGSMITH_BASE_URL"]

View file

@ -1,12 +1,12 @@
import json
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Union
from litellm.proxy._types import SpanAttributes
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = _Span
Span = Union[_Span, Any]
else:
Span = Any

View file

@ -20,7 +20,6 @@ def parse_tool_calls(tool_calls):
return None
def clean_tool_call(tool_call):
serialized = {
"type": tool_call.type,
"id": tool_call.id,
@ -36,7 +35,6 @@ def parse_tool_calls(tool_calls):
def parse_messages(input):
if input is None:
return None

View file

@ -48,14 +48,17 @@ class MlflowLogger(CustomLogger):
def _extract_and_set_chat_attributes(self, span, kwargs, response_obj):
try:
from mlflow.tracing.utils import set_span_chat_messages, set_span_chat_tools
from mlflow.tracing.utils import set_span_chat_messages # type: ignore
from mlflow.tracing.utils import set_span_chat_tools # type: ignore
except ImportError:
return
inputs = self._construct_input(kwargs)
input_messages = inputs.get("messages", [])
output_messages = [c.message.model_dump(exclude_none=True)
for c in getattr(response_obj, "choices", [])]
output_messages = [
c.message.model_dump(exclude_none=True)
for c in getattr(response_obj, "choices", [])
]
if messages := [*input_messages, *output_messages]:
set_span_chat_messages(span, messages)
if tools := inputs.get("tools"):

View file

@ -1,7 +1,7 @@
import os
from dataclasses import dataclass
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
import litellm
from litellm._logging import verbose_logger
@ -23,10 +23,10 @@ if TYPE_CHECKING:
)
from litellm.proxy.proxy_server import UserAPIKeyAuth as _UserAPIKeyAuth
Span = _Span
SpanExporter = _SpanExporter
UserAPIKeyAuth = _UserAPIKeyAuth
ManagementEndpointLoggingPayload = _ManagementEndpointLoggingPayload
Span = Union[_Span, Any]
SpanExporter = Union[_SpanExporter, Any]
UserAPIKeyAuth = Union[_UserAPIKeyAuth, Any]
ManagementEndpointLoggingPayload = Union[_ManagementEndpointLoggingPayload, Any]
else:
Span = Any
SpanExporter = Any
@ -46,7 +46,6 @@ LITELLM_REQUEST_SPAN_NAME = "litellm_request"
@dataclass
class OpenTelemetryConfig:
exporter: Union[str, SpanExporter] = "console"
endpoint: Optional[str] = None
headers: Optional[str] = None
@ -154,7 +153,6 @@ class OpenTelemetry(CustomLogger):
end_time: Optional[Union[datetime, float]] = None,
event_metadata: Optional[dict] = None,
):
from opentelemetry import trace
from opentelemetry.trace import Status, StatusCode
@ -215,7 +213,6 @@ class OpenTelemetry(CustomLogger):
end_time: Optional[Union[float, datetime]] = None,
event_metadata: Optional[dict] = None,
):
from opentelemetry import trace
from opentelemetry.trace import Status, StatusCode
@ -353,9 +350,9 @@ class OpenTelemetry(CustomLogger):
"""
from opentelemetry import trace
standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = (
kwargs.get("standard_callback_dynamic_params")
)
standard_callback_dynamic_params: Optional[
StandardCallbackDynamicParams
] = kwargs.get("standard_callback_dynamic_params")
if not standard_callback_dynamic_params:
return
@ -722,7 +719,6 @@ class OpenTelemetry(CustomLogger):
span.set_attribute(key, primitive_value)
def set_raw_request_attributes(self, span: Span, kwargs, response_obj):
kwargs.get("optional_params", {})
litellm_params = kwargs.get("litellm_params", {}) or {}
custom_llm_provider = litellm_params.get("custom_llm_provider", "Unknown")
@ -843,12 +839,14 @@ class OpenTelemetry(CustomLogger):
headers=dynamic_headers or self.OTEL_HEADERS
)
if isinstance(self.OTEL_EXPORTER, SpanExporter):
if hasattr(
self.OTEL_EXPORTER, "export"
): # Check if it has the export method that SpanExporter requires
verbose_logger.debug(
"OpenTelemetry: intiializing SpanExporter. Value of OTEL_EXPORTER: %s",
self.OTEL_EXPORTER,
)
return SimpleSpanProcessor(self.OTEL_EXPORTER)
return SimpleSpanProcessor(cast(SpanExporter, self.OTEL_EXPORTER))
if self.OTEL_EXPORTER == "console":
verbose_logger.debug(
@ -907,7 +905,6 @@ class OpenTelemetry(CustomLogger):
logging_payload: ManagementEndpointLoggingPayload,
parent_otel_span: Optional[Span] = None,
):
from opentelemetry import trace
from opentelemetry.trace import Status, StatusCode
@ -961,7 +958,6 @@ class OpenTelemetry(CustomLogger):
logging_payload: ManagementEndpointLoggingPayload,
parent_otel_span: Optional[Span] = None,
):
from opentelemetry import trace
from opentelemetry.trace import Status, StatusCode

View file

@ -185,7 +185,6 @@ class OpikLogger(CustomBatchLogger):
def _create_opik_payload( # noqa: PLR0915
self, kwargs, response_obj, start_time, end_time
) -> List[Dict]:
# Get metadata
_litellm_params = kwargs.get("litellm_params", {}) or {}
litellm_params_metadata = _litellm_params.get("metadata", {}) or {}

View file

@ -988,9 +988,9 @@ class PrometheusLogger(CustomLogger):
):
try:
verbose_logger.debug("setting remaining tokens requests metric")
standard_logging_payload: Optional[StandardLoggingPayload] = (
request_kwargs.get("standard_logging_object")
)
standard_logging_payload: Optional[
StandardLoggingPayload
] = request_kwargs.get("standard_logging_object")
if standard_logging_payload is None:
return

View file

@ -14,7 +14,6 @@ class PromptManagementClient(TypedDict):
class PromptManagementBase(ABC):
@property
@abstractmethod
def integration_name(self) -> str:
@ -83,11 +82,7 @@ class PromptManagementBase(ABC):
prompt_id: str,
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
) -> Tuple[
str,
List[AllMessageValues],
dict,
]:
) -> Tuple[str, List[AllMessageValues], dict,]:
if not self.should_run_prompt_management(
prompt_id=prompt_id, dynamic_callback_params=dynamic_callback_params
):

View file

@ -38,7 +38,7 @@ class S3Logger:
if litellm.s3_callback_params is not None:
# read in .env variables - example os.environ/AWS_BUCKET_NAME
for key, value in litellm.s3_callback_params.items():
if type(value) is str and value.startswith("os.environ/"):
if isinstance(value, str) and value.startswith("os.environ/"):
litellm.s3_callback_params[key] = litellm.get_secret(value)
# now set s3 params from litellm.s3_logger_params
s3_bucket_name = litellm.s3_callback_params.get("s3_bucket_name")

View file

@ -21,11 +21,11 @@ try:
# contains a (known) object attribute
object: Literal["chat.completion", "edit", "text_completion"]
def __getitem__(self, key: K) -> V: ... # noqa
def __getitem__(self, key: K) -> V:
... # noqa
def get( # noqa
self, key: K, default: Optional[V] = None
) -> Optional[V]: ... # pragma: no cover
def get(self, key: K, default: Optional[V] = None) -> Optional[V]: # noqa
... # pragma: no cover
class OpenAIRequestResponseResolver:
def __call__(

View file

@ -10,7 +10,7 @@ from litellm.types.llms.openai import AllMessageValues
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = _Span
Span = Union[_Span, Any]
else:
Span = Any

View file

@ -11,7 +11,9 @@ except (ImportError, AttributeError):
# Old way to access resources, which setuptools deprecated some time ago
import pkg_resources # type: ignore
filename = pkg_resources.resource_filename(__name__, "litellm_core_utils/tokenizers")
filename = pkg_resources.resource_filename(
__name__, "litellm_core_utils/tokenizers"
)
os.environ["TIKTOKEN_CACHE_DIR"] = os.getenv(
"CUSTOM_TIKTOKEN_CACHE_DIR", filename

View file

@ -239,9 +239,9 @@ class Logging(LiteLLMLoggingBaseClass):
self.litellm_trace_id = litellm_trace_id
self.function_id = function_id
self.streaming_chunks: List[Any] = [] # for generating complete stream response
self.sync_streaming_chunks: List[Any] = (
[]
) # for generating complete stream response
self.sync_streaming_chunks: List[
Any
] = [] # for generating complete stream response
self.log_raw_request_response = log_raw_request_response
# Initialize dynamic callbacks
@ -452,18 +452,19 @@ class Logging(LiteLLMLoggingBaseClass):
prompt_id: str,
prompt_variables: Optional[dict],
) -> Tuple[str, List[AllMessageValues], dict]:
custom_logger = self.get_custom_logger_for_prompt_management(model)
if custom_logger:
model, messages, non_default_params = (
custom_logger.get_chat_completion_prompt(
model=model,
messages=messages,
non_default_params=non_default_params,
prompt_id=prompt_id,
prompt_variables=prompt_variables,
dynamic_callback_params=self.standard_callback_dynamic_params,
)
(
model,
messages,
non_default_params,
) = custom_logger.get_chat_completion_prompt(
model=model,
messages=messages,
non_default_params=non_default_params,
prompt_id=prompt_id,
prompt_variables=prompt_variables,
dynamic_callback_params=self.standard_callback_dynamic_params,
)
self.messages = messages
return model, messages, non_default_params
@ -541,12 +542,11 @@ class Logging(LiteLLMLoggingBaseClass):
model
): # if model name was changes pre-call, overwrite the initial model call name with the new one
self.model_call_details["model"] = model
self.model_call_details["litellm_params"]["api_base"] = (
self._get_masked_api_base(additional_args.get("api_base", ""))
)
self.model_call_details["litellm_params"][
"api_base"
] = self._get_masked_api_base(additional_args.get("api_base", ""))
def pre_call(self, input, api_key, model=None, additional_args={}): # noqa: PLR0915
# Log the exact input to the LLM API
litellm.error_logs["PRE_CALL"] = locals()
try:
@ -568,19 +568,16 @@ class Logging(LiteLLMLoggingBaseClass):
self.log_raw_request_response is True
or log_raw_request_response is True
):
_litellm_params = self.model_call_details.get("litellm_params", {})
_metadata = _litellm_params.get("metadata", {}) or {}
try:
# [Non-blocking Extra Debug Information in metadata]
if turn_off_message_logging is True:
_metadata["raw_request"] = (
"redacted by litellm. \
_metadata[
"raw_request"
] = "redacted by litellm. \
'litellm.turn_off_message_logging=True'"
)
else:
curl_command = self._get_request_curl_command(
api_base=additional_args.get("api_base", ""),
headers=additional_args.get("headers", {}),
@ -590,33 +587,33 @@ class Logging(LiteLLMLoggingBaseClass):
_metadata["raw_request"] = str(curl_command)
# split up, so it's easier to parse in the UI
self.model_call_details["raw_request_typed_dict"] = (
RawRequestTypedDict(
raw_request_api_base=str(
additional_args.get("api_base") or ""
),
raw_request_body=self._get_raw_request_body(
additional_args.get("complete_input_dict", {})
),
raw_request_headers=self._get_masked_headers(
additional_args.get("headers", {}) or {},
ignore_sensitive_headers=True,
),
error=None,
)
self.model_call_details[
"raw_request_typed_dict"
] = RawRequestTypedDict(
raw_request_api_base=str(
additional_args.get("api_base") or ""
),
raw_request_body=self._get_raw_request_body(
additional_args.get("complete_input_dict", {})
),
raw_request_headers=self._get_masked_headers(
additional_args.get("headers", {}) or {},
ignore_sensitive_headers=True,
),
error=None,
)
except Exception as e:
self.model_call_details["raw_request_typed_dict"] = (
RawRequestTypedDict(
error=str(e),
)
self.model_call_details[
"raw_request_typed_dict"
] = RawRequestTypedDict(
error=str(e),
)
traceback.print_exc()
_metadata["raw_request"] = (
"Unable to Log \
_metadata[
"raw_request"
] = "Unable to Log \
raw request: {}".format(
str(e)
)
str(e)
)
if self.logger_fn and callable(self.logger_fn):
try:
@ -941,9 +938,9 @@ class Logging(LiteLLMLoggingBaseClass):
verbose_logger.debug(
f"response_cost_failure_debug_information: {debug_info}"
)
self.model_call_details["response_cost_failure_debug_information"] = (
debug_info
)
self.model_call_details[
"response_cost_failure_debug_information"
] = debug_info
return None
try:
@ -968,9 +965,9 @@ class Logging(LiteLLMLoggingBaseClass):
verbose_logger.debug(
f"response_cost_failure_debug_information: {debug_info}"
)
self.model_call_details["response_cost_failure_debug_information"] = (
debug_info
)
self.model_call_details[
"response_cost_failure_debug_information"
] = debug_info
return None
@ -995,7 +992,6 @@ class Logging(LiteLLMLoggingBaseClass):
def should_run_callback(
self, callback: litellm.CALLBACK_TYPES, litellm_params: dict, event_hook: str
) -> bool:
if litellm.global_disable_no_log_param:
return True
@ -1027,9 +1023,9 @@ class Logging(LiteLLMLoggingBaseClass):
end_time = datetime.datetime.now()
if self.completion_start_time is None:
self.completion_start_time = end_time
self.model_call_details["completion_start_time"] = (
self.completion_start_time
)
self.model_call_details[
"completion_start_time"
] = self.completion_start_time
self.model_call_details["log_event_type"] = "successful_api_call"
self.model_call_details["end_time"] = end_time
self.model_call_details["cache_hit"] = cache_hit
@ -1083,39 +1079,39 @@ class Logging(LiteLLMLoggingBaseClass):
"response_cost"
]
else:
self.model_call_details["response_cost"] = (
self._response_cost_calculator(result=result)
)
self.model_call_details[
"response_cost"
] = self._response_cost_calculator(result=result)
## STANDARDIZED LOGGING PAYLOAD
self.model_call_details["standard_logging_object"] = (
get_standard_logging_object_payload(
kwargs=self.model_call_details,
init_response_obj=result,
start_time=start_time,
end_time=end_time,
logging_obj=self,
status="success",
standard_built_in_tools_params=self.standard_built_in_tools_params,
)
self.model_call_details[
"standard_logging_object"
] = get_standard_logging_object_payload(
kwargs=self.model_call_details,
init_response_obj=result,
start_time=start_time,
end_time=end_time,
logging_obj=self,
status="success",
standard_built_in_tools_params=self.standard_built_in_tools_params,
)
elif isinstance(result, dict): # pass-through endpoints
## STANDARDIZED LOGGING PAYLOAD
self.model_call_details["standard_logging_object"] = (
get_standard_logging_object_payload(
kwargs=self.model_call_details,
init_response_obj=result,
start_time=start_time,
end_time=end_time,
logging_obj=self,
status="success",
standard_built_in_tools_params=self.standard_built_in_tools_params,
)
self.model_call_details[
"standard_logging_object"
] = get_standard_logging_object_payload(
kwargs=self.model_call_details,
init_response_obj=result,
start_time=start_time,
end_time=end_time,
logging_obj=self,
status="success",
standard_built_in_tools_params=self.standard_built_in_tools_params,
)
elif standard_logging_object is not None:
self.model_call_details["standard_logging_object"] = (
standard_logging_object
)
self.model_call_details[
"standard_logging_object"
] = standard_logging_object
else: # streaming chunks + image gen.
self.model_call_details["response_cost"] = None
@ -1154,7 +1150,6 @@ class Logging(LiteLLMLoggingBaseClass):
standard_logging_object=kwargs.get("standard_logging_object", None),
)
try:
## BUILD COMPLETE STREAMED RESPONSE
complete_streaming_response: Optional[
Union[ModelResponse, TextCompletionResponse, ResponsesAPIResponse]
@ -1172,23 +1167,23 @@ class Logging(LiteLLMLoggingBaseClass):
verbose_logger.debug(
"Logging Details LiteLLM-Success Call streaming complete"
)
self.model_call_details["complete_streaming_response"] = (
complete_streaming_response
)
self.model_call_details["response_cost"] = (
self._response_cost_calculator(result=complete_streaming_response)
)
self.model_call_details[
"complete_streaming_response"
] = complete_streaming_response
self.model_call_details[
"response_cost"
] = self._response_cost_calculator(result=complete_streaming_response)
## STANDARDIZED LOGGING PAYLOAD
self.model_call_details["standard_logging_object"] = (
get_standard_logging_object_payload(
kwargs=self.model_call_details,
init_response_obj=complete_streaming_response,
start_time=start_time,
end_time=end_time,
logging_obj=self,
status="success",
standard_built_in_tools_params=self.standard_built_in_tools_params,
)
self.model_call_details[
"standard_logging_object"
] = get_standard_logging_object_payload(
kwargs=self.model_call_details,
init_response_obj=complete_streaming_response,
start_time=start_time,
end_time=end_time,
logging_obj=self,
status="success",
standard_built_in_tools_params=self.standard_built_in_tools_params,
)
callbacks = self.get_combined_callback_list(
dynamic_success_callbacks=self.dynamic_success_callbacks,
@ -1207,7 +1202,6 @@ class Logging(LiteLLMLoggingBaseClass):
## LOGGING HOOK ##
for callback in callbacks:
if isinstance(callback, CustomLogger):
self.model_call_details, result = callback.logging_hook(
kwargs=self.model_call_details,
result=result,
@ -1538,10 +1532,10 @@ class Logging(LiteLLMLoggingBaseClass):
)
else:
if self.stream and complete_streaming_response:
self.model_call_details["complete_response"] = (
self.model_call_details.get(
"complete_streaming_response", {}
)
self.model_call_details[
"complete_response"
] = self.model_call_details.get(
"complete_streaming_response", {}
)
result = self.model_call_details["complete_response"]
openMeterLogger.log_success_event(
@ -1581,10 +1575,10 @@ class Logging(LiteLLMLoggingBaseClass):
)
else:
if self.stream and complete_streaming_response:
self.model_call_details["complete_response"] = (
self.model_call_details.get(
"complete_streaming_response", {}
)
self.model_call_details[
"complete_response"
] = self.model_call_details.get(
"complete_streaming_response", {}
)
result = self.model_call_details["complete_response"]
@ -1659,7 +1653,6 @@ class Logging(LiteLLMLoggingBaseClass):
if self.call_type == CallTypes.aretrieve_batch.value and isinstance(
result, LiteLLMBatch
):
response_cost, batch_usage, batch_models = await _handle_completed_batch(
batch=result, custom_llm_provider=self.custom_llm_provider
)
@ -1692,9 +1685,9 @@ class Logging(LiteLLMLoggingBaseClass):
if complete_streaming_response is not None:
print_verbose("Async success callbacks: Got a complete streaming response")
self.model_call_details["async_complete_streaming_response"] = (
complete_streaming_response
)
self.model_call_details[
"async_complete_streaming_response"
] = complete_streaming_response
try:
if self.model_call_details.get("cache_hit", False) is True:
self.model_call_details["response_cost"] = 0.0
@ -1704,10 +1697,10 @@ class Logging(LiteLLMLoggingBaseClass):
model_call_details=self.model_call_details
)
# base_model defaults to None if not set on model_info
self.model_call_details["response_cost"] = (
self._response_cost_calculator(
result=complete_streaming_response
)
self.model_call_details[
"response_cost"
] = self._response_cost_calculator(
result=complete_streaming_response
)
verbose_logger.debug(
@ -1720,16 +1713,16 @@ class Logging(LiteLLMLoggingBaseClass):
self.model_call_details["response_cost"] = None
## STANDARDIZED LOGGING PAYLOAD
self.model_call_details["standard_logging_object"] = (
get_standard_logging_object_payload(
kwargs=self.model_call_details,
init_response_obj=complete_streaming_response,
start_time=start_time,
end_time=end_time,
logging_obj=self,
status="success",
standard_built_in_tools_params=self.standard_built_in_tools_params,
)
self.model_call_details[
"standard_logging_object"
] = get_standard_logging_object_payload(
kwargs=self.model_call_details,
init_response_obj=complete_streaming_response,
start_time=start_time,
end_time=end_time,
logging_obj=self,
status="success",
standard_built_in_tools_params=self.standard_built_in_tools_params,
)
callbacks = self.get_combined_callback_list(
dynamic_success_callbacks=self.dynamic_async_success_callbacks,
@ -1935,18 +1928,18 @@ class Logging(LiteLLMLoggingBaseClass):
## STANDARDIZED LOGGING PAYLOAD
self.model_call_details["standard_logging_object"] = (
get_standard_logging_object_payload(
kwargs=self.model_call_details,
init_response_obj={},
start_time=start_time,
end_time=end_time,
logging_obj=self,
status="failure",
error_str=str(exception),
original_exception=exception,
standard_built_in_tools_params=self.standard_built_in_tools_params,
)
self.model_call_details[
"standard_logging_object"
] = get_standard_logging_object_payload(
kwargs=self.model_call_details,
init_response_obj={},
start_time=start_time,
end_time=end_time,
logging_obj=self,
status="failure",
error_str=str(exception),
original_exception=exception,
standard_built_in_tools_params=self.standard_built_in_tools_params,
)
return start_time, end_time
@ -2084,7 +2077,6 @@ class Logging(LiteLLMLoggingBaseClass):
)
is not True
): # custom logger class
callback.log_failure_event(
start_time=start_time,
end_time=end_time,
@ -2713,9 +2705,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
endpoint=arize_config.endpoint,
)
os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = (
f"space_key={arize_config.space_key},api_key={arize_config.api_key}"
)
os.environ[
"OTEL_EXPORTER_OTLP_TRACES_HEADERS"
] = f"space_key={arize_config.space_key},api_key={arize_config.api_key}"
for callback in _in_memory_loggers:
if (
isinstance(callback, ArizeLogger)
@ -2739,9 +2731,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
# auth can be disabled on local deployments of arize phoenix
if arize_phoenix_config.otlp_auth_headers is not None:
os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = (
arize_phoenix_config.otlp_auth_headers
)
os.environ[
"OTEL_EXPORTER_OTLP_TRACES_HEADERS"
] = arize_phoenix_config.otlp_auth_headers
for callback in _in_memory_loggers:
if (
@ -2832,9 +2824,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
exporter="otlp_http",
endpoint="https://langtrace.ai/api/trace",
)
os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = (
f"api_key={os.getenv('LANGTRACE_API_KEY')}"
)
os.environ[
"OTEL_EXPORTER_OTLP_TRACES_HEADERS"
] = f"api_key={os.getenv('LANGTRACE_API_KEY')}"
for callback in _in_memory_loggers:
if (
isinstance(callback, OpenTelemetry)
@ -3223,7 +3215,6 @@ class StandardLoggingPayloadSetup:
custom_llm_provider: Optional[str],
init_response_obj: Union[Any, BaseModel, dict],
) -> StandardLoggingModelInformation:
model_cost_name = _select_model_name_for_cost_calc(
model=None,
completion_response=init_response_obj, # type: ignore
@ -3286,7 +3277,6 @@ class StandardLoggingPayloadSetup:
def get_additional_headers(
additiona_headers: Optional[dict],
) -> Optional[StandardLoggingAdditionalHeaders]:
if additiona_headers is None:
return None
@ -3322,10 +3312,10 @@ class StandardLoggingPayloadSetup:
for key in StandardLoggingHiddenParams.__annotations__.keys():
if key in hidden_params:
if key == "additional_headers":
clean_hidden_params["additional_headers"] = (
StandardLoggingPayloadSetup.get_additional_headers(
hidden_params[key]
)
clean_hidden_params[
"additional_headers"
] = StandardLoggingPayloadSetup.get_additional_headers(
hidden_params[key]
)
else:
clean_hidden_params[key] = hidden_params[key] # type: ignore
@ -3463,12 +3453,14 @@ def get_standard_logging_object_payload(
)
# cleanup timestamps
start_time_float, end_time_float, completion_start_time_float = (
StandardLoggingPayloadSetup.cleanup_timestamps(
start_time=start_time,
end_time=end_time,
completion_start_time=completion_start_time,
)
(
start_time_float,
end_time_float,
completion_start_time_float,
) = StandardLoggingPayloadSetup.cleanup_timestamps(
start_time=start_time,
end_time=end_time,
completion_start_time=completion_start_time,
)
response_time = StandardLoggingPayloadSetup.get_response_time(
start_time_float=start_time_float,
@ -3495,7 +3487,6 @@ def get_standard_logging_object_payload(
saved_cache_cost: float = 0.0
if cache_hit is True:
id = f"{id}_cache_hit{time.time()}" # do not duplicate the request id
saved_cache_cost = (
logging_obj._response_cost_calculator(
@ -3658,9 +3649,9 @@ def scrub_sensitive_keys_in_metadata(litellm_params: Optional[dict]):
):
for k, v in metadata["user_api_key_metadata"].items():
if k == "logging": # prevent logging user logging keys
cleaned_user_api_key_metadata[k] = (
"scrubbed_by_litellm_for_sensitive_keys"
)
cleaned_user_api_key_metadata[
k
] = "scrubbed_by_litellm_for_sensitive_keys"
else:
cleaned_user_api_key_metadata[k] = v

View file

@ -258,14 +258,12 @@ def _extract_reasoning_content(message: dict) -> Tuple[Optional[str], Optional[s
class LiteLLMResponseObjectHandler:
@staticmethod
def convert_to_image_response(
response_object: dict,
model_response_object: Optional[ImageResponse] = None,
hidden_params: Optional[dict] = None,
) -> ImageResponse:
response_object.update({"hidden_params": hidden_params})
if model_response_object is None:
@ -481,9 +479,9 @@ def convert_to_model_response_object( # noqa: PLR0915
provider_specific_fields["thinking_blocks"] = thinking_blocks
if reasoning_content:
provider_specific_fields["reasoning_content"] = (
reasoning_content
)
provider_specific_fields[
"reasoning_content"
] = reasoning_content
message = Message(
content=content,

View file

@ -17,7 +17,6 @@ from litellm.types.rerank import RerankRequest
class ModelParamHelper:
@staticmethod
def get_standard_logging_model_parameters(
model_parameters: dict,

View file

@ -257,7 +257,6 @@ def _insert_assistant_continue_message(
and message.get("role") == "user" # Current is user
and messages[i + 1].get("role") == "user"
): # Next is user
# Insert assistant message
continue_message = (
assistant_continue_message or DEFAULT_ASSISTANT_CONTINUE_MESSAGE

View file

@ -1042,10 +1042,10 @@ def convert_to_gemini_tool_call_invoke(
if tool_calls is not None:
for tool in tool_calls:
if "function" in tool:
gemini_function_call: Optional[VertexFunctionCall] = (
_gemini_tool_call_invoke_helper(
function_call_params=tool["function"]
)
gemini_function_call: Optional[
VertexFunctionCall
] = _gemini_tool_call_invoke_helper(
function_call_params=tool["function"]
)
if gemini_function_call is not None:
_parts_list.append(
@ -1432,9 +1432,9 @@ def anthropic_messages_pt( # noqa: PLR0915
)
if "cache_control" in _content_element:
_anthropic_content_element["cache_control"] = (
_content_element["cache_control"]
)
_anthropic_content_element[
"cache_control"
] = _content_element["cache_control"]
user_content.append(_anthropic_content_element)
elif m.get("type", "") == "text":
m = cast(ChatCompletionTextObject, m)
@ -1466,9 +1466,9 @@ def anthropic_messages_pt( # noqa: PLR0915
)
if "cache_control" in _content_element:
_anthropic_content_text_element["cache_control"] = (
_content_element["cache_control"]
)
_anthropic_content_text_element[
"cache_control"
] = _content_element["cache_control"]
user_content.append(_anthropic_content_text_element)
@ -1533,7 +1533,6 @@ def anthropic_messages_pt( # noqa: PLR0915
"content"
] # don't pass empty text blocks. anthropic api raises errors.
):
_anthropic_text_content_element = AnthropicMessagesTextParam(
type="text",
text=assistant_content_block["content"],
@ -1569,7 +1568,6 @@ def anthropic_messages_pt( # noqa: PLR0915
msg_i += 1
if assistant_content:
new_messages.append({"role": "assistant", "content": assistant_content})
if msg_i == init_msg_i: # prevent infinite loops
@ -2245,7 +2243,6 @@ class BedrockImageProcessor:
@staticmethod
async def get_image_details_async(image_url) -> Tuple[str, str]:
try:
client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.PromptFactory,
params={"concurrent_limit": 1},
@ -2612,7 +2609,6 @@ def get_user_message_block_or_continue_message(
for item in modified_content_block:
# Check if the list is empty
if item["type"] == "text":
if not item["text"].strip():
# Replace empty text with continue message
_user_continue_message = ChatCompletionUserMessage(
@ -3207,7 +3203,6 @@ def _bedrock_converse_messages_pt( # noqa: PLR0915
assistant_content: List[BedrockContentBlock] = []
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
assistant_message_block = get_assistant_message_block_or_continue_message(
message=messages[msg_i],
assistant_continue_message=assistant_continue_message,
@ -3410,7 +3405,6 @@ def response_schema_prompt(model: str, response_schema: dict) -> str:
{"role": "user", "content": "{}".format(response_schema)}
]
if f"{model}/response_schema_prompt" in litellm.custom_prompt_dict:
custom_prompt_details = litellm.custom_prompt_dict[
f"{model}/response_schema_prompt"
] # allow user to define custom response schema prompt by model

View file

@ -122,7 +122,6 @@ class RealTimeStreaming:
pass
async def bidirectional_forward(self):
forward_task = asyncio.create_task(self.backend_to_client_send_messages())
try:
await self.client_ack_messages()

View file

@ -135,9 +135,9 @@ def _get_turn_off_message_logging_from_dynamic_params(
handles boolean and string values of `turn_off_message_logging`
"""
standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = (
model_call_details.get("standard_callback_dynamic_params", None)
)
standard_callback_dynamic_params: Optional[
StandardCallbackDynamicParams
] = model_call_details.get("standard_callback_dynamic_params", None)
if standard_callback_dynamic_params:
_turn_off_message_logging = standard_callback_dynamic_params.get(
"turn_off_message_logging"

View file

@ -1,4 +1,5 @@
from typing import Any, Dict, Optional, Set
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
@ -40,7 +41,10 @@ class SensitiveDataMasker:
return result
def mask_dict(
self, data: Dict[str, Any], depth: int = 0, max_depth: int = DEFAULT_MAX_RECURSE_DEPTH
self,
data: Dict[str, Any],
depth: int = 0,
max_depth: int = DEFAULT_MAX_RECURSE_DEPTH,
) -> Dict[str, Any]:
if depth >= max_depth:
return data

View file

@ -104,7 +104,6 @@ class ChunkProcessor:
def get_combined_tool_content(
self, tool_call_chunks: List[Dict[str, Any]]
) -> List[ChatCompletionMessageToolCall]:
argument_list: List[str] = []
delta = tool_call_chunks[0]["choices"][0]["delta"]
id = None

View file

@ -84,9 +84,9 @@ class CustomStreamWrapper:
self.system_fingerprint: Optional[str] = None
self.received_finish_reason: Optional[str] = None
self.intermittent_finish_reason: Optional[str] = (
None # finish reasons that show up mid-stream
)
self.intermittent_finish_reason: Optional[
str
] = None # finish reasons that show up mid-stream
self.special_tokens = [
"<|assistant|>",
"<|system|>",
@ -814,7 +814,6 @@ class CustomStreamWrapper:
model_response: ModelResponseStream,
response_obj: Dict[str, Any],
):
print_verbose(
f"completion_obj: {completion_obj}, model_response.choices[0]: {model_response.choices[0]}, response_obj: {response_obj}"
)
@ -1008,7 +1007,6 @@ class CustomStreamWrapper:
self.custom_llm_provider
and self.custom_llm_provider in litellm._custom_providers
):
if self.received_finish_reason is not None:
if "provider_specific_fields" not in chunk:
raise StopIteration
@ -1379,9 +1377,9 @@ class CustomStreamWrapper:
_json_delta = delta.model_dump()
print_verbose(f"_json_delta: {_json_delta}")
if "role" not in _json_delta or _json_delta["role"] is None:
_json_delta["role"] = (
"assistant" # mistral's api returns role as None
)
_json_delta[
"role"
] = "assistant" # mistral's api returns role as None
if "tool_calls" in _json_delta and isinstance(
_json_delta["tool_calls"], list
):
@ -1758,9 +1756,9 @@ class CustomStreamWrapper:
chunk = next(self.completion_stream)
if chunk is not None and chunk != b"":
print_verbose(f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}")
processed_chunk: Optional[ModelResponseStream] = (
self.chunk_creator(chunk=chunk)
)
processed_chunk: Optional[
ModelResponseStream
] = self.chunk_creator(chunk=chunk)
print_verbose(
f"PROCESSED CHUNK POST CHUNK CREATOR: {processed_chunk}"
)

View file

@ -290,7 +290,6 @@ class AnthropicChatCompletion(BaseLLM):
headers={},
client=None,
):
optional_params = copy.deepcopy(optional_params)
stream = optional_params.pop("stream", None)
json_mode: bool = optional_params.pop("json_mode", False)
@ -491,7 +490,6 @@ class ModelResponseIterator:
def _handle_usage(
self, anthropic_usage_chunk: Union[dict, UsageDelta]
) -> AnthropicChatCompletionUsageBlock:
usage_block = AnthropicChatCompletionUsageBlock(
prompt_tokens=anthropic_usage_chunk.get("input_tokens", 0),
completion_tokens=anthropic_usage_chunk.get("output_tokens", 0),
@ -515,7 +513,9 @@ class ModelResponseIterator:
return usage_block
def _content_block_delta_helper(self, chunk: dict) -> Tuple[
def _content_block_delta_helper(
self, chunk: dict
) -> Tuple[
str,
Optional[ChatCompletionToolCallChunk],
List[ChatCompletionThinkingBlock],
@ -592,9 +592,12 @@ class ModelResponseIterator:
Anthropic content chunk
chunk = {'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': 'Hello'}}
"""
text, tool_use, thinking_blocks, provider_specific_fields = (
self._content_block_delta_helper(chunk=chunk)
)
(
text,
tool_use,
thinking_blocks,
provider_specific_fields,
) = self._content_block_delta_helper(chunk=chunk)
if thinking_blocks:
reasoning_content = self._handle_reasoning_content(
thinking_blocks=thinking_blocks
@ -620,7 +623,6 @@ class ModelResponseIterator:
"index": self.tool_index,
}
elif type_chunk == "content_block_stop":
ContentBlockStop(**chunk) # type: ignore
# check if tool call content block
is_empty = self.check_empty_tool_call_args()

View file

@ -49,9 +49,9 @@ class AnthropicConfig(BaseConfig):
to pass metadata to anthropic, it's {"user_id": "any-relevant-information"}
"""
max_tokens: Optional[int] = (
4096 # anthropic requires a default value (Opus, Sonnet, and Haiku have the same default)
)
max_tokens: Optional[
int
] = 4096 # anthropic requires a default value (Opus, Sonnet, and Haiku have the same default)
stop_sequences: Optional[list] = None
temperature: Optional[int] = None
top_p: Optional[int] = None
@ -104,7 +104,6 @@ class AnthropicConfig(BaseConfig):
def get_json_schema_from_pydantic_object(
self, response_format: Union[Any, Dict, None]
) -> Optional[dict]:
return type_to_response_format_param(
response_format, ref_template="/$defs/{model}"
) # Relevant issue: https://github.com/BerriAI/litellm/issues/7755
@ -125,7 +124,6 @@ class AnthropicConfig(BaseConfig):
is_vertex_request: bool = False,
user_anthropic_beta_headers: Optional[List[str]] = None,
) -> dict:
betas = set()
if prompt_caching_set:
betas.add("prompt-caching-2024-07-31")
@ -300,7 +298,6 @@ class AnthropicConfig(BaseConfig):
model: str,
drop_params: bool,
) -> dict:
is_thinking_enabled = self.is_thinking_enabled(
non_default_params=non_default_params
)
@ -321,11 +318,11 @@ class AnthropicConfig(BaseConfig):
optional_params=optional_params, tools=tool_value
)
if param == "tool_choice" or param == "parallel_tool_calls":
_tool_choice: Optional[AnthropicMessagesToolChoice] = (
self._map_tool_choice(
tool_choice=non_default_params.get("tool_choice"),
parallel_tool_use=non_default_params.get("parallel_tool_calls"),
)
_tool_choice: Optional[
AnthropicMessagesToolChoice
] = self._map_tool_choice(
tool_choice=non_default_params.get("tool_choice"),
parallel_tool_use=non_default_params.get("parallel_tool_calls"),
)
if _tool_choice is not None:
@ -341,7 +338,6 @@ class AnthropicConfig(BaseConfig):
if param == "top_p":
optional_params["top_p"] = value
if param == "response_format" and isinstance(value, dict):
ignore_response_format_types = ["text"]
if value["type"] in ignore_response_format_types: # value is a no-op
continue
@ -470,9 +466,9 @@ class AnthropicConfig(BaseConfig):
text=system_message_block["content"],
)
if "cache_control" in system_message_block:
anthropic_system_message_content["cache_control"] = (
system_message_block["cache_control"]
)
anthropic_system_message_content[
"cache_control"
] = system_message_block["cache_control"]
anthropic_system_message_list.append(
anthropic_system_message_content
)
@ -486,9 +482,9 @@ class AnthropicConfig(BaseConfig):
)
)
if "cache_control" in _content:
anthropic_system_message_content["cache_control"] = (
_content["cache_control"]
)
anthropic_system_message_content[
"cache_control"
] = _content["cache_control"]
anthropic_system_message_list.append(
anthropic_system_message_content
@ -597,7 +593,9 @@ class AnthropicConfig(BaseConfig):
)
return _message
def extract_response_content(self, completion_response: dict) -> Tuple[
def extract_response_content(
self, completion_response: dict
) -> Tuple[
str,
Optional[List[Any]],
Optional[List[ChatCompletionThinkingBlock]],
@ -693,9 +691,13 @@ class AnthropicConfig(BaseConfig):
reasoning_content: Optional[str] = None
tool_calls: List[ChatCompletionToolCallChunk] = []
text_content, citations, thinking_blocks, reasoning_content, tool_calls = (
self.extract_response_content(completion_response=completion_response)
)
(
text_content,
citations,
thinking_blocks,
reasoning_content,
tool_calls,
) = self.extract_response_content(completion_response=completion_response)
_message = litellm.Message(
tool_calls=tool_calls,

View file

@ -54,9 +54,9 @@ class AnthropicTextConfig(BaseConfig):
to pass metadata to anthropic, it's {"user_id": "any-relevant-information"}
"""
max_tokens_to_sample: Optional[int] = (
litellm.max_tokens
) # anthropic requires a default
max_tokens_to_sample: Optional[
int
] = litellm.max_tokens # anthropic requires a default
stop_sequences: Optional[list] = None
temperature: Optional[int] = None
top_p: Optional[int] = None

View file

@ -25,7 +25,6 @@ from litellm.utils import ProviderConfigManager, client
class AnthropicMessagesHandler:
@staticmethod
async def _handle_anthropic_streaming(
response: httpx.Response,
@ -74,19 +73,22 @@ async def anthropic_messages(
"""
# Use provided client or create a new one
optional_params = GenericLiteLLMParams(**kwargs)
model, _custom_llm_provider, dynamic_api_key, dynamic_api_base = (
litellm.get_llm_provider(
model=model,
custom_llm_provider=custom_llm_provider,
api_base=optional_params.api_base,
api_key=optional_params.api_key,
)
(
model,
_custom_llm_provider,
dynamic_api_key,
dynamic_api_base,
) = litellm.get_llm_provider(
model=model,
custom_llm_provider=custom_llm_provider,
api_base=optional_params.api_base,
api_key=optional_params.api_key,
)
anthropic_messages_provider_config: Optional[BaseAnthropicMessagesConfig] = (
ProviderConfigManager.get_provider_anthropic_messages_config(
model=model,
provider=litellm.LlmProviders(_custom_llm_provider),
)
anthropic_messages_provider_config: Optional[
BaseAnthropicMessagesConfig
] = ProviderConfigManager.get_provider_anthropic_messages_config(
model=model,
provider=litellm.LlmProviders(_custom_llm_provider),
)
if anthropic_messages_provider_config is None:
raise ValueError(

View file

@ -654,7 +654,6 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
) -> EmbeddingResponse:
response = None
try:
openai_aclient = self.get_azure_openai_client(
api_version=api_version,
api_base=api_base,
@ -835,7 +834,6 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
"2023-10-01-preview",
]
): # CREATE + POLL for azure dall-e-2 calls
api_base = modify_url(
original_url=api_base, new_path="/openai/images/generations:submit"
)
@ -867,7 +865,6 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
)
while response.json()["status"] not in ["succeeded", "failed"]:
if time.time() - start_time > timeout_secs:
raise AzureOpenAIError(
status_code=408, message="Operation polling timed out."
)
@ -935,7 +932,6 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
"2023-10-01-preview",
]
): # CREATE + POLL for azure dall-e-2 calls
api_base = modify_url(
original_url=api_base, new_path="/openai/images/generations:submit"
)
@ -1199,7 +1195,6 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
client=None,
litellm_params: Optional[dict] = None,
) -> HttpxBinaryResponseContent:
max_retries = optional_params.pop("max_retries", 2)
if aspeech is not None and aspeech is True:
@ -1253,7 +1248,6 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
client=None,
litellm_params: Optional[dict] = None,
) -> HttpxBinaryResponseContent:
azure_client: AsyncAzureOpenAI = self.get_azure_openai_client(
api_base=api_base,
api_version=api_version,

View file

@ -50,15 +50,15 @@ class AzureBatchesAPI(BaseAzureLLM):
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
litellm_params: Optional[dict] = None,
) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
client=client,
_is_async=_is_async,
litellm_params=litellm_params or {},
)
azure_client: Optional[
Union[AzureOpenAI, AsyncAzureOpenAI]
] = self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
client=client,
_is_async=_is_async,
litellm_params=litellm_params or {},
)
if azure_client is None:
raise ValueError(
@ -96,15 +96,15 @@ class AzureBatchesAPI(BaseAzureLLM):
client: Optional[AzureOpenAI] = None,
litellm_params: Optional[dict] = None,
):
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
client=client,
_is_async=_is_async,
litellm_params=litellm_params or {},
)
azure_client: Optional[
Union[AzureOpenAI, AsyncAzureOpenAI]
] = self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
client=client,
_is_async=_is_async,
litellm_params=litellm_params or {},
)
if azure_client is None:
raise ValueError(
@ -144,15 +144,15 @@ class AzureBatchesAPI(BaseAzureLLM):
client: Optional[AzureOpenAI] = None,
litellm_params: Optional[dict] = None,
):
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
client=client,
_is_async=_is_async,
litellm_params=litellm_params or {},
)
azure_client: Optional[
Union[AzureOpenAI, AsyncAzureOpenAI]
] = self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
client=client,
_is_async=_is_async,
litellm_params=litellm_params or {},
)
if azure_client is None:
raise ValueError(
@ -183,15 +183,15 @@ class AzureBatchesAPI(BaseAzureLLM):
client: Optional[AzureOpenAI] = None,
litellm_params: Optional[dict] = None,
):
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
client=client,
_is_async=_is_async,
litellm_params=litellm_params or {},
)
azure_client: Optional[
Union[AzureOpenAI, AsyncAzureOpenAI]
] = self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
client=client,
_is_async=_is_async,
litellm_params=litellm_params or {},
)
if azure_client is None:
raise ValueError(

View file

@ -306,7 +306,6 @@ class BaseAzureLLM(BaseOpenAILLM):
api_version: Optional[str],
is_async: bool,
) -> dict:
azure_ad_token_provider: Optional[Callable[[], str]] = None
# If we have api_key, then we have higher priority
azure_ad_token = litellm_params.get("azure_ad_token")

View file

@ -46,16 +46,15 @@ class AzureOpenAIFilesAPI(BaseAzureLLM):
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
litellm_params: Optional[dict] = None,
) -> Union[FileObject, Coroutine[Any, Any, FileObject]]:
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
self.get_azure_openai_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
api_version=api_version,
client=client,
_is_async=_is_async,
)
openai_client: Optional[
Union[AzureOpenAI, AsyncAzureOpenAI]
] = self.get_azure_openai_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
api_version=api_version,
client=client,
_is_async=_is_async,
)
if openai_client is None:
raise ValueError(
@ -95,15 +94,15 @@ class AzureOpenAIFilesAPI(BaseAzureLLM):
) -> Union[
HttpxBinaryResponseContent, Coroutine[Any, Any, HttpxBinaryResponseContent]
]:
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
self.get_azure_openai_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
api_version=api_version,
client=client,
_is_async=_is_async,
)
openai_client: Optional[
Union[AzureOpenAI, AsyncAzureOpenAI]
] = self.get_azure_openai_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
api_version=api_version,
client=client,
_is_async=_is_async,
)
if openai_client is None:
raise ValueError(
@ -145,15 +144,15 @@ class AzureOpenAIFilesAPI(BaseAzureLLM):
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
litellm_params: Optional[dict] = None,
):
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
self.get_azure_openai_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
api_version=api_version,
client=client,
_is_async=_is_async,
)
openai_client: Optional[
Union[AzureOpenAI, AsyncAzureOpenAI]
] = self.get_azure_openai_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
api_version=api_version,
client=client,
_is_async=_is_async,
)
if openai_client is None:
raise ValueError(
@ -197,15 +196,15 @@ class AzureOpenAIFilesAPI(BaseAzureLLM):
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
litellm_params: Optional[dict] = None,
):
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
self.get_azure_openai_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
api_version=api_version,
client=client,
_is_async=_is_async,
)
openai_client: Optional[
Union[AzureOpenAI, AsyncAzureOpenAI]
] = self.get_azure_openai_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
api_version=api_version,
client=client,
_is_async=_is_async,
)
if openai_client is None:
raise ValueError(
@ -251,15 +250,15 @@ class AzureOpenAIFilesAPI(BaseAzureLLM):
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
litellm_params: Optional[dict] = None,
):
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
self.get_azure_openai_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
api_version=api_version,
client=client,
_is_async=_is_async,
)
openai_client: Optional[
Union[AzureOpenAI, AsyncAzureOpenAI]
] = self.get_azure_openai_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
api_version=api_version,
client=client,
_is_async=_is_async,
)
if openai_client is None:
raise ValueError(

View file

@ -25,14 +25,7 @@ class AzureOpenAIFineTuningAPI(OpenAIFineTuningAPI, BaseAzureLLM):
_is_async: bool = False,
api_version: Optional[str] = None,
litellm_params: Optional[dict] = None,
) -> Optional[
Union[
OpenAI,
AsyncOpenAI,
AzureOpenAI,
AsyncAzureOpenAI,
]
]:
) -> Optional[Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI,]]:
# Override to use Azure-specific client initialization
if isinstance(client, OpenAI) or isinstance(client, AsyncOpenAI):
client = None

View file

@ -145,7 +145,6 @@ class AzureAIStudioConfig(OpenAIConfig):
2. If message contains an image or audio, send as is (user-intended)
"""
for message in messages:
# Do nothing if the message contains an image or audio
if _audio_or_image_in_message_content(message):
continue

View file

@ -22,7 +22,6 @@ class AzureAICohereConfig:
pass
def _map_azure_model_group(self, model: str) -> str:
if model == "offer-cohere-embed-multili-paygo":
return "Cohere-embed-v3-multilingual"
elif model == "offer-cohere-embed-english-paygo":

View file

@ -17,7 +17,6 @@ from .cohere_transformation import AzureAICohereConfig
class AzureAIEmbedding(OpenAIChatCompletion):
def _process_response(
self,
image_embedding_responses: Optional[List],
@ -145,7 +144,6 @@ class AzureAIEmbedding(OpenAIChatCompletion):
api_base: Optional[str] = None,
client=None,
) -> EmbeddingResponse:
(
image_embeddings_request,
v1_embeddings_request,

View file

@ -17,6 +17,7 @@ class AzureAIRerankConfig(CohereRerankConfig):
"""
Azure AI Rerank - Follows the same Spec as Cohere Rerank
"""
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
if api_base is None:
raise ValueError(

View file

@ -9,7 +9,6 @@ from litellm.types.utils import ModelResponse, TextCompletionResponse
class BaseLLM:
_client_session: Optional[httpx.Client] = None
def process_response(

View file

@ -218,7 +218,6 @@ class BaseConfig(ABC):
json_schema = value["json_schema"]["schema"]
if json_schema and not is_response_format_supported:
_tool_choice = ChatCompletionToolChoiceObjectParam(
type="function",
function=ChatCompletionToolChoiceFunctionParam(

View file

@ -58,7 +58,6 @@ class BaseResponsesAPIConfig(ABC):
model: str,
drop_params: bool,
) -> Dict:
pass
@abstractmethod

View file

@ -81,7 +81,6 @@ def make_sync_call(
class BedrockConverseLLM(BaseAWSLLM):
def __init__(self) -> None:
super().__init__()
@ -114,7 +113,6 @@ class BedrockConverseLLM(BaseAWSLLM):
fake_stream: bool = False,
json_mode: Optional[bool] = False,
) -> CustomStreamWrapper:
request_data = await litellm.AmazonConverseConfig()._async_transform_request(
model=model,
messages=messages,
@ -179,7 +177,6 @@ class BedrockConverseLLM(BaseAWSLLM):
headers: dict = {},
client: Optional[AsyncHTTPHandler] = None,
) -> Union[ModelResponse, CustomStreamWrapper]:
request_data = await litellm.AmazonConverseConfig()._async_transform_request(
model=model,
messages=messages,
@ -265,7 +262,6 @@ class BedrockConverseLLM(BaseAWSLLM):
extra_headers: Optional[dict] = None,
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
):
## SETUP ##
stream = optional_params.pop("stream", None)
unencoded_model_id = optional_params.pop("model_id", None)
@ -301,9 +297,9 @@ class BedrockConverseLLM(BaseAWSLLM):
aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None)
optional_params.pop("aws_region_name", None)
litellm_params["aws_region_name"] = (
aws_region_name # [DO NOT DELETE] important for async calls
)
litellm_params[
"aws_region_name"
] = aws_region_name # [DO NOT DELETE] important for async calls
credentials: Credentials = self.get_credentials(
aws_access_key_id=aws_access_key_id,

View file

@ -223,7 +223,6 @@ class AmazonConverseConfig(BaseConfig):
)
for param, value in non_default_params.items():
if param == "response_format" and isinstance(value, dict):
ignore_response_format_types = ["text"]
if value["type"] in ignore_response_format_types: # value is a no-op
continue
@ -715,9 +714,9 @@ class AmazonConverseConfig(BaseConfig):
chat_completion_message: ChatCompletionResponseMessage = {"role": "assistant"}
content_str = ""
tools: List[ChatCompletionToolCallChunk] = []
reasoningContentBlocks: Optional[List[BedrockConverseReasoningContentBlock]] = (
None
)
reasoningContentBlocks: Optional[
List[BedrockConverseReasoningContentBlock]
] = None
if message is not None:
for idx, content in enumerate(message["content"]):
@ -727,7 +726,6 @@ class AmazonConverseConfig(BaseConfig):
if "text" in content:
content_str += content["text"]
if "toolUse" in content:
## check tool name was formatted by litellm
_response_tool_name = content["toolUse"]["name"]
response_tool_name = get_bedrock_tool_name(
@ -754,12 +752,12 @@ class AmazonConverseConfig(BaseConfig):
chat_completion_message["provider_specific_fields"] = {
"reasoningContentBlocks": reasoningContentBlocks,
}
chat_completion_message["reasoning_content"] = (
self._transform_reasoning_content(reasoningContentBlocks)
)
chat_completion_message["thinking_blocks"] = (
self._transform_thinking_blocks(reasoningContentBlocks)
)
chat_completion_message[
"reasoning_content"
] = self._transform_reasoning_content(reasoningContentBlocks)
chat_completion_message[
"thinking_blocks"
] = self._transform_thinking_blocks(reasoningContentBlocks)
chat_completion_message["content"] = content_str
if json_mode is True and tools is not None and len(tools) == 1:
# to support 'json_schema' logic on bedrock models

View file

@ -496,9 +496,9 @@ class BedrockLLM(BaseAWSLLM):
content=None,
)
model_response.choices[0].message = _message # type: ignore
model_response._hidden_params["original_response"] = (
outputText # allow user to access raw anthropic tool calling response
)
model_response._hidden_params[
"original_response"
] = outputText # allow user to access raw anthropic tool calling response
if (
_is_function_call is True
and stream is not None
@ -806,9 +806,9 @@ class BedrockLLM(BaseAWSLLM):
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
if stream is True:
inference_params["stream"] = (
True # cohere requires stream = True in inference params
)
inference_params[
"stream"
] = True # cohere requires stream = True in inference params
data = json.dumps({"prompt": prompt, **inference_params})
elif provider == "anthropic":
if model.startswith("anthropic.claude-3"):
@ -1205,7 +1205,6 @@ class BedrockLLM(BaseAWSLLM):
def get_response_stream_shape():
global _response_stream_shape_cache
if _response_stream_shape_cache is None:
from botocore.loaders import Loader
from botocore.model import ServiceModel
@ -1539,7 +1538,6 @@ class AmazonDeepSeekR1StreamDecoder(AWSEventStreamDecoder):
model: str,
sync_stream: bool,
) -> None:
super().__init__(model=model)
from litellm.llms.bedrock.chat.invoke_transformations.amazon_deepseek_transformation import (
AmazonDeepseekR1ResponseIterator,

View file

@ -225,9 +225,9 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
if stream is True:
inference_params["stream"] = (
True # cohere requires stream = True in inference params
)
inference_params[
"stream"
] = True # cohere requires stream = True in inference params
request_data = {"prompt": prompt, **inference_params}
elif provider == "anthropic":
return litellm.AmazonAnthropicClaude3Config().transform_request(
@ -311,7 +311,6 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
try:
completion_response = raw_response.json()
except Exception:

View file

@ -314,7 +314,6 @@ def get_bedrock_tool_name(response_tool_name: str) -> str:
class BedrockModelInfo(BaseLLMModelInfo):
global_config = AmazonBedrockGlobalConfig()
all_global_regions = global_config.get_all_regions()

View file

@ -33,9 +33,9 @@ class AmazonTitanMultimodalEmbeddingG1Config:
) -> dict:
for k, v in non_default_params.items():
if k == "dimensions":
optional_params["embeddingConfig"] = (
AmazonTitanMultimodalEmbeddingConfig(outputEmbeddingLength=v)
)
optional_params[
"embeddingConfig"
] = AmazonTitanMultimodalEmbeddingConfig(outputEmbeddingLength=v)
return optional_params
def _transform_request(
@ -58,7 +58,6 @@ class AmazonTitanMultimodalEmbeddingG1Config:
def _transform_response(
self, response_list: List[dict], model: str
) -> EmbeddingResponse:
total_prompt_tokens = 0
transformed_responses: List[Embedding] = []
for index, response in enumerate(response_list):

View file

@ -1,12 +1,16 @@
import types
from typing import List, Optional
from typing import Any, Dict, List, Optional
from openai.types.image import Image
from litellm.types.llms.bedrock import (
AmazonNovaCanvasTextToImageRequest, AmazonNovaCanvasTextToImageResponse,
AmazonNovaCanvasTextToImageParams, AmazonNovaCanvasRequestBase, AmazonNovaCanvasColorGuidedGenerationParams,
AmazonNovaCanvasColorGuidedGenerationParams,
AmazonNovaCanvasColorGuidedRequest,
AmazonNovaCanvasImageGenerationConfig,
AmazonNovaCanvasRequestBase,
AmazonNovaCanvasTextToImageParams,
AmazonNovaCanvasTextToImageRequest,
AmazonNovaCanvasTextToImageResponse,
)
from litellm.types.utils import ImageResponse
@ -23,7 +27,7 @@ class AmazonNovaCanvasConfig:
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
and not isinstance(
v,
(
types.FunctionType,
@ -32,13 +36,12 @@ class AmazonNovaCanvasConfig:
staticmethod,
),
)
and v is not None
and v is not None
}
@classmethod
def get_supported_openai_params(cls, model: Optional[str] = None) -> List:
"""
"""
""" """
return ["n", "size", "quality"]
@classmethod
@ -56,7 +59,7 @@ class AmazonNovaCanvasConfig:
@classmethod
def transform_request_body(
cls, text: str, optional_params: dict
cls, text: str, optional_params: dict
) -> AmazonNovaCanvasRequestBase:
"""
Transform the request body for Amazon Nova Canvas model
@ -65,18 +68,64 @@ class AmazonNovaCanvasConfig:
image_generation_config = optional_params.pop("imageGenerationConfig", {})
image_generation_config = {**image_generation_config, **optional_params}
if task_type == "TEXT_IMAGE":
text_to_image_params = image_generation_config.pop("textToImageParams", {})
text_to_image_params = {"text" :text, **text_to_image_params}
text_to_image_params = AmazonNovaCanvasTextToImageParams(**text_to_image_params)
return AmazonNovaCanvasTextToImageRequest(textToImageParams=text_to_image_params, taskType=task_type,
imageGenerationConfig=image_generation_config)
text_to_image_params: Dict[str, Any] = image_generation_config.pop(
"textToImageParams", {}
)
text_to_image_params = {"text": text, **text_to_image_params}
try:
text_to_image_params_typed = AmazonNovaCanvasTextToImageParams(
**text_to_image_params # type: ignore
)
except Exception as e:
raise ValueError(
f"Error transforming text to image params: {e}. Got params: {text_to_image_params}, Expected params: {AmazonNovaCanvasTextToImageParams.__annotations__}"
)
try:
image_generation_config_typed = AmazonNovaCanvasImageGenerationConfig(
**image_generation_config
)
except Exception as e:
raise ValueError(
f"Error transforming image generation config: {e}. Got params: {image_generation_config}, Expected params: {AmazonNovaCanvasImageGenerationConfig.__annotations__}"
)
return AmazonNovaCanvasTextToImageRequest(
textToImageParams=text_to_image_params_typed,
taskType=task_type,
imageGenerationConfig=image_generation_config_typed,
)
if task_type == "COLOR_GUIDED_GENERATION":
color_guided_generation_params = image_generation_config.pop("colorGuidedGenerationParams", {})
color_guided_generation_params = {"text": text, **color_guided_generation_params}
color_guided_generation_params = AmazonNovaCanvasColorGuidedGenerationParams(**color_guided_generation_params)
return AmazonNovaCanvasColorGuidedRequest(taskType=task_type,
colorGuidedGenerationParams=color_guided_generation_params,
imageGenerationConfig=image_generation_config)
color_guided_generation_params: Dict[
str, Any
] = image_generation_config.pop("colorGuidedGenerationParams", {})
color_guided_generation_params = {
"text": text,
**color_guided_generation_params,
}
try:
color_guided_generation_params_typed = AmazonNovaCanvasColorGuidedGenerationParams(
**color_guided_generation_params # type: ignore
)
except Exception as e:
raise ValueError(
f"Error transforming color guided generation params: {e}. Got params: {color_guided_generation_params}, Expected params: {AmazonNovaCanvasColorGuidedGenerationParams.__annotations__}"
)
try:
image_generation_config_typed = AmazonNovaCanvasImageGenerationConfig(
**image_generation_config
)
except Exception as e:
raise ValueError(
f"Error transforming image generation config: {e}. Got params: {image_generation_config}, Expected params: {AmazonNovaCanvasImageGenerationConfig.__annotations__}"
)
return AmazonNovaCanvasColorGuidedRequest(
taskType=task_type,
colorGuidedGenerationParams=color_guided_generation_params_typed,
imageGenerationConfig=image_generation_config_typed,
)
raise NotImplementedError(f"Task type {task_type} is not supported")
@classmethod
@ -87,7 +136,9 @@ class AmazonNovaCanvasConfig:
_size = non_default_params.get("size")
if _size is not None:
width, height = _size.split("x")
optional_params["width"], optional_params["height"] = int(width), int(height)
optional_params["width"], optional_params["height"] = int(width), int(
height
)
if non_default_params.get("n") is not None:
optional_params["numberOfImages"] = non_default_params.get("n")
if non_default_params.get("quality") is not None:
@ -99,7 +150,7 @@ class AmazonNovaCanvasConfig:
@classmethod
def transform_response_dict_to_openai_response(
cls, model_response: ImageResponse, response_dict: dict
cls, model_response: ImageResponse, response_dict: dict
) -> ImageResponse:
"""
Transform the response dict to the OpenAI response

View file

@ -267,7 +267,11 @@ class BedrockImageGeneration(BaseAWSLLM):
**inference_params,
}
elif provider == "amazon":
return dict(litellm.AmazonNovaCanvasConfig.transform_request_body(text=prompt, optional_params=optional_params))
return dict(
litellm.AmazonNovaCanvasConfig.transform_request_body(
text=prompt, optional_params=optional_params
)
)
else:
raise BedrockError(
status_code=422, message=f"Unsupported model={model}, passed in"
@ -303,8 +307,11 @@ class BedrockImageGeneration(BaseAWSLLM):
config_class = (
litellm.AmazonStability3Config
if litellm.AmazonStability3Config._is_stability_3_model(model=model)
else litellm.AmazonNovaCanvasConfig if litellm.AmazonNovaCanvasConfig._is_nova_model(model=model)
else litellm.AmazonStabilityConfig
else (
litellm.AmazonNovaCanvasConfig
if litellm.AmazonNovaCanvasConfig._is_nova_model(model=model)
else litellm.AmazonStabilityConfig
)
)
config_class.transform_response_dict_to_openai_response(
model_response=model_response,

View file

@ -60,7 +60,6 @@ class BedrockRerankHandler(BaseAWSLLM):
extra_headers: Optional[dict] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
) -> RerankResponse:
request_data = RerankRequest(
model=model,
query=query,

View file

@ -29,7 +29,6 @@ from litellm.types.rerank import (
class BedrockRerankConfig:
def _transform_sources(
self, documents: List[Union[str, dict]]
) -> List[BedrockRerankSource]:

View file

@ -314,7 +314,6 @@ class CodestralTextCompletion:
return _response
### SYNC COMPLETION
else:
response = litellm.module_level_client.post(
url=completion_url,
headers=headers,
@ -352,13 +351,11 @@ class CodestralTextCompletion:
logger_fn=None,
headers={},
) -> TextCompletionResponse:
async_handler = get_async_httpx_client(
llm_provider=litellm.LlmProviders.TEXT_COMPLETION_CODESTRAL,
params={"timeout": timeout},
)
try:
response = await async_handler.post(
api_base, headers=headers, data=json.dumps(data)
)

View file

@ -78,7 +78,6 @@ class CodestralTextCompletionConfig(OpenAITextCompletionConfig):
return optional_params
def _chunk_parser(self, chunk_data: str) -> GenericStreamingChunk:
text = ""
is_finished = False
finish_reason = None

View file

@ -180,7 +180,6 @@ class CohereChatConfig(BaseConfig):
litellm_params: dict,
headers: dict,
) -> dict:
## Load Config
for k, v in litellm.CohereChatConfig.get_config().items():
if (
@ -222,7 +221,6 @@ class CohereChatConfig(BaseConfig):
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
try:
raw_response_json = raw_response.json()
model_response.choices[0].message.content = raw_response_json["text"] # type: ignore

View file

@ -56,7 +56,6 @@ async def async_embedding(
encoding: Callable,
client: Optional[AsyncHTTPHandler] = None,
):
## LOGGING
logging_obj.pre_call(
input=input,

View file

@ -72,7 +72,6 @@ class CohereEmbeddingConfig:
return transformed_request
def _calculate_usage(self, input: List[str], encoding: Any, meta: dict) -> Usage:
input_tokens = 0
text_tokens: Optional[int] = meta.get("billed_units", {}).get("input_tokens")
@ -111,7 +110,6 @@ class CohereEmbeddingConfig:
encoding: Any,
input: list,
) -> EmbeddingResponse:
response_json = response.json()
## LOGGING
logging_obj.post_call(

View file

@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Union
from litellm.llms.cohere.rerank.transformation import CohereRerankConfig
from litellm.types.rerank import OptionalRerankParams, RerankRequest
class CohereRerankV2Config(CohereRerankConfig):
"""
Reference: https://docs.cohere.com/v2/reference/rerank

View file

@ -32,7 +32,6 @@ DEFAULT_TIMEOUT = 600
class BaseLLMAIOHTTPHandler:
def __init__(self):
self.client_session: Optional[aiohttp.ClientSession] = None
@ -110,7 +109,6 @@ class BaseLLMAIOHTTPHandler:
content: Any = None,
params: Optional[dict] = None,
) -> httpx.Response:
max_retry_on_unprocessable_entity_error = (
provider_config.max_retry_on_unprocessable_entity_error
)

View file

@ -114,7 +114,6 @@ class AsyncHTTPHandler:
event_hooks: Optional[Mapping[str, List[Callable[..., Any]]]],
ssl_verify: Optional[VerifyTypes] = None,
) -> httpx.AsyncClient:
# SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
# /path/to/certificate.pem
if ssl_verify is None:
@ -590,7 +589,6 @@ class HTTPHandler:
timeout: Optional[Union[float, httpx.Timeout]] = None,
):
try:
if timeout is not None:
req = self.client.build_request(
"PATCH", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore
@ -609,7 +607,6 @@ class HTTPHandler:
llm_provider="litellm-httpx-handler",
)
except httpx.HTTPStatusError as e:
if stream is True:
setattr(e, "message", mask_sensitive_info(e.response.read()))
setattr(e, "text", mask_sensitive_info(e.response.read()))
@ -635,7 +632,6 @@ class HTTPHandler:
timeout: Optional[Union[float, httpx.Timeout]] = None,
):
try:
if timeout is not None:
req = self.client.build_request(
"PUT", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore

View file

@ -41,7 +41,6 @@ else:
class BaseLLMHTTPHandler:
async def _make_common_async_call(
self,
async_httpx_client: AsyncHTTPHandler,
@ -109,7 +108,6 @@ class BaseLLMHTTPHandler:
logging_obj: LiteLLMLoggingObj,
stream: bool = False,
) -> httpx.Response:
max_retry_on_unprocessable_entity_error = (
provider_config.max_retry_on_unprocessable_entity_error
)
@ -599,7 +597,6 @@ class BaseLLMHTTPHandler:
aembedding: bool = False,
headers={},
) -> EmbeddingResponse:
provider_config = ProviderConfigManager.get_provider_embedding_config(
model=model, provider=litellm.LlmProviders(custom_llm_provider)
)
@ -742,7 +739,6 @@ class BaseLLMHTTPHandler:
api_base: Optional[str] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
) -> RerankResponse:
# get config from model, custom llm provider
headers = provider_config.validate_environment(
api_key=api_key,
@ -828,7 +824,6 @@ class BaseLLMHTTPHandler:
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
) -> RerankResponse:
if client is None or not isinstance(client, AsyncHTTPHandler):
async_httpx_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders(custom_llm_provider)

View file

@ -16,9 +16,9 @@ class DatabricksBase:
api_base = api_base or f"{databricks_client.config.host}/serving-endpoints"
if api_key is None:
databricks_auth_headers: dict[str, str] = (
databricks_client.config.authenticate()
)
databricks_auth_headers: dict[
str, str
] = databricks_client.config.authenticate()
headers = {**databricks_auth_headers, **headers}
return api_base, headers

View file

@ -11,9 +11,9 @@ class DatabricksEmbeddingConfig:
Reference: https://learn.microsoft.com/en-us/azure/databricks/machine-learning/foundation-models/api-reference#--embedding-task
"""
instruction: Optional[str] = (
None # An optional instruction to pass to the embedding model. BGE Authors recommend 'Represent this sentence for searching relevant passages:' for retrieval queries
)
instruction: Optional[
str
] = None # An optional instruction to pass to the embedding model. BGE Authors recommend 'Represent this sentence for searching relevant passages:' for retrieval queries
def __init__(self, instruction: Optional[str] = None) -> None:
locals_ = locals().copy()

View file

@ -55,7 +55,6 @@ class ModelResponseIterator:
usage_chunk: Optional[Usage] = getattr(processed_chunk, "usage", None)
if usage_chunk is not None:
usage = ChatCompletionUsageBlock(
prompt_tokens=usage_chunk.prompt_tokens,
completion_tokens=usage_chunk.completion_tokens,

View file

@ -126,9 +126,9 @@ class DeepgramAudioTranscriptionConfig(BaseAudioTranscriptionConfig):
# Add additional metadata matching OpenAI format
response["task"] = "transcribe"
response["language"] = (
"english" # Deepgram auto-detects but doesn't return language
)
response[
"language"
] = "english" # Deepgram auto-detects but doesn't return language
response["duration"] = response_json["metadata"]["duration"]
# Transform words to match OpenAI format

View file

@ -14,7 +14,6 @@ from ...openai.chat.gpt_transformation import OpenAIGPTConfig
class DeepSeekChatConfig(OpenAIGPTConfig):
def _transform_messages(
self, messages: List[AllMessageValues], model: str
) -> List[AllMessageValues]:

View file

@ -77,9 +77,9 @@ class AlephAlphaConfig:
- `control_log_additive` (boolean; default value: true): Method of applying control to attention scores.
"""
maximum_tokens: Optional[int] = (
litellm.max_tokens
) # aleph alpha requires max tokens
maximum_tokens: Optional[
int
] = litellm.max_tokens # aleph alpha requires max tokens
minimum_tokens: Optional[int] = None
echo: Optional[bool] = None
temperature: Optional[int] = None

Some files were not shown because too many files have changed in this diff Show more