build(pyproject.toml): add new dev dependencies - for type checking (#9631)

* build(pyproject.toml): add new dev dependencies - for type checking

* build: reformat files to fit black

* ci: reformat to fit black

* ci(test-litellm.yml): make tests run clear

* build(pyproject.toml): add ruff

* fix: fix ruff checks

* build(mypy/): fix mypy linting errors

* fix(hashicorp_secret_manager.py): fix passing cert for tls auth

* build(mypy/): resolve all mypy errors

* test: update test

* fix: fix black formatting

* build(pre-commit-config.yaml): use poetry run black

* fix(proxy_server.py): fix linting error

* fix: fix ruff safe representation error
This commit is contained in:
Krish Dholakia 2025-03-29 11:02:13 -07:00 committed by GitHub
parent 95e5dfae5a
commit 9b7ebb6a7d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
214 changed files with 1553 additions and 1433 deletions

53
.github/workflows/test-linting.yml vendored Normal file
View file

@ -0,0 +1,53 @@
name: LiteLLM Linting
on:
pull_request:
branches: [ main ]
jobs:
lint:
runs-on: ubuntu-latest
timeout-minutes: 5
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.12'
- name: Install Poetry
uses: snok/install-poetry@v1
- name: Install dependencies
run: |
poetry install --with dev
- name: Run Black formatting check
run: |
cd litellm
poetry run black . --check
cd ..
- name: Run Ruff linting
run: |
cd litellm
poetry run ruff check .
cd ..
- name: Run MyPy type checking
run: |
cd litellm
poetry run mypy . --ignore-missing-imports
cd ..
- name: Check for circular imports
run: |
cd litellm
poetry run python ../tests/documentation_tests/test_circular_imports.py
cd ..
- name: Check import safety
run: |
poetry run python -c "from litellm import *" || (echo '🚨 import failed, this means you introduced unprotected imports! 🚨'; exit 1)

View file

@ -1,4 +1,4 @@
name: LiteLLM Tests
name: LiteLLM Mock Tests (folder - tests/litellm)
on:
pull_request:

View file

@ -14,10 +14,12 @@ repos:
types: [python]
files: litellm/.*\.py
exclude: ^litellm/__init__.py$
- repo: https://github.com/psf/black
rev: 24.2.0
hooks:
- id: black
- id: black
name: black
entry: poetry run black
language: system
types: [python]
files: litellm/.*\.py
- repo: https://github.com/pycqa/flake8
rev: 7.0.0 # The version of flake8 to use
hooks:

View file

@ -444,9 +444,7 @@ class _ENTERPRISE_SecretDetection(CustomGuardrail):
detected_secrets = []
for file in secrets.files:
for found_secret in secrets[file]:
if found_secret.secret_value is None:
continue
detected_secrets.append(
@ -471,14 +469,12 @@ class _ENTERPRISE_SecretDetection(CustomGuardrail):
data: dict,
call_type: str, # "completion", "embeddings", "image_generation", "moderation"
):
if await self.should_run_check(user_api_key_dict) is False:
return
if "messages" in data and isinstance(data["messages"], list):
for message in data["messages"]:
if "content" in message and isinstance(message["content"], str):
detected_secrets = self.scan_message_for_secrets(message["content"])
for secret in detected_secrets:

View file

@ -122,19 +122,19 @@ langsmith_batch_size: Optional[int] = None
prometheus_initialize_budget_metrics: Optional[bool] = False
argilla_batch_size: Optional[int] = None
datadog_use_v1: Optional[bool] = False # if you want to use v1 datadog logged payload
gcs_pub_sub_use_v1: Optional[bool] = (
False # if you want to use v1 gcs pubsub logged payload
)
gcs_pub_sub_use_v1: Optional[
bool
] = False # if you want to use v1 gcs pubsub logged payload
argilla_transformation_object: Optional[Dict[str, Any]] = None
_async_input_callback: List[Union[str, Callable, CustomLogger]] = (
[]
) # internal variable - async custom callbacks are routed here.
_async_success_callback: List[Union[str, Callable, CustomLogger]] = (
[]
) # internal variable - async custom callbacks are routed here.
_async_failure_callback: List[Union[str, Callable, CustomLogger]] = (
[]
) # internal variable - async custom callbacks are routed here.
_async_input_callback: List[
Union[str, Callable, CustomLogger]
] = [] # internal variable - async custom callbacks are routed here.
_async_success_callback: List[
Union[str, Callable, CustomLogger]
] = [] # internal variable - async custom callbacks are routed here.
_async_failure_callback: List[
Union[str, Callable, CustomLogger]
] = [] # internal variable - async custom callbacks are routed here.
pre_call_rules: List[Callable] = []
post_call_rules: List[Callable] = []
turn_off_message_logging: Optional[bool] = False
@ -142,18 +142,18 @@ log_raw_request_response: bool = False
redact_messages_in_exceptions: Optional[bool] = False
redact_user_api_key_info: Optional[bool] = False
filter_invalid_headers: Optional[bool] = False
add_user_information_to_llm_headers: Optional[bool] = (
None # adds user_id, team_id, token hash (params from StandardLoggingMetadata) to request headers
)
add_user_information_to_llm_headers: Optional[
bool
] = None # adds user_id, team_id, token hash (params from StandardLoggingMetadata) to request headers
store_audit_logs = False # Enterprise feature, allow users to see audit logs
### end of callbacks #############
email: Optional[str] = (
None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
)
token: Optional[str] = (
None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
)
email: Optional[
str
] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
token: Optional[
str
] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
telemetry = True
max_tokens = 256 # OpenAI Defaults
drop_params = bool(os.getenv("LITELLM_DROP_PARAMS", False))
@ -229,24 +229,20 @@ enable_loadbalancing_on_batch_endpoints: Optional[bool] = None
enable_caching_on_provider_specific_optional_params: bool = (
False # feature-flag for caching on optional params - e.g. 'top_k'
)
caching: bool = (
False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
)
caching_with_models: bool = (
False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
)
cache: Optional[Cache] = (
None # cache object <- use this - https://docs.litellm.ai/docs/caching
)
caching: bool = False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
caching_with_models: bool = False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
cache: Optional[
Cache
] = None # cache object <- use this - https://docs.litellm.ai/docs/caching
default_in_memory_ttl: Optional[float] = None
default_redis_ttl: Optional[float] = None
default_redis_batch_cache_expiry: Optional[float] = None
model_alias_map: Dict[str, str] = {}
model_group_alias_map: Dict[str, str] = {}
max_budget: float = 0.0 # set the max budget across all providers
budget_duration: Optional[str] = (
None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
)
budget_duration: Optional[
str
] = None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
default_soft_budget: float = (
50.0 # by default all litellm proxy keys have a soft budget of 50.0
)
@ -255,15 +251,11 @@ forward_traceparent_to_llm_provider: bool = False
_current_cost = 0.0 # private variable, used if max budget is set
error_logs: Dict = {}
add_function_to_prompt: bool = (
False # if function calling not supported by api, append function call details to system prompt
)
add_function_to_prompt: bool = False # if function calling not supported by api, append function call details to system prompt
client_session: Optional[httpx.Client] = None
aclient_session: Optional[httpx.AsyncClient] = None
model_fallbacks: Optional[List] = None # Deprecated for 'litellm.fallbacks'
model_cost_map_url: str = (
"https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
)
model_cost_map_url: str = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
suppress_debug_info = False
dynamodb_table_name: Optional[str] = None
s3_callback_params: Optional[Dict] = None
@ -285,9 +277,7 @@ disable_end_user_cost_tracking_prometheus_only: Optional[bool] = None
custom_prometheus_metadata_labels: List[str] = []
#### REQUEST PRIORITIZATION ####
priority_reservation: Optional[Dict[str, float]] = None
force_ipv4: bool = (
False # when True, litellm will force ipv4 for all LLM requests. Some users have seen httpx ConnectionError when using ipv6.
)
force_ipv4: bool = False # when True, litellm will force ipv4 for all LLM requests. Some users have seen httpx ConnectionError when using ipv6.
module_level_aclient = AsyncHTTPHandler(
timeout=request_timeout, client_alias="module level aclient"
)
@ -301,13 +291,13 @@ fallbacks: Optional[List] = None
context_window_fallbacks: Optional[List] = None
content_policy_fallbacks: Optional[List] = None
allowed_fails: int = 3
num_retries_per_request: Optional[int] = (
None # for the request overall (incl. fallbacks + model retries)
)
num_retries_per_request: Optional[
int
] = None # for the request overall (incl. fallbacks + model retries)
####### SECRET MANAGERS #####################
secret_manager_client: Optional[Any] = (
None # list of instantiated key management clients - e.g. azure kv, infisical, etc.
)
secret_manager_client: Optional[
Any
] = None # list of instantiated key management clients - e.g. azure kv, infisical, etc.
_google_kms_resource_name: Optional[str] = None
_key_management_system: Optional[KeyManagementSystem] = None
_key_management_settings: KeyManagementSettings = KeyManagementSettings()
@ -1056,10 +1046,10 @@ from .types.llms.custom_llm import CustomLLMItem
from .types.utils import GenericStreamingChunk
custom_provider_map: List[CustomLLMItem] = []
_custom_providers: List[str] = (
[]
) # internal helper util, used to track names of custom providers
disable_hf_tokenizer_download: Optional[bool] = (
None # disable huggingface tokenizer download. Defaults to openai clk100
)
_custom_providers: List[
str
] = [] # internal helper util, used to track names of custom providers
disable_hf_tokenizer_download: Optional[
bool
] = None # disable huggingface tokenizer download. Defaults to openai clk100
global_disable_no_log_param: bool = False

View file

@ -15,7 +15,7 @@ from .types.services import ServiceLoggerPayload, ServiceTypes
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = _Span
Span = Union[_Span, Any]
OTELClass = OpenTelemetry
else:
Span = Any

View file

@ -153,7 +153,6 @@ def create_batch(
)
api_base: Optional[str] = None
if custom_llm_provider == "openai":
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
api_base = (
optional_params.api_base
@ -358,7 +357,6 @@ def retrieve_batch(
_is_async = kwargs.pop("aretrieve_batch", False) is True
api_base: Optional[str] = None
if custom_llm_provider == "openai":
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
api_base = (
optional_params.api_base

View file

@ -9,12 +9,12 @@ Has 4 methods:
"""
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Optional, Union
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = _Span
Span = Union[_Span, Any]
else:
Span = Any

View file

@ -66,9 +66,7 @@ class CachingHandlerResponse(BaseModel):
cached_result: Optional[Any] = None
final_embedding_cached_response: Optional[EmbeddingResponse] = None
embedding_all_elements_cache_hit: bool = (
False # this is set to True when all elements in the list have a cache hit in the embedding cache, if true return the final_embedding_cached_response no need to make an API call
)
embedding_all_elements_cache_hit: bool = False # this is set to True when all elements in the list have a cache hit in the embedding cache, if true return the final_embedding_cached_response no need to make an API call
class LLMCachingHandler:
@ -738,7 +736,6 @@ class LLMCachingHandler:
if self._should_store_result_in_cache(
original_function=self.original_function, kwargs=new_kwargs
):
litellm.cache.add_cache(result, **new_kwargs)
return
@ -865,9 +862,9 @@ class LLMCachingHandler:
}
if litellm.cache is not None:
litellm_params["preset_cache_key"] = (
litellm.cache._get_preset_cache_key_from_kwargs(**kwargs)
)
litellm_params[
"preset_cache_key"
] = litellm.cache._get_preset_cache_key_from_kwargs(**kwargs)
else:
litellm_params["preset_cache_key"] = None

View file

@ -1,12 +1,12 @@
import json
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Optional, Union
from .base_cache import BaseCache
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = _Span
Span = Union[_Span, Any]
else:
Span = Any

View file

@ -12,7 +12,7 @@ import asyncio
import time
import traceback
from concurrent.futures import ThreadPoolExecutor
from typing import TYPE_CHECKING, Any, List, Optional
from typing import TYPE_CHECKING, Any, List, Optional, Union
import litellm
from litellm._logging import print_verbose, verbose_logger
@ -24,7 +24,7 @@ from .redis_cache import RedisCache
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = _Span
Span = Union[_Span, Any]
else:
Span = Any

View file

@ -8,7 +8,6 @@ from .in_memory_cache import InMemoryCache
class LLMClientCache(InMemoryCache):
def update_cache_key_with_event_loop(self, key):
"""
Add the event loop to the cache key, to prevent event loop closed errors.

View file

@ -34,7 +34,7 @@ if TYPE_CHECKING:
cluster_pipeline = ClusterPipeline
async_redis_client = Redis
async_redis_cluster_client = RedisCluster
Span = _Span
Span = Union[_Span, Any]
else:
pipeline = Any
cluster_pipeline = Any
@ -57,7 +57,6 @@ class RedisCache(BaseCache):
socket_timeout: Optional[float] = 5.0, # default 5 second timeout
**kwargs,
):
from litellm._service_logger import ServiceLogging
from .._redis import get_redis_client, get_redis_connection_pool

View file

@ -5,7 +5,7 @@ Key differences:
- RedisClient NEEDs to be re-used across requests, adds 3000ms latency if it's re-created
"""
from typing import TYPE_CHECKING, Any, List, Optional
from typing import TYPE_CHECKING, Any, List, Optional, Union
from litellm.caching.redis_cache import RedisCache
@ -16,7 +16,7 @@ if TYPE_CHECKING:
pipeline = Pipeline
async_redis_client = Redis
Span = _Span
Span = Union[_Span, Any]
else:
pipeline = Any
async_redis_client = Any

View file

@ -13,23 +13,27 @@ import ast
import asyncio
import json
import os
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, cast
import litellm
from litellm._logging import print_verbose
from litellm.litellm_core_utils.prompt_templates.common_utils import get_str_from_messages
from litellm.litellm_core_utils.prompt_templates.common_utils import (
get_str_from_messages,
)
from litellm.types.utils import EmbeddingResponse
from .base_cache import BaseCache
class RedisSemanticCache(BaseCache):
"""
Redis-backed semantic cache for LLM responses.
This cache uses vector similarity to find semantically similar prompts that have been
Redis-backed semantic cache for LLM responses.
This cache uses vector similarity to find semantically similar prompts that have been
previously sent to the LLM, allowing for cache hits even when prompts are not identical
but carry similar meaning.
"""
DEFAULT_REDIS_INDEX_NAME: str = "litellm_semantic_cache_index"
def __init__(
@ -57,7 +61,7 @@ class RedisSemanticCache(BaseCache):
index_name: Name for the Redis index
ttl: Default time-to-live for cache entries in seconds
**kwargs: Additional arguments passed to the Redis client
Raises:
Exception: If similarity_threshold is not provided or required Redis
connection information is missing
@ -69,14 +73,14 @@ class RedisSemanticCache(BaseCache):
index_name = self.DEFAULT_REDIS_INDEX_NAME
print_verbose(f"Redis semantic-cache initializing index - {index_name}")
# Validate similarity threshold
if similarity_threshold is None:
raise ValueError("similarity_threshold must be provided, passed None")
# Store configuration
self.similarity_threshold = similarity_threshold
# Convert similarity threshold [0,1] to distance threshold [0,2]
# For cosine distance: 0 = most similar, 2 = least similar
# While similarity: 1 = most similar, 0 = least similar
@ -87,14 +91,16 @@ class RedisSemanticCache(BaseCache):
if redis_url is None:
try:
# Attempt to use provided parameters or fallback to environment variables
host = host or os.environ['REDIS_HOST']
port = port or os.environ['REDIS_PORT']
password = password or os.environ['REDIS_PASSWORD']
host = host or os.environ["REDIS_HOST"]
port = port or os.environ["REDIS_PORT"]
password = password or os.environ["REDIS_PASSWORD"]
except KeyError as e:
# Raise a more informative exception if any of the required keys are missing
missing_var = e.args[0]
raise ValueError(f"Missing required Redis configuration: {missing_var}. "
f"Provide {missing_var} or redis_url.") from e
raise ValueError(
f"Missing required Redis configuration: {missing_var}. "
f"Provide {missing_var} or redis_url."
) from e
redis_url = f"redis://:{password}@{host}:{port}"
@ -114,7 +120,7 @@ class RedisSemanticCache(BaseCache):
def _get_ttl(self, **kwargs) -> Optional[int]:
"""
Get the TTL (time-to-live) value for cache entries.
Args:
**kwargs: Keyword arguments that may contain a custom TTL
@ -125,22 +131,25 @@ class RedisSemanticCache(BaseCache):
if ttl is not None:
ttl = int(ttl)
return ttl
def _get_embedding(self, prompt: str) -> List[float]:
"""
Generate an embedding vector for the given prompt using the configured embedding model.
Args:
prompt: The text to generate an embedding for
Returns:
List[float]: The embedding vector
"""
# Create an embedding from prompt
embedding_response = litellm.embedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
embedding_response = cast(
EmbeddingResponse,
litellm.embedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
),
)
embedding = embedding_response["data"][0]["embedding"]
return embedding
@ -148,10 +157,10 @@ class RedisSemanticCache(BaseCache):
def _get_cache_logic(self, cached_response: Any) -> Any:
"""
Process the cached response to prepare it for use.
Args:
cached_response: The raw cached response
Returns:
The processed cache response, or None if input was None
"""
@ -171,13 +180,13 @@ class RedisSemanticCache(BaseCache):
except (ValueError, SyntaxError) as e:
print_verbose(f"Error parsing cached response: {str(e)}")
return None
return cached_response
def set_cache(self, key: str, value: Any, **kwargs) -> None:
"""
Store a value in the semantic cache.
Args:
key: The cache key (not directly used in semantic caching)
value: The response value to cache
@ -186,13 +195,14 @@ class RedisSemanticCache(BaseCache):
"""
print_verbose(f"Redis semantic-cache set_cache, kwargs: {kwargs}")
value_str: Optional[str] = None
try:
# Extract the prompt from messages
messages = kwargs.get("messages", [])
if not messages:
print_verbose("No messages provided for semantic caching")
return
prompt = get_str_from_messages(messages)
value_str = str(value)
@ -203,16 +213,18 @@ class RedisSemanticCache(BaseCache):
else:
self.llmcache.store(prompt, value_str)
except Exception as e:
print_verbose(f"Error setting {value_str} in the Redis semantic cache: {str(e)}")
print_verbose(
f"Error setting {value_str or value} in the Redis semantic cache: {str(e)}"
)
def get_cache(self, key: str, **kwargs) -> Any:
"""
Retrieve a semantically similar cached response.
Args:
key: The cache key (not directly used in semantic caching)
**kwargs: Additional arguments including 'messages' for the prompt
Returns:
The cached response if a semantically similar prompt is found, else None
"""
@ -224,7 +236,7 @@ class RedisSemanticCache(BaseCache):
if not messages:
print_verbose("No messages provided for semantic cache lookup")
return None
prompt = get_str_from_messages(messages)
# Check the cache for semantically similar prompts
results = self.llmcache.check(prompt=prompt)
@ -236,12 +248,12 @@ class RedisSemanticCache(BaseCache):
# Process the best matching result
cache_hit = results[0]
vector_distance = float(cache_hit["vector_distance"])
# Convert vector distance back to similarity score
# For cosine distance: 0 = most similar, 2 = least similar
# While similarity: 1 = most similar, 0 = least similar
similarity = 1 - vector_distance
cached_prompt = cache_hit["prompt"]
cached_response = cache_hit["response"]
@ -251,19 +263,19 @@ class RedisSemanticCache(BaseCache):
f"current prompt: {prompt}, "
f"cached prompt: {cached_prompt}"
)
return self._get_cache_logic(cached_response=cached_response)
except Exception as e:
print_verbose(f"Error retrieving from Redis semantic cache: {str(e)}")
async def _get_async_embedding(self, prompt: str, **kwargs) -> List[float]:
"""
Asynchronously generate an embedding for the given prompt.
Args:
prompt: The text to generate an embedding for
**kwargs: Additional arguments that may contain metadata
Returns:
List[float]: The embedding vector
"""
@ -275,7 +287,7 @@ class RedisSemanticCache(BaseCache):
if llm_model_list is not None
else []
)
try:
if llm_router is not None and self.embedding_model in router_model_names:
# Use the router for embedding generation
@ -307,7 +319,7 @@ class RedisSemanticCache(BaseCache):
async def async_set_cache(self, key: str, value: Any, **kwargs) -> None:
"""
Asynchronously store a value in the semantic cache.
Args:
key: The cache key (not directly used in semantic caching)
value: The response value to cache
@ -322,13 +334,13 @@ class RedisSemanticCache(BaseCache):
if not messages:
print_verbose("No messages provided for semantic caching")
return
prompt = get_str_from_messages(messages)
value_str = str(value)
# Generate embedding for the value (response) to cache
prompt_embedding = await self._get_async_embedding(prompt, **kwargs)
# Get TTL and store in Redis semantic cache
ttl = self._get_ttl(**kwargs)
if ttl is not None:
@ -336,13 +348,13 @@ class RedisSemanticCache(BaseCache):
prompt,
value_str,
vector=prompt_embedding, # Pass through custom embedding
ttl=ttl
ttl=ttl,
)
else:
await self.llmcache.astore(
prompt,
value_str,
vector=prompt_embedding # Pass through custom embedding
vector=prompt_embedding, # Pass through custom embedding
)
except Exception as e:
print_verbose(f"Error in async_set_cache: {str(e)}")
@ -350,11 +362,11 @@ class RedisSemanticCache(BaseCache):
async def async_get_cache(self, key: str, **kwargs) -> Any:
"""
Asynchronously retrieve a semantically similar cached response.
Args:
key: The cache key (not directly used in semantic caching)
**kwargs: Additional arguments including 'messages' for the prompt
Returns:
The cached response if a semantically similar prompt is found, else None
"""
@ -367,21 +379,20 @@ class RedisSemanticCache(BaseCache):
print_verbose("No messages provided for semantic cache lookup")
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
return None
prompt = get_str_from_messages(messages)
# Generate embedding for the prompt
prompt_embedding = await self._get_async_embedding(prompt, **kwargs)
# Check the cache for semantically similar prompts
results = await self.llmcache.acheck(
prompt=prompt,
vector=prompt_embedding
)
results = await self.llmcache.acheck(prompt=prompt, vector=prompt_embedding)
# handle results / cache hit
if not results:
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0 # TODO why here but not above??
kwargs.setdefault("metadata", {})[
"semantic-similarity"
] = 0.0 # TODO why here but not above??
return None
cache_hit = results[0]
@ -404,7 +415,7 @@ class RedisSemanticCache(BaseCache):
f"current prompt: {prompt}, "
f"cached prompt: {cached_prompt}"
)
return self._get_cache_logic(cached_response=cached_response)
except Exception as e:
print_verbose(f"Error in async_get_cache: {str(e)}")
@ -413,17 +424,19 @@ class RedisSemanticCache(BaseCache):
async def _index_info(self) -> Dict[str, Any]:
"""
Get information about the Redis index.
Returns:
Dict[str, Any]: Information about the Redis index
"""
aindex = await self.llmcache._get_async_index()
return await aindex.info()
async def async_set_cache_pipeline(self, cache_list: List[Tuple[str, Any]], **kwargs) -> None:
async def async_set_cache_pipeline(
self, cache_list: List[Tuple[str, Any]], **kwargs
) -> None:
"""
Asynchronously store multiple values in the semantic cache.
Args:
cache_list: List of (key, value) tuples to cache
**kwargs: Additional arguments

View file

@ -123,7 +123,7 @@ class S3Cache(BaseCache):
) # Convert string to dictionary
except Exception:
cached_response = ast.literal_eval(cached_response)
if type(cached_response) is not dict:
if not isinstance(cached_response, dict):
cached_response = dict(cached_response)
verbose_logger.debug(
f"Got S3 Cache: key: {key}, cached_response {cached_response}. Type Response {type(cached_response)}"

View file

@ -580,7 +580,6 @@ def completion_cost( # noqa: PLR0915
- For un-mapped Replicate models, the cost is calculated based on the total time used for the request.
"""
try:
call_type = _infer_call_type(call_type, completion_response) or "completion"
if (

View file

@ -138,7 +138,6 @@ def create_fine_tuning_job(
# OpenAI
if custom_llm_provider == "openai":
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
api_base = (
optional_params.api_base
@ -360,7 +359,6 @@ def cancel_fine_tuning_job(
# OpenAI
if custom_llm_provider == "openai":
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
api_base = (
optional_params.api_base
@ -522,7 +520,6 @@ def list_fine_tuning_jobs(
# OpenAI
if custom_llm_provider == "openai":
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
api_base = (
optional_params.api_base

View file

@ -19,7 +19,6 @@ else:
def squash_payloads(queue):
squashed = {}
if len(queue) == 0:
return squashed

View file

@ -195,12 +195,15 @@ class SlackAlerting(CustomBatchLogger):
if self.alerting is None or self.alert_types is None:
return
time_difference_float, model, api_base, messages = (
self._response_taking_too_long_callback_helper(
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
)
(
time_difference_float,
model,
api_base,
messages,
) = self._response_taking_too_long_callback_helper(
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
)
if litellm.turn_off_message_logging or litellm.redact_messages_in_exceptions:
messages = "Message not logged. litellm.redact_messages_in_exceptions=True"
@ -819,9 +822,9 @@ class SlackAlerting(CustomBatchLogger):
### UNIQUE CACHE KEY ###
cache_key = provider + region_name
outage_value: Optional[ProviderRegionOutageModel] = (
await self.internal_usage_cache.async_get_cache(key=cache_key)
)
outage_value: Optional[
ProviderRegionOutageModel
] = await self.internal_usage_cache.async_get_cache(key=cache_key)
if (
getattr(exception, "status_code", None) is None
@ -1402,9 +1405,9 @@ Model Info:
self.alert_to_webhook_url is not None
and alert_type in self.alert_to_webhook_url
):
slack_webhook_url: Optional[Union[str, List[str]]] = (
self.alert_to_webhook_url[alert_type]
)
slack_webhook_url: Optional[
Union[str, List[str]]
] = self.alert_to_webhook_url[alert_type]
elif self.default_webhook_url is not None:
slack_webhook_url = self.default_webhook_url
else:
@ -1768,7 +1771,6 @@ Model Info:
- Team Created, Updated, Deleted
"""
try:
message = f"`{event_name}`\n"
key_event_dict = key_event.model_dump()

View file

@ -283,4 +283,4 @@ class OpenInferenceSpanKindValues(Enum):
class OpenInferenceMimeTypeValues(Enum):
TEXT = "text/plain"
JSON = "application/json"
JSON = "application/json"

View file

@ -98,7 +98,6 @@ class ArgillaLogger(CustomBatchLogger):
argilla_dataset_name: Optional[str],
argilla_base_url: Optional[str],
) -> ArgillaCredentialsObject:
_credentials_api_key = argilla_api_key or os.getenv("ARGILLA_API_KEY")
if _credentials_api_key is None:
raise Exception("Invalid Argilla API Key given. _credentials_api_key=None.")

View file

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Optional, Union
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
@ -7,7 +7,7 @@ from litellm.types.utils import StandardLoggingPayload
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = _Span
Span = Union[_Span, Any]
else:
Span = Any

View file

@ -19,14 +19,13 @@ if TYPE_CHECKING:
from litellm.types.integrations.arize import Protocol as _Protocol
Protocol = _Protocol
Span = _Span
Span = Union[_Span, Any]
else:
Protocol = Any
Span = Any
class ArizeLogger(OpenTelemetry):
def set_attributes(self, span: Span, kwargs, response_obj: Optional[Any]):
ArizeLogger.set_arize_attributes(span, kwargs, response_obj)
return

View file

@ -1,17 +1,20 @@
import os
from typing import TYPE_CHECKING, Any
from litellm.integrations.arize import _utils
from typing import TYPE_CHECKING, Any, Union
from litellm._logging import verbose_logger
from litellm.integrations.arize import _utils
from litellm.types.integrations.arize_phoenix import ArizePhoenixConfig
if TYPE_CHECKING:
from .opentelemetry import OpenTelemetryConfig as _OpenTelemetryConfig
from litellm.types.integrations.arize import Protocol as _Protocol
from opentelemetry.trace import Span as _Span
from litellm.types.integrations.arize import Protocol as _Protocol
from .opentelemetry import OpenTelemetryConfig as _OpenTelemetryConfig
Protocol = _Protocol
OpenTelemetryConfig = _OpenTelemetryConfig
Span = _Span
Span = Union[_Span, Any]
else:
Protocol = Any
OpenTelemetryConfig = Any
@ -20,6 +23,7 @@ else:
ARIZE_HOSTED_PHOENIX_ENDPOINT = "https://app.phoenix.arize.com/v1/traces"
class ArizePhoenixLogger:
@staticmethod
def set_arize_phoenix_attributes(span: Span, kwargs, response_obj):
@ -49,7 +53,7 @@ class ArizePhoenixLogger:
protocol = "otlp_grpc"
else:
endpoint = ARIZE_HOSTED_PHOENIX_ENDPOINT
protocol = "otlp_http"
protocol = "otlp_http"
verbose_logger.debug(
f"No PHOENIX_COLLECTOR_ENDPOINT or PHOENIX_COLLECTOR_HTTP_ENDPOINT found, using default endpoint with http: {ARIZE_HOSTED_PHOENIX_ENDPOINT}"
)
@ -57,17 +61,16 @@ class ArizePhoenixLogger:
otlp_auth_headers = None
# If the endpoint is the Arize hosted Phoenix endpoint, use the api_key as the auth header as currently it is uses
# a slightly different auth header format than self hosted phoenix
if endpoint == ARIZE_HOSTED_PHOENIX_ENDPOINT:
if endpoint == ARIZE_HOSTED_PHOENIX_ENDPOINT:
if api_key is None:
raise ValueError("PHOENIX_API_KEY must be set when the Arize hosted Phoenix endpoint is used.")
raise ValueError(
"PHOENIX_API_KEY must be set when the Arize hosted Phoenix endpoint is used."
)
otlp_auth_headers = f"api_key={api_key}"
elif api_key is not None:
# api_key/auth is optional for self hosted phoenix
otlp_auth_headers = f"Authorization=Bearer {api_key}"
return ArizePhoenixConfig(
otlp_auth_headers=otlp_auth_headers,
protocol=protocol,
endpoint=endpoint
otlp_auth_headers=otlp_auth_headers, protocol=protocol, endpoint=endpoint
)

View file

@ -12,7 +12,10 @@ class AthinaLogger:
"athina-api-key": self.athina_api_key,
"Content-Type": "application/json",
}
self.athina_logging_url = os.getenv("ATHINA_BASE_URL", "https://log.athina.ai") + "/api/v1/log/inference"
self.athina_logging_url = (
os.getenv("ATHINA_BASE_URL", "https://log.athina.ai")
+ "/api/v1/log/inference"
)
self.additional_keys = [
"environment",
"prompt_slug",

View file

@ -50,12 +50,12 @@ class AzureBlobStorageLogger(CustomBatchLogger):
self.azure_storage_file_system: str = _azure_storage_file_system
# Internal variables used for Token based authentication
self.azure_auth_token: Optional[str] = (
None # the Azure AD token to use for Azure Storage API requests
)
self.token_expiry: Optional[datetime] = (
None # the expiry time of the currentAzure AD token
)
self.azure_auth_token: Optional[
str
] = None # the Azure AD token to use for Azure Storage API requests
self.token_expiry: Optional[
datetime
] = None # the expiry time of the currentAzure AD token
asyncio.create_task(self.periodic_flush())
self.flush_lock = asyncio.Lock()
@ -153,7 +153,6 @@ class AzureBlobStorageLogger(CustomBatchLogger):
3. Flush the data
"""
try:
if self.azure_storage_account_key:
await self.upload_to_azure_data_lake_with_azure_account_key(
payload=payload

View file

@ -4,7 +4,7 @@
import copy
import os
from datetime import datetime
from typing import Optional, Dict
from typing import Dict, Optional
import httpx
from pydantic import BaseModel
@ -19,7 +19,9 @@ from litellm.llms.custom_httpx.http_handler import (
)
from litellm.utils import print_verbose
global_braintrust_http_handler = get_async_httpx_client(llm_provider=httpxSpecialProvider.LoggingCallback)
global_braintrust_http_handler = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback
)
global_braintrust_sync_http_handler = HTTPHandler()
API_BASE = "https://api.braintrustdata.com/v1"
@ -35,7 +37,9 @@ def get_utc_datetime():
class BraintrustLogger(CustomLogger):
def __init__(self, api_key: Optional[str] = None, api_base: Optional[str] = None) -> None:
def __init__(
self, api_key: Optional[str] = None, api_base: Optional[str] = None
) -> None:
super().__init__()
self.validate_environment(api_key=api_key)
self.api_base = api_base or API_BASE
@ -45,7 +49,9 @@ class BraintrustLogger(CustomLogger):
"Authorization": "Bearer " + self.api_key,
"Content-Type": "application/json",
}
self._project_id_cache: Dict[str, str] = {} # Cache mapping project names to IDs
self._project_id_cache: Dict[
str, str
] = {} # Cache mapping project names to IDs
def validate_environment(self, api_key: Optional[str]):
"""
@ -71,7 +77,9 @@ class BraintrustLogger(CustomLogger):
try:
response = global_braintrust_sync_http_handler.post(
f"{self.api_base}/project", headers=self.headers, json={"name": project_name}
f"{self.api_base}/project",
headers=self.headers,
json={"name": project_name},
)
project_dict = response.json()
project_id = project_dict["id"]
@ -89,7 +97,9 @@ class BraintrustLogger(CustomLogger):
try:
response = await global_braintrust_http_handler.post(
f"{self.api_base}/project/register", headers=self.headers, json={"name": project_name}
f"{self.api_base}/project/register",
headers=self.headers,
json={"name": project_name},
)
project_dict = response.json()
project_id = project_dict["id"]
@ -116,15 +126,21 @@ class BraintrustLogger(CustomLogger):
if metadata is None:
metadata = {}
proxy_headers = litellm_params.get("proxy_server_request", {}).get("headers", {}) or {}
proxy_headers = (
litellm_params.get("proxy_server_request", {}).get("headers", {}) or {}
)
for metadata_param_key in proxy_headers:
if metadata_param_key.startswith("braintrust"):
trace_param_key = metadata_param_key.replace("braintrust", "", 1)
if trace_param_key in metadata:
verbose_logger.warning(f"Overwriting Braintrust `{trace_param_key}` from request header")
verbose_logger.warning(
f"Overwriting Braintrust `{trace_param_key}` from request header"
)
else:
verbose_logger.debug(f"Found Braintrust `{trace_param_key}` in request header")
verbose_logger.debug(
f"Found Braintrust `{trace_param_key}` in request header"
)
metadata[trace_param_key] = proxy_headers.get(metadata_param_key)
return metadata
@ -157,24 +173,35 @@ class BraintrustLogger(CustomLogger):
output = None
choices = []
if response_obj is not None and (
kwargs.get("call_type", None) == "embedding" or isinstance(response_obj, litellm.EmbeddingResponse)
kwargs.get("call_type", None) == "embedding"
or isinstance(response_obj, litellm.EmbeddingResponse)
):
output = None
elif response_obj is not None and isinstance(response_obj, litellm.ModelResponse):
elif response_obj is not None and isinstance(
response_obj, litellm.ModelResponse
):
output = response_obj["choices"][0]["message"].json()
choices = response_obj["choices"]
elif response_obj is not None and isinstance(response_obj, litellm.TextCompletionResponse):
elif response_obj is not None and isinstance(
response_obj, litellm.TextCompletionResponse
):
output = response_obj.choices[0].text
choices = response_obj.choices
elif response_obj is not None and isinstance(response_obj, litellm.ImageResponse):
elif response_obj is not None and isinstance(
response_obj, litellm.ImageResponse
):
output = response_obj["data"]
litellm_params = kwargs.get("litellm_params", {})
metadata = litellm_params.get("metadata", {}) or {} # if litellm_params['metadata'] == None
metadata = (
litellm_params.get("metadata", {}) or {}
) # if litellm_params['metadata'] == None
metadata = self.add_metadata_from_header(litellm_params, metadata)
clean_metadata = {}
try:
metadata = copy.deepcopy(metadata) # Avoid modifying the original metadata
metadata = copy.deepcopy(
metadata
) # Avoid modifying the original metadata
except Exception:
new_metadata = {}
for key, value in metadata.items():
@ -192,7 +219,9 @@ class BraintrustLogger(CustomLogger):
project_id = metadata.get("project_id")
if project_id is None:
project_name = metadata.get("project_name")
project_id = self.get_project_id_sync(project_name) if project_name else None
project_id = (
self.get_project_id_sync(project_name) if project_name else None
)
if project_id is None:
if self.default_project_id is None:
@ -234,7 +263,8 @@ class BraintrustLogger(CustomLogger):
"completion_tokens": usage_obj.completion_tokens,
"total_tokens": usage_obj.total_tokens,
"total_cost": cost,
"time_to_first_token": end_time.timestamp() - start_time.timestamp(),
"time_to_first_token": end_time.timestamp()
- start_time.timestamp(),
"start": start_time.timestamp(),
"end": end_time.timestamp(),
}
@ -255,7 +285,9 @@ class BraintrustLogger(CustomLogger):
request_data["metrics"] = metrics
try:
print_verbose(f"global_braintrust_sync_http_handler.post: {global_braintrust_sync_http_handler.post}")
print_verbose(
f"global_braintrust_sync_http_handler.post: {global_braintrust_sync_http_handler.post}"
)
global_braintrust_sync_http_handler.post(
url=f"{self.api_base}/project_logs/{project_id}/insert",
json={"events": [request_data]},
@ -276,20 +308,29 @@ class BraintrustLogger(CustomLogger):
output = None
choices = []
if response_obj is not None and (
kwargs.get("call_type", None) == "embedding" or isinstance(response_obj, litellm.EmbeddingResponse)
kwargs.get("call_type", None) == "embedding"
or isinstance(response_obj, litellm.EmbeddingResponse)
):
output = None
elif response_obj is not None and isinstance(response_obj, litellm.ModelResponse):
elif response_obj is not None and isinstance(
response_obj, litellm.ModelResponse
):
output = response_obj["choices"][0]["message"].json()
choices = response_obj["choices"]
elif response_obj is not None and isinstance(response_obj, litellm.TextCompletionResponse):
elif response_obj is not None and isinstance(
response_obj, litellm.TextCompletionResponse
):
output = response_obj.choices[0].text
choices = response_obj.choices
elif response_obj is not None and isinstance(response_obj, litellm.ImageResponse):
elif response_obj is not None and isinstance(
response_obj, litellm.ImageResponse
):
output = response_obj["data"]
litellm_params = kwargs.get("litellm_params", {})
metadata = litellm_params.get("metadata", {}) or {} # if litellm_params['metadata'] == None
metadata = (
litellm_params.get("metadata", {}) or {}
) # if litellm_params['metadata'] == None
metadata = self.add_metadata_from_header(litellm_params, metadata)
clean_metadata = {}
new_metadata = {}
@ -313,7 +354,11 @@ class BraintrustLogger(CustomLogger):
project_id = metadata.get("project_id")
if project_id is None:
project_name = metadata.get("project_name")
project_id = await self.get_project_id_async(project_name) if project_name else None
project_id = (
await self.get_project_id_async(project_name)
if project_name
else None
)
if project_id is None:
if self.default_project_id is None:
@ -362,8 +407,14 @@ class BraintrustLogger(CustomLogger):
api_call_start_time = kwargs.get("api_call_start_time")
completion_start_time = kwargs.get("completion_start_time")
if api_call_start_time is not None and completion_start_time is not None:
metrics["time_to_first_token"] = completion_start_time.timestamp() - api_call_start_time.timestamp()
if (
api_call_start_time is not None
and completion_start_time is not None
):
metrics["time_to_first_token"] = (
completion_start_time.timestamp()
- api_call_start_time.timestamp()
)
request_data = {
"id": litellm_call_id,

View file

@ -14,7 +14,6 @@ from litellm.integrations.custom_logger import CustomLogger
class CustomBatchLogger(CustomLogger):
def __init__(
self,
flush_lock: Optional[asyncio.Lock] = None,

View file

@ -7,7 +7,6 @@ from litellm.types.utils import StandardLoggingGuardrailInformation
class CustomGuardrail(CustomLogger):
def __init__(
self,
guardrail_name: Optional[str] = None,

View file

@ -31,7 +31,7 @@ from litellm.types.utils import (
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = _Span
Span = Union[_Span, Any]
else:
Span = Any

View file

@ -233,7 +233,6 @@ class DataDogLogger(
pass
async def _log_async_event(self, kwargs, response_obj, start_time, end_time):
dd_payload = self.create_datadog_logging_payload(
kwargs=kwargs,
response_obj=response_obj,

View file

@ -125,9 +125,9 @@ class GCSBucketBase(CustomBatchLogger):
if kwargs is None:
kwargs = {}
standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = (
kwargs.get("standard_callback_dynamic_params", None)
)
standard_callback_dynamic_params: Optional[
StandardCallbackDynamicParams
] = kwargs.get("standard_callback_dynamic_params", None)
bucket_name: str
path_service_account: Optional[str]

View file

@ -70,12 +70,13 @@ class GcsPubSubLogger(CustomBatchLogger):
"""Construct authorization headers using Vertex AI auth"""
from litellm import vertex_chat_completion
_auth_header, vertex_project = (
await vertex_chat_completion._ensure_access_token_async(
credentials=self.path_service_account_json,
project_id=None,
custom_llm_provider="vertex_ai",
)
(
_auth_header,
vertex_project,
) = await vertex_chat_completion._ensure_access_token_async(
credentials=self.path_service_account_json,
project_id=None,
custom_llm_provider="vertex_ai",
)
auth_header, _ = vertex_chat_completion._get_token_and_url(

View file

@ -155,11 +155,7 @@ class HumanloopLogger(CustomLogger):
prompt_id: str,
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
) -> Tuple[
str,
List[AllMessageValues],
dict,
]:
) -> Tuple[str, List[AllMessageValues], dict,]:
humanloop_api_key = dynamic_callback_params.get(
"humanloop_api_key"
) or get_secret_str("HUMANLOOP_API_KEY")

View file

@ -471,9 +471,9 @@ class LangFuseLogger:
# we clean out all extra litellm metadata params before logging
clean_metadata: Dict[str, Any] = {}
if prompt_management_metadata is not None:
clean_metadata["prompt_management_metadata"] = (
prompt_management_metadata
)
clean_metadata[
"prompt_management_metadata"
] = prompt_management_metadata
if isinstance(metadata, dict):
for key, value in metadata.items():
# generate langfuse tags - Default Tags sent to Langfuse from LiteLLM Proxy

View file

@ -19,7 +19,6 @@ else:
class LangFuseHandler:
@staticmethod
def get_langfuse_logger_for_request(
standard_callback_dynamic_params: StandardCallbackDynamicParams,
@ -87,7 +86,9 @@ class LangFuseHandler:
if globalLangfuseLogger is not None:
return globalLangfuseLogger
credentials_dict: Dict[str, Any] = (
credentials_dict: Dict[
str, Any
] = (
{}
) # the global langfuse logger uses Environment Variables, there are no dynamic credentials
globalLangfuseLogger = in_memory_dynamic_logger_cache.get_cache(

View file

@ -172,11 +172,7 @@ class LangfusePromptManagement(LangFuseLogger, PromptManagementBase, CustomLogge
prompt_id: str,
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
) -> Tuple[
str,
List[AllMessageValues],
dict,
]:
) -> Tuple[str, List[AllMessageValues], dict,]:
return self.get_chat_completion_prompt(
model,
messages,

View file

@ -75,7 +75,6 @@ class LangsmithLogger(CustomBatchLogger):
langsmith_project: Optional[str] = None,
langsmith_base_url: Optional[str] = None,
) -> LangsmithCredentialsObject:
_credentials_api_key = langsmith_api_key or os.getenv("LANGSMITH_API_KEY")
if _credentials_api_key is None:
raise Exception(
@ -443,9 +442,9 @@ class LangsmithLogger(CustomBatchLogger):
Otherwise, use the default credentials.
"""
standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = (
kwargs.get("standard_callback_dynamic_params", None)
)
standard_callback_dynamic_params: Optional[
StandardCallbackDynamicParams
] = kwargs.get("standard_callback_dynamic_params", None)
if standard_callback_dynamic_params is not None:
credentials = self.get_credentials_from_env(
langsmith_api_key=standard_callback_dynamic_params.get(
@ -481,7 +480,6 @@ class LangsmithLogger(CustomBatchLogger):
asyncio.run(self.async_send_batch())
def get_run_by_id(self, run_id):
langsmith_api_key = self.default_credentials["LANGSMITH_API_KEY"]
langsmith_api_base = self.default_credentials["LANGSMITH_BASE_URL"]

View file

@ -1,12 +1,12 @@
import json
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Union
from litellm.proxy._types import SpanAttributes
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = _Span
Span = Union[_Span, Any]
else:
Span = Any

View file

@ -20,7 +20,6 @@ def parse_tool_calls(tool_calls):
return None
def clean_tool_call(tool_call):
serialized = {
"type": tool_call.type,
"id": tool_call.id,
@ -36,7 +35,6 @@ def parse_tool_calls(tool_calls):
def parse_messages(input):
if input is None:
return None

View file

@ -48,14 +48,17 @@ class MlflowLogger(CustomLogger):
def _extract_and_set_chat_attributes(self, span, kwargs, response_obj):
try:
from mlflow.tracing.utils import set_span_chat_messages, set_span_chat_tools
from mlflow.tracing.utils import set_span_chat_messages # type: ignore
from mlflow.tracing.utils import set_span_chat_tools # type: ignore
except ImportError:
return
inputs = self._construct_input(kwargs)
input_messages = inputs.get("messages", [])
output_messages = [c.message.model_dump(exclude_none=True)
for c in getattr(response_obj, "choices", [])]
output_messages = [
c.message.model_dump(exclude_none=True)
for c in getattr(response_obj, "choices", [])
]
if messages := [*input_messages, *output_messages]:
set_span_chat_messages(span, messages)
if tools := inputs.get("tools"):

View file

@ -1,7 +1,7 @@
import os
from dataclasses import dataclass
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
import litellm
from litellm._logging import verbose_logger
@ -23,10 +23,10 @@ if TYPE_CHECKING:
)
from litellm.proxy.proxy_server import UserAPIKeyAuth as _UserAPIKeyAuth
Span = _Span
SpanExporter = _SpanExporter
UserAPIKeyAuth = _UserAPIKeyAuth
ManagementEndpointLoggingPayload = _ManagementEndpointLoggingPayload
Span = Union[_Span, Any]
SpanExporter = Union[_SpanExporter, Any]
UserAPIKeyAuth = Union[_UserAPIKeyAuth, Any]
ManagementEndpointLoggingPayload = Union[_ManagementEndpointLoggingPayload, Any]
else:
Span = Any
SpanExporter = Any
@ -46,7 +46,6 @@ LITELLM_REQUEST_SPAN_NAME = "litellm_request"
@dataclass
class OpenTelemetryConfig:
exporter: Union[str, SpanExporter] = "console"
endpoint: Optional[str] = None
headers: Optional[str] = None
@ -154,7 +153,6 @@ class OpenTelemetry(CustomLogger):
end_time: Optional[Union[datetime, float]] = None,
event_metadata: Optional[dict] = None,
):
from opentelemetry import trace
from opentelemetry.trace import Status, StatusCode
@ -215,7 +213,6 @@ class OpenTelemetry(CustomLogger):
end_time: Optional[Union[float, datetime]] = None,
event_metadata: Optional[dict] = None,
):
from opentelemetry import trace
from opentelemetry.trace import Status, StatusCode
@ -353,9 +350,9 @@ class OpenTelemetry(CustomLogger):
"""
from opentelemetry import trace
standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = (
kwargs.get("standard_callback_dynamic_params")
)
standard_callback_dynamic_params: Optional[
StandardCallbackDynamicParams
] = kwargs.get("standard_callback_dynamic_params")
if not standard_callback_dynamic_params:
return
@ -722,7 +719,6 @@ class OpenTelemetry(CustomLogger):
span.set_attribute(key, primitive_value)
def set_raw_request_attributes(self, span: Span, kwargs, response_obj):
kwargs.get("optional_params", {})
litellm_params = kwargs.get("litellm_params", {}) or {}
custom_llm_provider = litellm_params.get("custom_llm_provider", "Unknown")
@ -843,12 +839,14 @@ class OpenTelemetry(CustomLogger):
headers=dynamic_headers or self.OTEL_HEADERS
)
if isinstance(self.OTEL_EXPORTER, SpanExporter):
if hasattr(
self.OTEL_EXPORTER, "export"
): # Check if it has the export method that SpanExporter requires
verbose_logger.debug(
"OpenTelemetry: intiializing SpanExporter. Value of OTEL_EXPORTER: %s",
self.OTEL_EXPORTER,
)
return SimpleSpanProcessor(self.OTEL_EXPORTER)
return SimpleSpanProcessor(cast(SpanExporter, self.OTEL_EXPORTER))
if self.OTEL_EXPORTER == "console":
verbose_logger.debug(
@ -907,7 +905,6 @@ class OpenTelemetry(CustomLogger):
logging_payload: ManagementEndpointLoggingPayload,
parent_otel_span: Optional[Span] = None,
):
from opentelemetry import trace
from opentelemetry.trace import Status, StatusCode
@ -961,7 +958,6 @@ class OpenTelemetry(CustomLogger):
logging_payload: ManagementEndpointLoggingPayload,
parent_otel_span: Optional[Span] = None,
):
from opentelemetry import trace
from opentelemetry.trace import Status, StatusCode

View file

@ -185,7 +185,6 @@ class OpikLogger(CustomBatchLogger):
def _create_opik_payload( # noqa: PLR0915
self, kwargs, response_obj, start_time, end_time
) -> List[Dict]:
# Get metadata
_litellm_params = kwargs.get("litellm_params", {}) or {}
litellm_params_metadata = _litellm_params.get("metadata", {}) or {}

View file

@ -988,9 +988,9 @@ class PrometheusLogger(CustomLogger):
):
try:
verbose_logger.debug("setting remaining tokens requests metric")
standard_logging_payload: Optional[StandardLoggingPayload] = (
request_kwargs.get("standard_logging_object")
)
standard_logging_payload: Optional[
StandardLoggingPayload
] = request_kwargs.get("standard_logging_object")
if standard_logging_payload is None:
return

View file

@ -14,7 +14,6 @@ class PromptManagementClient(TypedDict):
class PromptManagementBase(ABC):
@property
@abstractmethod
def integration_name(self) -> str:
@ -83,11 +82,7 @@ class PromptManagementBase(ABC):
prompt_id: str,
prompt_variables: Optional[dict],
dynamic_callback_params: StandardCallbackDynamicParams,
) -> Tuple[
str,
List[AllMessageValues],
dict,
]:
) -> Tuple[str, List[AllMessageValues], dict,]:
if not self.should_run_prompt_management(
prompt_id=prompt_id, dynamic_callback_params=dynamic_callback_params
):

View file

@ -38,7 +38,7 @@ class S3Logger:
if litellm.s3_callback_params is not None:
# read in .env variables - example os.environ/AWS_BUCKET_NAME
for key, value in litellm.s3_callback_params.items():
if type(value) is str and value.startswith("os.environ/"):
if isinstance(value, str) and value.startswith("os.environ/"):
litellm.s3_callback_params[key] = litellm.get_secret(value)
# now set s3 params from litellm.s3_logger_params
s3_bucket_name = litellm.s3_callback_params.get("s3_bucket_name")

View file

@ -21,11 +21,11 @@ try:
# contains a (known) object attribute
object: Literal["chat.completion", "edit", "text_completion"]
def __getitem__(self, key: K) -> V: ... # noqa
def __getitem__(self, key: K) -> V:
... # noqa
def get( # noqa
self, key: K, default: Optional[V] = None
) -> Optional[V]: ... # pragma: no cover
def get(self, key: K, default: Optional[V] = None) -> Optional[V]: # noqa
... # pragma: no cover
class OpenAIRequestResponseResolver:
def __call__(

View file

@ -10,7 +10,7 @@ from litellm.types.llms.openai import AllMessageValues
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = _Span
Span = Union[_Span, Any]
else:
Span = Any

View file

@ -11,7 +11,9 @@ except (ImportError, AttributeError):
# Old way to access resources, which setuptools deprecated some time ago
import pkg_resources # type: ignore
filename = pkg_resources.resource_filename(__name__, "litellm_core_utils/tokenizers")
filename = pkg_resources.resource_filename(
__name__, "litellm_core_utils/tokenizers"
)
os.environ["TIKTOKEN_CACHE_DIR"] = os.getenv(
"CUSTOM_TIKTOKEN_CACHE_DIR", filename

View file

@ -239,9 +239,9 @@ class Logging(LiteLLMLoggingBaseClass):
self.litellm_trace_id = litellm_trace_id
self.function_id = function_id
self.streaming_chunks: List[Any] = [] # for generating complete stream response
self.sync_streaming_chunks: List[Any] = (
[]
) # for generating complete stream response
self.sync_streaming_chunks: List[
Any
] = [] # for generating complete stream response
self.log_raw_request_response = log_raw_request_response
# Initialize dynamic callbacks
@ -452,18 +452,19 @@ class Logging(LiteLLMLoggingBaseClass):
prompt_id: str,
prompt_variables: Optional[dict],
) -> Tuple[str, List[AllMessageValues], dict]:
custom_logger = self.get_custom_logger_for_prompt_management(model)
if custom_logger:
model, messages, non_default_params = (
custom_logger.get_chat_completion_prompt(
model=model,
messages=messages,
non_default_params=non_default_params,
prompt_id=prompt_id,
prompt_variables=prompt_variables,
dynamic_callback_params=self.standard_callback_dynamic_params,
)
(
model,
messages,
non_default_params,
) = custom_logger.get_chat_completion_prompt(
model=model,
messages=messages,
non_default_params=non_default_params,
prompt_id=prompt_id,
prompt_variables=prompt_variables,
dynamic_callback_params=self.standard_callback_dynamic_params,
)
self.messages = messages
return model, messages, non_default_params
@ -541,12 +542,11 @@ class Logging(LiteLLMLoggingBaseClass):
model
): # if model name was changes pre-call, overwrite the initial model call name with the new one
self.model_call_details["model"] = model
self.model_call_details["litellm_params"]["api_base"] = (
self._get_masked_api_base(additional_args.get("api_base", ""))
)
self.model_call_details["litellm_params"][
"api_base"
] = self._get_masked_api_base(additional_args.get("api_base", ""))
def pre_call(self, input, api_key, model=None, additional_args={}): # noqa: PLR0915
# Log the exact input to the LLM API
litellm.error_logs["PRE_CALL"] = locals()
try:
@ -568,19 +568,16 @@ class Logging(LiteLLMLoggingBaseClass):
self.log_raw_request_response is True
or log_raw_request_response is True
):
_litellm_params = self.model_call_details.get("litellm_params", {})
_metadata = _litellm_params.get("metadata", {}) or {}
try:
# [Non-blocking Extra Debug Information in metadata]
if turn_off_message_logging is True:
_metadata["raw_request"] = (
"redacted by litellm. \
_metadata[
"raw_request"
] = "redacted by litellm. \
'litellm.turn_off_message_logging=True'"
)
else:
curl_command = self._get_request_curl_command(
api_base=additional_args.get("api_base", ""),
headers=additional_args.get("headers", {}),
@ -590,33 +587,33 @@ class Logging(LiteLLMLoggingBaseClass):
_metadata["raw_request"] = str(curl_command)
# split up, so it's easier to parse in the UI
self.model_call_details["raw_request_typed_dict"] = (
RawRequestTypedDict(
raw_request_api_base=str(
additional_args.get("api_base") or ""
),
raw_request_body=self._get_raw_request_body(
additional_args.get("complete_input_dict", {})
),
raw_request_headers=self._get_masked_headers(
additional_args.get("headers", {}) or {},
ignore_sensitive_headers=True,
),
error=None,
)
self.model_call_details[
"raw_request_typed_dict"
] = RawRequestTypedDict(
raw_request_api_base=str(
additional_args.get("api_base") or ""
),
raw_request_body=self._get_raw_request_body(
additional_args.get("complete_input_dict", {})
),
raw_request_headers=self._get_masked_headers(
additional_args.get("headers", {}) or {},
ignore_sensitive_headers=True,
),
error=None,
)
except Exception as e:
self.model_call_details["raw_request_typed_dict"] = (
RawRequestTypedDict(
error=str(e),
)
self.model_call_details[
"raw_request_typed_dict"
] = RawRequestTypedDict(
error=str(e),
)
traceback.print_exc()
_metadata["raw_request"] = (
"Unable to Log \
_metadata[
"raw_request"
] = "Unable to Log \
raw request: {}".format(
str(e)
)
str(e)
)
if self.logger_fn and callable(self.logger_fn):
try:
@ -941,9 +938,9 @@ class Logging(LiteLLMLoggingBaseClass):
verbose_logger.debug(
f"response_cost_failure_debug_information: {debug_info}"
)
self.model_call_details["response_cost_failure_debug_information"] = (
debug_info
)
self.model_call_details[
"response_cost_failure_debug_information"
] = debug_info
return None
try:
@ -968,9 +965,9 @@ class Logging(LiteLLMLoggingBaseClass):
verbose_logger.debug(
f"response_cost_failure_debug_information: {debug_info}"
)
self.model_call_details["response_cost_failure_debug_information"] = (
debug_info
)
self.model_call_details[
"response_cost_failure_debug_information"
] = debug_info
return None
@ -995,7 +992,6 @@ class Logging(LiteLLMLoggingBaseClass):
def should_run_callback(
self, callback: litellm.CALLBACK_TYPES, litellm_params: dict, event_hook: str
) -> bool:
if litellm.global_disable_no_log_param:
return True
@ -1027,9 +1023,9 @@ class Logging(LiteLLMLoggingBaseClass):
end_time = datetime.datetime.now()
if self.completion_start_time is None:
self.completion_start_time = end_time
self.model_call_details["completion_start_time"] = (
self.completion_start_time
)
self.model_call_details[
"completion_start_time"
] = self.completion_start_time
self.model_call_details["log_event_type"] = "successful_api_call"
self.model_call_details["end_time"] = end_time
self.model_call_details["cache_hit"] = cache_hit
@ -1083,39 +1079,39 @@ class Logging(LiteLLMLoggingBaseClass):
"response_cost"
]
else:
self.model_call_details["response_cost"] = (
self._response_cost_calculator(result=result)
)
self.model_call_details[
"response_cost"
] = self._response_cost_calculator(result=result)
## STANDARDIZED LOGGING PAYLOAD
self.model_call_details["standard_logging_object"] = (
get_standard_logging_object_payload(
kwargs=self.model_call_details,
init_response_obj=result,
start_time=start_time,
end_time=end_time,
logging_obj=self,
status="success",
standard_built_in_tools_params=self.standard_built_in_tools_params,
)
self.model_call_details[
"standard_logging_object"
] = get_standard_logging_object_payload(
kwargs=self.model_call_details,
init_response_obj=result,
start_time=start_time,
end_time=end_time,
logging_obj=self,
status="success",
standard_built_in_tools_params=self.standard_built_in_tools_params,
)
elif isinstance(result, dict): # pass-through endpoints
## STANDARDIZED LOGGING PAYLOAD
self.model_call_details["standard_logging_object"] = (
get_standard_logging_object_payload(
kwargs=self.model_call_details,
init_response_obj=result,
start_time=start_time,
end_time=end_time,
logging_obj=self,
status="success",
standard_built_in_tools_params=self.standard_built_in_tools_params,
)
self.model_call_details[
"standard_logging_object"
] = get_standard_logging_object_payload(
kwargs=self.model_call_details,
init_response_obj=result,
start_time=start_time,
end_time=end_time,
logging_obj=self,
status="success",
standard_built_in_tools_params=self.standard_built_in_tools_params,
)
elif standard_logging_object is not None:
self.model_call_details["standard_logging_object"] = (
standard_logging_object
)
self.model_call_details[
"standard_logging_object"
] = standard_logging_object
else: # streaming chunks + image gen.
self.model_call_details["response_cost"] = None
@ -1154,7 +1150,6 @@ class Logging(LiteLLMLoggingBaseClass):
standard_logging_object=kwargs.get("standard_logging_object", None),
)
try:
## BUILD COMPLETE STREAMED RESPONSE
complete_streaming_response: Optional[
Union[ModelResponse, TextCompletionResponse, ResponsesAPIResponse]
@ -1172,23 +1167,23 @@ class Logging(LiteLLMLoggingBaseClass):
verbose_logger.debug(
"Logging Details LiteLLM-Success Call streaming complete"
)
self.model_call_details["complete_streaming_response"] = (
complete_streaming_response
)
self.model_call_details["response_cost"] = (
self._response_cost_calculator(result=complete_streaming_response)
)
self.model_call_details[
"complete_streaming_response"
] = complete_streaming_response
self.model_call_details[
"response_cost"
] = self._response_cost_calculator(result=complete_streaming_response)
## STANDARDIZED LOGGING PAYLOAD
self.model_call_details["standard_logging_object"] = (
get_standard_logging_object_payload(
kwargs=self.model_call_details,
init_response_obj=complete_streaming_response,
start_time=start_time,
end_time=end_time,
logging_obj=self,
status="success",
standard_built_in_tools_params=self.standard_built_in_tools_params,
)
self.model_call_details[
"standard_logging_object"
] = get_standard_logging_object_payload(
kwargs=self.model_call_details,
init_response_obj=complete_streaming_response,
start_time=start_time,
end_time=end_time,
logging_obj=self,
status="success",
standard_built_in_tools_params=self.standard_built_in_tools_params,
)
callbacks = self.get_combined_callback_list(
dynamic_success_callbacks=self.dynamic_success_callbacks,
@ -1207,7 +1202,6 @@ class Logging(LiteLLMLoggingBaseClass):
## LOGGING HOOK ##
for callback in callbacks:
if isinstance(callback, CustomLogger):
self.model_call_details, result = callback.logging_hook(
kwargs=self.model_call_details,
result=result,
@ -1538,10 +1532,10 @@ class Logging(LiteLLMLoggingBaseClass):
)
else:
if self.stream and complete_streaming_response:
self.model_call_details["complete_response"] = (
self.model_call_details.get(
"complete_streaming_response", {}
)
self.model_call_details[
"complete_response"
] = self.model_call_details.get(
"complete_streaming_response", {}
)
result = self.model_call_details["complete_response"]
openMeterLogger.log_success_event(
@ -1581,10 +1575,10 @@ class Logging(LiteLLMLoggingBaseClass):
)
else:
if self.stream and complete_streaming_response:
self.model_call_details["complete_response"] = (
self.model_call_details.get(
"complete_streaming_response", {}
)
self.model_call_details[
"complete_response"
] = self.model_call_details.get(
"complete_streaming_response", {}
)
result = self.model_call_details["complete_response"]
@ -1659,7 +1653,6 @@ class Logging(LiteLLMLoggingBaseClass):
if self.call_type == CallTypes.aretrieve_batch.value and isinstance(
result, LiteLLMBatch
):
response_cost, batch_usage, batch_models = await _handle_completed_batch(
batch=result, custom_llm_provider=self.custom_llm_provider
)
@ -1692,9 +1685,9 @@ class Logging(LiteLLMLoggingBaseClass):
if complete_streaming_response is not None:
print_verbose("Async success callbacks: Got a complete streaming response")
self.model_call_details["async_complete_streaming_response"] = (
complete_streaming_response
)
self.model_call_details[
"async_complete_streaming_response"
] = complete_streaming_response
try:
if self.model_call_details.get("cache_hit", False) is True:
self.model_call_details["response_cost"] = 0.0
@ -1704,10 +1697,10 @@ class Logging(LiteLLMLoggingBaseClass):
model_call_details=self.model_call_details
)
# base_model defaults to None if not set on model_info
self.model_call_details["response_cost"] = (
self._response_cost_calculator(
result=complete_streaming_response
)
self.model_call_details[
"response_cost"
] = self._response_cost_calculator(
result=complete_streaming_response
)
verbose_logger.debug(
@ -1720,16 +1713,16 @@ class Logging(LiteLLMLoggingBaseClass):
self.model_call_details["response_cost"] = None
## STANDARDIZED LOGGING PAYLOAD
self.model_call_details["standard_logging_object"] = (
get_standard_logging_object_payload(
kwargs=self.model_call_details,
init_response_obj=complete_streaming_response,
start_time=start_time,
end_time=end_time,
logging_obj=self,
status="success",
standard_built_in_tools_params=self.standard_built_in_tools_params,
)
self.model_call_details[
"standard_logging_object"
] = get_standard_logging_object_payload(
kwargs=self.model_call_details,
init_response_obj=complete_streaming_response,
start_time=start_time,
end_time=end_time,
logging_obj=self,
status="success",
standard_built_in_tools_params=self.standard_built_in_tools_params,
)
callbacks = self.get_combined_callback_list(
dynamic_success_callbacks=self.dynamic_async_success_callbacks,
@ -1935,18 +1928,18 @@ class Logging(LiteLLMLoggingBaseClass):
## STANDARDIZED LOGGING PAYLOAD
self.model_call_details["standard_logging_object"] = (
get_standard_logging_object_payload(
kwargs=self.model_call_details,
init_response_obj={},
start_time=start_time,
end_time=end_time,
logging_obj=self,
status="failure",
error_str=str(exception),
original_exception=exception,
standard_built_in_tools_params=self.standard_built_in_tools_params,
)
self.model_call_details[
"standard_logging_object"
] = get_standard_logging_object_payload(
kwargs=self.model_call_details,
init_response_obj={},
start_time=start_time,
end_time=end_time,
logging_obj=self,
status="failure",
error_str=str(exception),
original_exception=exception,
standard_built_in_tools_params=self.standard_built_in_tools_params,
)
return start_time, end_time
@ -2084,7 +2077,6 @@ class Logging(LiteLLMLoggingBaseClass):
)
is not True
): # custom logger class
callback.log_failure_event(
start_time=start_time,
end_time=end_time,
@ -2713,9 +2705,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
endpoint=arize_config.endpoint,
)
os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = (
f"space_key={arize_config.space_key},api_key={arize_config.api_key}"
)
os.environ[
"OTEL_EXPORTER_OTLP_TRACES_HEADERS"
] = f"space_key={arize_config.space_key},api_key={arize_config.api_key}"
for callback in _in_memory_loggers:
if (
isinstance(callback, ArizeLogger)
@ -2739,9 +2731,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
# auth can be disabled on local deployments of arize phoenix
if arize_phoenix_config.otlp_auth_headers is not None:
os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = (
arize_phoenix_config.otlp_auth_headers
)
os.environ[
"OTEL_EXPORTER_OTLP_TRACES_HEADERS"
] = arize_phoenix_config.otlp_auth_headers
for callback in _in_memory_loggers:
if (
@ -2832,9 +2824,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
exporter="otlp_http",
endpoint="https://langtrace.ai/api/trace",
)
os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = (
f"api_key={os.getenv('LANGTRACE_API_KEY')}"
)
os.environ[
"OTEL_EXPORTER_OTLP_TRACES_HEADERS"
] = f"api_key={os.getenv('LANGTRACE_API_KEY')}"
for callback in _in_memory_loggers:
if (
isinstance(callback, OpenTelemetry)
@ -3223,7 +3215,6 @@ class StandardLoggingPayloadSetup:
custom_llm_provider: Optional[str],
init_response_obj: Union[Any, BaseModel, dict],
) -> StandardLoggingModelInformation:
model_cost_name = _select_model_name_for_cost_calc(
model=None,
completion_response=init_response_obj, # type: ignore
@ -3286,7 +3277,6 @@ class StandardLoggingPayloadSetup:
def get_additional_headers(
additiona_headers: Optional[dict],
) -> Optional[StandardLoggingAdditionalHeaders]:
if additiona_headers is None:
return None
@ -3322,10 +3312,10 @@ class StandardLoggingPayloadSetup:
for key in StandardLoggingHiddenParams.__annotations__.keys():
if key in hidden_params:
if key == "additional_headers":
clean_hidden_params["additional_headers"] = (
StandardLoggingPayloadSetup.get_additional_headers(
hidden_params[key]
)
clean_hidden_params[
"additional_headers"
] = StandardLoggingPayloadSetup.get_additional_headers(
hidden_params[key]
)
else:
clean_hidden_params[key] = hidden_params[key] # type: ignore
@ -3463,12 +3453,14 @@ def get_standard_logging_object_payload(
)
# cleanup timestamps
start_time_float, end_time_float, completion_start_time_float = (
StandardLoggingPayloadSetup.cleanup_timestamps(
start_time=start_time,
end_time=end_time,
completion_start_time=completion_start_time,
)
(
start_time_float,
end_time_float,
completion_start_time_float,
) = StandardLoggingPayloadSetup.cleanup_timestamps(
start_time=start_time,
end_time=end_time,
completion_start_time=completion_start_time,
)
response_time = StandardLoggingPayloadSetup.get_response_time(
start_time_float=start_time_float,
@ -3495,7 +3487,6 @@ def get_standard_logging_object_payload(
saved_cache_cost: float = 0.0
if cache_hit is True:
id = f"{id}_cache_hit{time.time()}" # do not duplicate the request id
saved_cache_cost = (
logging_obj._response_cost_calculator(
@ -3658,9 +3649,9 @@ def scrub_sensitive_keys_in_metadata(litellm_params: Optional[dict]):
):
for k, v in metadata["user_api_key_metadata"].items():
if k == "logging": # prevent logging user logging keys
cleaned_user_api_key_metadata[k] = (
"scrubbed_by_litellm_for_sensitive_keys"
)
cleaned_user_api_key_metadata[
k
] = "scrubbed_by_litellm_for_sensitive_keys"
else:
cleaned_user_api_key_metadata[k] = v

View file

@ -258,14 +258,12 @@ def _extract_reasoning_content(message: dict) -> Tuple[Optional[str], Optional[s
class LiteLLMResponseObjectHandler:
@staticmethod
def convert_to_image_response(
response_object: dict,
model_response_object: Optional[ImageResponse] = None,
hidden_params: Optional[dict] = None,
) -> ImageResponse:
response_object.update({"hidden_params": hidden_params})
if model_response_object is None:
@ -481,9 +479,9 @@ def convert_to_model_response_object( # noqa: PLR0915
provider_specific_fields["thinking_blocks"] = thinking_blocks
if reasoning_content:
provider_specific_fields["reasoning_content"] = (
reasoning_content
)
provider_specific_fields[
"reasoning_content"
] = reasoning_content
message = Message(
content=content,

View file

@ -17,7 +17,6 @@ from litellm.types.rerank import RerankRequest
class ModelParamHelper:
@staticmethod
def get_standard_logging_model_parameters(
model_parameters: dict,

View file

@ -257,7 +257,6 @@ def _insert_assistant_continue_message(
and message.get("role") == "user" # Current is user
and messages[i + 1].get("role") == "user"
): # Next is user
# Insert assistant message
continue_message = (
assistant_continue_message or DEFAULT_ASSISTANT_CONTINUE_MESSAGE

View file

@ -1042,10 +1042,10 @@ def convert_to_gemini_tool_call_invoke(
if tool_calls is not None:
for tool in tool_calls:
if "function" in tool:
gemini_function_call: Optional[VertexFunctionCall] = (
_gemini_tool_call_invoke_helper(
function_call_params=tool["function"]
)
gemini_function_call: Optional[
VertexFunctionCall
] = _gemini_tool_call_invoke_helper(
function_call_params=tool["function"]
)
if gemini_function_call is not None:
_parts_list.append(
@ -1432,9 +1432,9 @@ def anthropic_messages_pt( # noqa: PLR0915
)
if "cache_control" in _content_element:
_anthropic_content_element["cache_control"] = (
_content_element["cache_control"]
)
_anthropic_content_element[
"cache_control"
] = _content_element["cache_control"]
user_content.append(_anthropic_content_element)
elif m.get("type", "") == "text":
m = cast(ChatCompletionTextObject, m)
@ -1466,9 +1466,9 @@ def anthropic_messages_pt( # noqa: PLR0915
)
if "cache_control" in _content_element:
_anthropic_content_text_element["cache_control"] = (
_content_element["cache_control"]
)
_anthropic_content_text_element[
"cache_control"
] = _content_element["cache_control"]
user_content.append(_anthropic_content_text_element)
@ -1533,7 +1533,6 @@ def anthropic_messages_pt( # noqa: PLR0915
"content"
] # don't pass empty text blocks. anthropic api raises errors.
):
_anthropic_text_content_element = AnthropicMessagesTextParam(
type="text",
text=assistant_content_block["content"],
@ -1569,7 +1568,6 @@ def anthropic_messages_pt( # noqa: PLR0915
msg_i += 1
if assistant_content:
new_messages.append({"role": "assistant", "content": assistant_content})
if msg_i == init_msg_i: # prevent infinite loops
@ -2245,7 +2243,6 @@ class BedrockImageProcessor:
@staticmethod
async def get_image_details_async(image_url) -> Tuple[str, str]:
try:
client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.PromptFactory,
params={"concurrent_limit": 1},
@ -2612,7 +2609,6 @@ def get_user_message_block_or_continue_message(
for item in modified_content_block:
# Check if the list is empty
if item["type"] == "text":
if not item["text"].strip():
# Replace empty text with continue message
_user_continue_message = ChatCompletionUserMessage(
@ -3207,7 +3203,6 @@ def _bedrock_converse_messages_pt( # noqa: PLR0915
assistant_content: List[BedrockContentBlock] = []
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
assistant_message_block = get_assistant_message_block_or_continue_message(
message=messages[msg_i],
assistant_continue_message=assistant_continue_message,
@ -3410,7 +3405,6 @@ def response_schema_prompt(model: str, response_schema: dict) -> str:
{"role": "user", "content": "{}".format(response_schema)}
]
if f"{model}/response_schema_prompt" in litellm.custom_prompt_dict:
custom_prompt_details = litellm.custom_prompt_dict[
f"{model}/response_schema_prompt"
] # allow user to define custom response schema prompt by model

View file

@ -122,7 +122,6 @@ class RealTimeStreaming:
pass
async def bidirectional_forward(self):
forward_task = asyncio.create_task(self.backend_to_client_send_messages())
try:
await self.client_ack_messages()

View file

@ -135,9 +135,9 @@ def _get_turn_off_message_logging_from_dynamic_params(
handles boolean and string values of `turn_off_message_logging`
"""
standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = (
model_call_details.get("standard_callback_dynamic_params", None)
)
standard_callback_dynamic_params: Optional[
StandardCallbackDynamicParams
] = model_call_details.get("standard_callback_dynamic_params", None)
if standard_callback_dynamic_params:
_turn_off_message_logging = standard_callback_dynamic_params.get(
"turn_off_message_logging"

View file

@ -1,4 +1,5 @@
from typing import Any, Dict, Optional, Set
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
@ -40,7 +41,10 @@ class SensitiveDataMasker:
return result
def mask_dict(
self, data: Dict[str, Any], depth: int = 0, max_depth: int = DEFAULT_MAX_RECURSE_DEPTH
self,
data: Dict[str, Any],
depth: int = 0,
max_depth: int = DEFAULT_MAX_RECURSE_DEPTH,
) -> Dict[str, Any]:
if depth >= max_depth:
return data

View file

@ -104,7 +104,6 @@ class ChunkProcessor:
def get_combined_tool_content(
self, tool_call_chunks: List[Dict[str, Any]]
) -> List[ChatCompletionMessageToolCall]:
argument_list: List[str] = []
delta = tool_call_chunks[0]["choices"][0]["delta"]
id = None

View file

@ -84,9 +84,9 @@ class CustomStreamWrapper:
self.system_fingerprint: Optional[str] = None
self.received_finish_reason: Optional[str] = None
self.intermittent_finish_reason: Optional[str] = (
None # finish reasons that show up mid-stream
)
self.intermittent_finish_reason: Optional[
str
] = None # finish reasons that show up mid-stream
self.special_tokens = [
"<|assistant|>",
"<|system|>",
@ -814,7 +814,6 @@ class CustomStreamWrapper:
model_response: ModelResponseStream,
response_obj: Dict[str, Any],
):
print_verbose(
f"completion_obj: {completion_obj}, model_response.choices[0]: {model_response.choices[0]}, response_obj: {response_obj}"
)
@ -1008,7 +1007,6 @@ class CustomStreamWrapper:
self.custom_llm_provider
and self.custom_llm_provider in litellm._custom_providers
):
if self.received_finish_reason is not None:
if "provider_specific_fields" not in chunk:
raise StopIteration
@ -1379,9 +1377,9 @@ class CustomStreamWrapper:
_json_delta = delta.model_dump()
print_verbose(f"_json_delta: {_json_delta}")
if "role" not in _json_delta or _json_delta["role"] is None:
_json_delta["role"] = (
"assistant" # mistral's api returns role as None
)
_json_delta[
"role"
] = "assistant" # mistral's api returns role as None
if "tool_calls" in _json_delta and isinstance(
_json_delta["tool_calls"], list
):
@ -1758,9 +1756,9 @@ class CustomStreamWrapper:
chunk = next(self.completion_stream)
if chunk is not None and chunk != b"":
print_verbose(f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}")
processed_chunk: Optional[ModelResponseStream] = (
self.chunk_creator(chunk=chunk)
)
processed_chunk: Optional[
ModelResponseStream
] = self.chunk_creator(chunk=chunk)
print_verbose(
f"PROCESSED CHUNK POST CHUNK CREATOR: {processed_chunk}"
)

View file

@ -290,7 +290,6 @@ class AnthropicChatCompletion(BaseLLM):
headers={},
client=None,
):
optional_params = copy.deepcopy(optional_params)
stream = optional_params.pop("stream", None)
json_mode: bool = optional_params.pop("json_mode", False)
@ -491,7 +490,6 @@ class ModelResponseIterator:
def _handle_usage(
self, anthropic_usage_chunk: Union[dict, UsageDelta]
) -> AnthropicChatCompletionUsageBlock:
usage_block = AnthropicChatCompletionUsageBlock(
prompt_tokens=anthropic_usage_chunk.get("input_tokens", 0),
completion_tokens=anthropic_usage_chunk.get("output_tokens", 0),
@ -515,7 +513,9 @@ class ModelResponseIterator:
return usage_block
def _content_block_delta_helper(self, chunk: dict) -> Tuple[
def _content_block_delta_helper(
self, chunk: dict
) -> Tuple[
str,
Optional[ChatCompletionToolCallChunk],
List[ChatCompletionThinkingBlock],
@ -592,9 +592,12 @@ class ModelResponseIterator:
Anthropic content chunk
chunk = {'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': 'Hello'}}
"""
text, tool_use, thinking_blocks, provider_specific_fields = (
self._content_block_delta_helper(chunk=chunk)
)
(
text,
tool_use,
thinking_blocks,
provider_specific_fields,
) = self._content_block_delta_helper(chunk=chunk)
if thinking_blocks:
reasoning_content = self._handle_reasoning_content(
thinking_blocks=thinking_blocks
@ -620,7 +623,6 @@ class ModelResponseIterator:
"index": self.tool_index,
}
elif type_chunk == "content_block_stop":
ContentBlockStop(**chunk) # type: ignore
# check if tool call content block
is_empty = self.check_empty_tool_call_args()

View file

@ -49,9 +49,9 @@ class AnthropicConfig(BaseConfig):
to pass metadata to anthropic, it's {"user_id": "any-relevant-information"}
"""
max_tokens: Optional[int] = (
4096 # anthropic requires a default value (Opus, Sonnet, and Haiku have the same default)
)
max_tokens: Optional[
int
] = 4096 # anthropic requires a default value (Opus, Sonnet, and Haiku have the same default)
stop_sequences: Optional[list] = None
temperature: Optional[int] = None
top_p: Optional[int] = None
@ -104,7 +104,6 @@ class AnthropicConfig(BaseConfig):
def get_json_schema_from_pydantic_object(
self, response_format: Union[Any, Dict, None]
) -> Optional[dict]:
return type_to_response_format_param(
response_format, ref_template="/$defs/{model}"
) # Relevant issue: https://github.com/BerriAI/litellm/issues/7755
@ -125,7 +124,6 @@ class AnthropicConfig(BaseConfig):
is_vertex_request: bool = False,
user_anthropic_beta_headers: Optional[List[str]] = None,
) -> dict:
betas = set()
if prompt_caching_set:
betas.add("prompt-caching-2024-07-31")
@ -300,7 +298,6 @@ class AnthropicConfig(BaseConfig):
model: str,
drop_params: bool,
) -> dict:
is_thinking_enabled = self.is_thinking_enabled(
non_default_params=non_default_params
)
@ -321,11 +318,11 @@ class AnthropicConfig(BaseConfig):
optional_params=optional_params, tools=tool_value
)
if param == "tool_choice" or param == "parallel_tool_calls":
_tool_choice: Optional[AnthropicMessagesToolChoice] = (
self._map_tool_choice(
tool_choice=non_default_params.get("tool_choice"),
parallel_tool_use=non_default_params.get("parallel_tool_calls"),
)
_tool_choice: Optional[
AnthropicMessagesToolChoice
] = self._map_tool_choice(
tool_choice=non_default_params.get("tool_choice"),
parallel_tool_use=non_default_params.get("parallel_tool_calls"),
)
if _tool_choice is not None:
@ -341,7 +338,6 @@ class AnthropicConfig(BaseConfig):
if param == "top_p":
optional_params["top_p"] = value
if param == "response_format" and isinstance(value, dict):
ignore_response_format_types = ["text"]
if value["type"] in ignore_response_format_types: # value is a no-op
continue
@ -470,9 +466,9 @@ class AnthropicConfig(BaseConfig):
text=system_message_block["content"],
)
if "cache_control" in system_message_block:
anthropic_system_message_content["cache_control"] = (
system_message_block["cache_control"]
)
anthropic_system_message_content[
"cache_control"
] = system_message_block["cache_control"]
anthropic_system_message_list.append(
anthropic_system_message_content
)
@ -486,9 +482,9 @@ class AnthropicConfig(BaseConfig):
)
)
if "cache_control" in _content:
anthropic_system_message_content["cache_control"] = (
_content["cache_control"]
)
anthropic_system_message_content[
"cache_control"
] = _content["cache_control"]
anthropic_system_message_list.append(
anthropic_system_message_content
@ -597,7 +593,9 @@ class AnthropicConfig(BaseConfig):
)
return _message
def extract_response_content(self, completion_response: dict) -> Tuple[
def extract_response_content(
self, completion_response: dict
) -> Tuple[
str,
Optional[List[Any]],
Optional[List[ChatCompletionThinkingBlock]],
@ -693,9 +691,13 @@ class AnthropicConfig(BaseConfig):
reasoning_content: Optional[str] = None
tool_calls: List[ChatCompletionToolCallChunk] = []
text_content, citations, thinking_blocks, reasoning_content, tool_calls = (
self.extract_response_content(completion_response=completion_response)
)
(
text_content,
citations,
thinking_blocks,
reasoning_content,
tool_calls,
) = self.extract_response_content(completion_response=completion_response)
_message = litellm.Message(
tool_calls=tool_calls,

View file

@ -54,9 +54,9 @@ class AnthropicTextConfig(BaseConfig):
to pass metadata to anthropic, it's {"user_id": "any-relevant-information"}
"""
max_tokens_to_sample: Optional[int] = (
litellm.max_tokens
) # anthropic requires a default
max_tokens_to_sample: Optional[
int
] = litellm.max_tokens # anthropic requires a default
stop_sequences: Optional[list] = None
temperature: Optional[int] = None
top_p: Optional[int] = None

View file

@ -25,7 +25,6 @@ from litellm.utils import ProviderConfigManager, client
class AnthropicMessagesHandler:
@staticmethod
async def _handle_anthropic_streaming(
response: httpx.Response,
@ -74,19 +73,22 @@ async def anthropic_messages(
"""
# Use provided client or create a new one
optional_params = GenericLiteLLMParams(**kwargs)
model, _custom_llm_provider, dynamic_api_key, dynamic_api_base = (
litellm.get_llm_provider(
model=model,
custom_llm_provider=custom_llm_provider,
api_base=optional_params.api_base,
api_key=optional_params.api_key,
)
(
model,
_custom_llm_provider,
dynamic_api_key,
dynamic_api_base,
) = litellm.get_llm_provider(
model=model,
custom_llm_provider=custom_llm_provider,
api_base=optional_params.api_base,
api_key=optional_params.api_key,
)
anthropic_messages_provider_config: Optional[BaseAnthropicMessagesConfig] = (
ProviderConfigManager.get_provider_anthropic_messages_config(
model=model,
provider=litellm.LlmProviders(_custom_llm_provider),
)
anthropic_messages_provider_config: Optional[
BaseAnthropicMessagesConfig
] = ProviderConfigManager.get_provider_anthropic_messages_config(
model=model,
provider=litellm.LlmProviders(_custom_llm_provider),
)
if anthropic_messages_provider_config is None:
raise ValueError(

View file

@ -654,7 +654,6 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
) -> EmbeddingResponse:
response = None
try:
openai_aclient = self.get_azure_openai_client(
api_version=api_version,
api_base=api_base,
@ -835,7 +834,6 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
"2023-10-01-preview",
]
): # CREATE + POLL for azure dall-e-2 calls
api_base = modify_url(
original_url=api_base, new_path="/openai/images/generations:submit"
)
@ -867,7 +865,6 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
)
while response.json()["status"] not in ["succeeded", "failed"]:
if time.time() - start_time > timeout_secs:
raise AzureOpenAIError(
status_code=408, message="Operation polling timed out."
)
@ -935,7 +932,6 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
"2023-10-01-preview",
]
): # CREATE + POLL for azure dall-e-2 calls
api_base = modify_url(
original_url=api_base, new_path="/openai/images/generations:submit"
)
@ -1199,7 +1195,6 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
client=None,
litellm_params: Optional[dict] = None,
) -> HttpxBinaryResponseContent:
max_retries = optional_params.pop("max_retries", 2)
if aspeech is not None and aspeech is True:
@ -1253,7 +1248,6 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
client=None,
litellm_params: Optional[dict] = None,
) -> HttpxBinaryResponseContent:
azure_client: AsyncAzureOpenAI = self.get_azure_openai_client(
api_base=api_base,
api_version=api_version,

View file

@ -50,15 +50,15 @@ class AzureBatchesAPI(BaseAzureLLM):
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
litellm_params: Optional[dict] = None,
) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
client=client,
_is_async=_is_async,
litellm_params=litellm_params or {},
)
azure_client: Optional[
Union[AzureOpenAI, AsyncAzureOpenAI]
] = self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
client=client,
_is_async=_is_async,
litellm_params=litellm_params or {},
)
if azure_client is None:
raise ValueError(
@ -96,15 +96,15 @@ class AzureBatchesAPI(BaseAzureLLM):
client: Optional[AzureOpenAI] = None,
litellm_params: Optional[dict] = None,
):
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
client=client,
_is_async=_is_async,
litellm_params=litellm_params or {},
)
azure_client: Optional[
Union[AzureOpenAI, AsyncAzureOpenAI]
] = self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
client=client,
_is_async=_is_async,
litellm_params=litellm_params or {},
)
if azure_client is None:
raise ValueError(
@ -144,15 +144,15 @@ class AzureBatchesAPI(BaseAzureLLM):
client: Optional[AzureOpenAI] = None,
litellm_params: Optional[dict] = None,
):
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
client=client,
_is_async=_is_async,
litellm_params=litellm_params or {},
)
azure_client: Optional[
Union[AzureOpenAI, AsyncAzureOpenAI]
] = self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
client=client,
_is_async=_is_async,
litellm_params=litellm_params or {},
)
if azure_client is None:
raise ValueError(
@ -183,15 +183,15 @@ class AzureBatchesAPI(BaseAzureLLM):
client: Optional[AzureOpenAI] = None,
litellm_params: Optional[dict] = None,
):
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
client=client,
_is_async=_is_async,
litellm_params=litellm_params or {},
)
azure_client: Optional[
Union[AzureOpenAI, AsyncAzureOpenAI]
] = self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
client=client,
_is_async=_is_async,
litellm_params=litellm_params or {},
)
if azure_client is None:
raise ValueError(

View file

@ -306,7 +306,6 @@ class BaseAzureLLM(BaseOpenAILLM):
api_version: Optional[str],
is_async: bool,
) -> dict:
azure_ad_token_provider: Optional[Callable[[], str]] = None
# If we have api_key, then we have higher priority
azure_ad_token = litellm_params.get("azure_ad_token")

View file

@ -46,16 +46,15 @@ class AzureOpenAIFilesAPI(BaseAzureLLM):
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
litellm_params: Optional[dict] = None,
) -> Union[FileObject, Coroutine[Any, Any, FileObject]]:
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
self.get_azure_openai_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
api_version=api_version,
client=client,
_is_async=_is_async,
)
openai_client: Optional[
Union[AzureOpenAI, AsyncAzureOpenAI]
] = self.get_azure_openai_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
api_version=api_version,
client=client,
_is_async=_is_async,
)
if openai_client is None:
raise ValueError(
@ -95,15 +94,15 @@ class AzureOpenAIFilesAPI(BaseAzureLLM):
) -> Union[
HttpxBinaryResponseContent, Coroutine[Any, Any, HttpxBinaryResponseContent]
]:
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
self.get_azure_openai_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
api_version=api_version,
client=client,
_is_async=_is_async,
)
openai_client: Optional[
Union[AzureOpenAI, AsyncAzureOpenAI]
] = self.get_azure_openai_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
api_version=api_version,
client=client,
_is_async=_is_async,
)
if openai_client is None:
raise ValueError(
@ -145,15 +144,15 @@ class AzureOpenAIFilesAPI(BaseAzureLLM):
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
litellm_params: Optional[dict] = None,
):
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
self.get_azure_openai_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
api_version=api_version,
client=client,
_is_async=_is_async,
)
openai_client: Optional[
Union[AzureOpenAI, AsyncAzureOpenAI]
] = self.get_azure_openai_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
api_version=api_version,
client=client,
_is_async=_is_async,
)
if openai_client is None:
raise ValueError(
@ -197,15 +196,15 @@ class AzureOpenAIFilesAPI(BaseAzureLLM):
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
litellm_params: Optional[dict] = None,
):
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
self.get_azure_openai_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
api_version=api_version,
client=client,
_is_async=_is_async,
)
openai_client: Optional[
Union[AzureOpenAI, AsyncAzureOpenAI]
] = self.get_azure_openai_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
api_version=api_version,
client=client,
_is_async=_is_async,
)
if openai_client is None:
raise ValueError(
@ -251,15 +250,15 @@ class AzureOpenAIFilesAPI(BaseAzureLLM):
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
litellm_params: Optional[dict] = None,
):
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
self.get_azure_openai_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
api_version=api_version,
client=client,
_is_async=_is_async,
)
openai_client: Optional[
Union[AzureOpenAI, AsyncAzureOpenAI]
] = self.get_azure_openai_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
api_version=api_version,
client=client,
_is_async=_is_async,
)
if openai_client is None:
raise ValueError(

View file

@ -25,14 +25,7 @@ class AzureOpenAIFineTuningAPI(OpenAIFineTuningAPI, BaseAzureLLM):
_is_async: bool = False,
api_version: Optional[str] = None,
litellm_params: Optional[dict] = None,
) -> Optional[
Union[
OpenAI,
AsyncOpenAI,
AzureOpenAI,
AsyncAzureOpenAI,
]
]:
) -> Optional[Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI,]]:
# Override to use Azure-specific client initialization
if isinstance(client, OpenAI) or isinstance(client, AsyncOpenAI):
client = None

View file

@ -145,7 +145,6 @@ class AzureAIStudioConfig(OpenAIConfig):
2. If message contains an image or audio, send as is (user-intended)
"""
for message in messages:
# Do nothing if the message contains an image or audio
if _audio_or_image_in_message_content(message):
continue

View file

@ -22,7 +22,6 @@ class AzureAICohereConfig:
pass
def _map_azure_model_group(self, model: str) -> str:
if model == "offer-cohere-embed-multili-paygo":
return "Cohere-embed-v3-multilingual"
elif model == "offer-cohere-embed-english-paygo":

View file

@ -17,7 +17,6 @@ from .cohere_transformation import AzureAICohereConfig
class AzureAIEmbedding(OpenAIChatCompletion):
def _process_response(
self,
image_embedding_responses: Optional[List],
@ -145,7 +144,6 @@ class AzureAIEmbedding(OpenAIChatCompletion):
api_base: Optional[str] = None,
client=None,
) -> EmbeddingResponse:
(
image_embeddings_request,
v1_embeddings_request,

View file

@ -17,6 +17,7 @@ class AzureAIRerankConfig(CohereRerankConfig):
"""
Azure AI Rerank - Follows the same Spec as Cohere Rerank
"""
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
if api_base is None:
raise ValueError(

View file

@ -9,7 +9,6 @@ from litellm.types.utils import ModelResponse, TextCompletionResponse
class BaseLLM:
_client_session: Optional[httpx.Client] = None
def process_response(

View file

@ -218,7 +218,6 @@ class BaseConfig(ABC):
json_schema = value["json_schema"]["schema"]
if json_schema and not is_response_format_supported:
_tool_choice = ChatCompletionToolChoiceObjectParam(
type="function",
function=ChatCompletionToolChoiceFunctionParam(

View file

@ -58,7 +58,6 @@ class BaseResponsesAPIConfig(ABC):
model: str,
drop_params: bool,
) -> Dict:
pass
@abstractmethod

View file

@ -81,7 +81,6 @@ def make_sync_call(
class BedrockConverseLLM(BaseAWSLLM):
def __init__(self) -> None:
super().__init__()
@ -114,7 +113,6 @@ class BedrockConverseLLM(BaseAWSLLM):
fake_stream: bool = False,
json_mode: Optional[bool] = False,
) -> CustomStreamWrapper:
request_data = await litellm.AmazonConverseConfig()._async_transform_request(
model=model,
messages=messages,
@ -179,7 +177,6 @@ class BedrockConverseLLM(BaseAWSLLM):
headers: dict = {},
client: Optional[AsyncHTTPHandler] = None,
) -> Union[ModelResponse, CustomStreamWrapper]:
request_data = await litellm.AmazonConverseConfig()._async_transform_request(
model=model,
messages=messages,
@ -265,7 +262,6 @@ class BedrockConverseLLM(BaseAWSLLM):
extra_headers: Optional[dict] = None,
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
):
## SETUP ##
stream = optional_params.pop("stream", None)
unencoded_model_id = optional_params.pop("model_id", None)
@ -301,9 +297,9 @@ class BedrockConverseLLM(BaseAWSLLM):
aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None)
optional_params.pop("aws_region_name", None)
litellm_params["aws_region_name"] = (
aws_region_name # [DO NOT DELETE] important for async calls
)
litellm_params[
"aws_region_name"
] = aws_region_name # [DO NOT DELETE] important for async calls
credentials: Credentials = self.get_credentials(
aws_access_key_id=aws_access_key_id,

View file

@ -223,7 +223,6 @@ class AmazonConverseConfig(BaseConfig):
)
for param, value in non_default_params.items():
if param == "response_format" and isinstance(value, dict):
ignore_response_format_types = ["text"]
if value["type"] in ignore_response_format_types: # value is a no-op
continue
@ -715,9 +714,9 @@ class AmazonConverseConfig(BaseConfig):
chat_completion_message: ChatCompletionResponseMessage = {"role": "assistant"}
content_str = ""
tools: List[ChatCompletionToolCallChunk] = []
reasoningContentBlocks: Optional[List[BedrockConverseReasoningContentBlock]] = (
None
)
reasoningContentBlocks: Optional[
List[BedrockConverseReasoningContentBlock]
] = None
if message is not None:
for idx, content in enumerate(message["content"]):
@ -727,7 +726,6 @@ class AmazonConverseConfig(BaseConfig):
if "text" in content:
content_str += content["text"]
if "toolUse" in content:
## check tool name was formatted by litellm
_response_tool_name = content["toolUse"]["name"]
response_tool_name = get_bedrock_tool_name(
@ -754,12 +752,12 @@ class AmazonConverseConfig(BaseConfig):
chat_completion_message["provider_specific_fields"] = {
"reasoningContentBlocks": reasoningContentBlocks,
}
chat_completion_message["reasoning_content"] = (
self._transform_reasoning_content(reasoningContentBlocks)
)
chat_completion_message["thinking_blocks"] = (
self._transform_thinking_blocks(reasoningContentBlocks)
)
chat_completion_message[
"reasoning_content"
] = self._transform_reasoning_content(reasoningContentBlocks)
chat_completion_message[
"thinking_blocks"
] = self._transform_thinking_blocks(reasoningContentBlocks)
chat_completion_message["content"] = content_str
if json_mode is True and tools is not None and len(tools) == 1:
# to support 'json_schema' logic on bedrock models

View file

@ -496,9 +496,9 @@ class BedrockLLM(BaseAWSLLM):
content=None,
)
model_response.choices[0].message = _message # type: ignore
model_response._hidden_params["original_response"] = (
outputText # allow user to access raw anthropic tool calling response
)
model_response._hidden_params[
"original_response"
] = outputText # allow user to access raw anthropic tool calling response
if (
_is_function_call is True
and stream is not None
@ -806,9 +806,9 @@ class BedrockLLM(BaseAWSLLM):
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
if stream is True:
inference_params["stream"] = (
True # cohere requires stream = True in inference params
)
inference_params[
"stream"
] = True # cohere requires stream = True in inference params
data = json.dumps({"prompt": prompt, **inference_params})
elif provider == "anthropic":
if model.startswith("anthropic.claude-3"):
@ -1205,7 +1205,6 @@ class BedrockLLM(BaseAWSLLM):
def get_response_stream_shape():
global _response_stream_shape_cache
if _response_stream_shape_cache is None:
from botocore.loaders import Loader
from botocore.model import ServiceModel
@ -1539,7 +1538,6 @@ class AmazonDeepSeekR1StreamDecoder(AWSEventStreamDecoder):
model: str,
sync_stream: bool,
) -> None:
super().__init__(model=model)
from litellm.llms.bedrock.chat.invoke_transformations.amazon_deepseek_transformation import (
AmazonDeepseekR1ResponseIterator,

View file

@ -225,9 +225,9 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
if stream is True:
inference_params["stream"] = (
True # cohere requires stream = True in inference params
)
inference_params[
"stream"
] = True # cohere requires stream = True in inference params
request_data = {"prompt": prompt, **inference_params}
elif provider == "anthropic":
return litellm.AmazonAnthropicClaude3Config().transform_request(
@ -311,7 +311,6 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
try:
completion_response = raw_response.json()
except Exception:

View file

@ -314,7 +314,6 @@ def get_bedrock_tool_name(response_tool_name: str) -> str:
class BedrockModelInfo(BaseLLMModelInfo):
global_config = AmazonBedrockGlobalConfig()
all_global_regions = global_config.get_all_regions()

View file

@ -33,9 +33,9 @@ class AmazonTitanMultimodalEmbeddingG1Config:
) -> dict:
for k, v in non_default_params.items():
if k == "dimensions":
optional_params["embeddingConfig"] = (
AmazonTitanMultimodalEmbeddingConfig(outputEmbeddingLength=v)
)
optional_params[
"embeddingConfig"
] = AmazonTitanMultimodalEmbeddingConfig(outputEmbeddingLength=v)
return optional_params
def _transform_request(
@ -58,7 +58,6 @@ class AmazonTitanMultimodalEmbeddingG1Config:
def _transform_response(
self, response_list: List[dict], model: str
) -> EmbeddingResponse:
total_prompt_tokens = 0
transformed_responses: List[Embedding] = []
for index, response in enumerate(response_list):

View file

@ -1,12 +1,16 @@
import types
from typing import List, Optional
from typing import Any, Dict, List, Optional
from openai.types.image import Image
from litellm.types.llms.bedrock import (
AmazonNovaCanvasTextToImageRequest, AmazonNovaCanvasTextToImageResponse,
AmazonNovaCanvasTextToImageParams, AmazonNovaCanvasRequestBase, AmazonNovaCanvasColorGuidedGenerationParams,
AmazonNovaCanvasColorGuidedGenerationParams,
AmazonNovaCanvasColorGuidedRequest,
AmazonNovaCanvasImageGenerationConfig,
AmazonNovaCanvasRequestBase,
AmazonNovaCanvasTextToImageParams,
AmazonNovaCanvasTextToImageRequest,
AmazonNovaCanvasTextToImageResponse,
)
from litellm.types.utils import ImageResponse
@ -23,7 +27,7 @@ class AmazonNovaCanvasConfig:
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
and not isinstance(
v,
(
types.FunctionType,
@ -32,13 +36,12 @@ class AmazonNovaCanvasConfig:
staticmethod,
),
)
and v is not None
and v is not None
}
@classmethod
def get_supported_openai_params(cls, model: Optional[str] = None) -> List:
"""
"""
""" """
return ["n", "size", "quality"]
@classmethod
@ -56,7 +59,7 @@ class AmazonNovaCanvasConfig:
@classmethod
def transform_request_body(
cls, text: str, optional_params: dict
cls, text: str, optional_params: dict
) -> AmazonNovaCanvasRequestBase:
"""
Transform the request body for Amazon Nova Canvas model
@ -65,18 +68,64 @@ class AmazonNovaCanvasConfig:
image_generation_config = optional_params.pop("imageGenerationConfig", {})
image_generation_config = {**image_generation_config, **optional_params}
if task_type == "TEXT_IMAGE":
text_to_image_params = image_generation_config.pop("textToImageParams", {})
text_to_image_params = {"text" :text, **text_to_image_params}
text_to_image_params = AmazonNovaCanvasTextToImageParams(**text_to_image_params)
return AmazonNovaCanvasTextToImageRequest(textToImageParams=text_to_image_params, taskType=task_type,
imageGenerationConfig=image_generation_config)
text_to_image_params: Dict[str, Any] = image_generation_config.pop(
"textToImageParams", {}
)
text_to_image_params = {"text": text, **text_to_image_params}
try:
text_to_image_params_typed = AmazonNovaCanvasTextToImageParams(
**text_to_image_params # type: ignore
)
except Exception as e:
raise ValueError(
f"Error transforming text to image params: {e}. Got params: {text_to_image_params}, Expected params: {AmazonNovaCanvasTextToImageParams.__annotations__}"
)
try:
image_generation_config_typed = AmazonNovaCanvasImageGenerationConfig(
**image_generation_config
)
except Exception as e:
raise ValueError(
f"Error transforming image generation config: {e}. Got params: {image_generation_config}, Expected params: {AmazonNovaCanvasImageGenerationConfig.__annotations__}"
)
return AmazonNovaCanvasTextToImageRequest(
textToImageParams=text_to_image_params_typed,
taskType=task_type,
imageGenerationConfig=image_generation_config_typed,
)
if task_type == "COLOR_GUIDED_GENERATION":
color_guided_generation_params = image_generation_config.pop("colorGuidedGenerationParams", {})
color_guided_generation_params = {"text": text, **color_guided_generation_params}
color_guided_generation_params = AmazonNovaCanvasColorGuidedGenerationParams(**color_guided_generation_params)
return AmazonNovaCanvasColorGuidedRequest(taskType=task_type,
colorGuidedGenerationParams=color_guided_generation_params,
imageGenerationConfig=image_generation_config)
color_guided_generation_params: Dict[
str, Any
] = image_generation_config.pop("colorGuidedGenerationParams", {})
color_guided_generation_params = {
"text": text,
**color_guided_generation_params,
}
try:
color_guided_generation_params_typed = AmazonNovaCanvasColorGuidedGenerationParams(
**color_guided_generation_params # type: ignore
)
except Exception as e:
raise ValueError(
f"Error transforming color guided generation params: {e}. Got params: {color_guided_generation_params}, Expected params: {AmazonNovaCanvasColorGuidedGenerationParams.__annotations__}"
)
try:
image_generation_config_typed = AmazonNovaCanvasImageGenerationConfig(
**image_generation_config
)
except Exception as e:
raise ValueError(
f"Error transforming image generation config: {e}. Got params: {image_generation_config}, Expected params: {AmazonNovaCanvasImageGenerationConfig.__annotations__}"
)
return AmazonNovaCanvasColorGuidedRequest(
taskType=task_type,
colorGuidedGenerationParams=color_guided_generation_params_typed,
imageGenerationConfig=image_generation_config_typed,
)
raise NotImplementedError(f"Task type {task_type} is not supported")
@classmethod
@ -87,7 +136,9 @@ class AmazonNovaCanvasConfig:
_size = non_default_params.get("size")
if _size is not None:
width, height = _size.split("x")
optional_params["width"], optional_params["height"] = int(width), int(height)
optional_params["width"], optional_params["height"] = int(width), int(
height
)
if non_default_params.get("n") is not None:
optional_params["numberOfImages"] = non_default_params.get("n")
if non_default_params.get("quality") is not None:
@ -99,7 +150,7 @@ class AmazonNovaCanvasConfig:
@classmethod
def transform_response_dict_to_openai_response(
cls, model_response: ImageResponse, response_dict: dict
cls, model_response: ImageResponse, response_dict: dict
) -> ImageResponse:
"""
Transform the response dict to the OpenAI response

View file

@ -267,7 +267,11 @@ class BedrockImageGeneration(BaseAWSLLM):
**inference_params,
}
elif provider == "amazon":
return dict(litellm.AmazonNovaCanvasConfig.transform_request_body(text=prompt, optional_params=optional_params))
return dict(
litellm.AmazonNovaCanvasConfig.transform_request_body(
text=prompt, optional_params=optional_params
)
)
else:
raise BedrockError(
status_code=422, message=f"Unsupported model={model}, passed in"
@ -303,8 +307,11 @@ class BedrockImageGeneration(BaseAWSLLM):
config_class = (
litellm.AmazonStability3Config
if litellm.AmazonStability3Config._is_stability_3_model(model=model)
else litellm.AmazonNovaCanvasConfig if litellm.AmazonNovaCanvasConfig._is_nova_model(model=model)
else litellm.AmazonStabilityConfig
else (
litellm.AmazonNovaCanvasConfig
if litellm.AmazonNovaCanvasConfig._is_nova_model(model=model)
else litellm.AmazonStabilityConfig
)
)
config_class.transform_response_dict_to_openai_response(
model_response=model_response,

View file

@ -60,7 +60,6 @@ class BedrockRerankHandler(BaseAWSLLM):
extra_headers: Optional[dict] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
) -> RerankResponse:
request_data = RerankRequest(
model=model,
query=query,

View file

@ -29,7 +29,6 @@ from litellm.types.rerank import (
class BedrockRerankConfig:
def _transform_sources(
self, documents: List[Union[str, dict]]
) -> List[BedrockRerankSource]:

View file

@ -314,7 +314,6 @@ class CodestralTextCompletion:
return _response
### SYNC COMPLETION
else:
response = litellm.module_level_client.post(
url=completion_url,
headers=headers,
@ -352,13 +351,11 @@ class CodestralTextCompletion:
logger_fn=None,
headers={},
) -> TextCompletionResponse:
async_handler = get_async_httpx_client(
llm_provider=litellm.LlmProviders.TEXT_COMPLETION_CODESTRAL,
params={"timeout": timeout},
)
try:
response = await async_handler.post(
api_base, headers=headers, data=json.dumps(data)
)

View file

@ -78,7 +78,6 @@ class CodestralTextCompletionConfig(OpenAITextCompletionConfig):
return optional_params
def _chunk_parser(self, chunk_data: str) -> GenericStreamingChunk:
text = ""
is_finished = False
finish_reason = None

View file

@ -180,7 +180,6 @@ class CohereChatConfig(BaseConfig):
litellm_params: dict,
headers: dict,
) -> dict:
## Load Config
for k, v in litellm.CohereChatConfig.get_config().items():
if (
@ -222,7 +221,6 @@ class CohereChatConfig(BaseConfig):
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
try:
raw_response_json = raw_response.json()
model_response.choices[0].message.content = raw_response_json["text"] # type: ignore

View file

@ -56,7 +56,6 @@ async def async_embedding(
encoding: Callable,
client: Optional[AsyncHTTPHandler] = None,
):
## LOGGING
logging_obj.pre_call(
input=input,

View file

@ -72,7 +72,6 @@ class CohereEmbeddingConfig:
return transformed_request
def _calculate_usage(self, input: List[str], encoding: Any, meta: dict) -> Usage:
input_tokens = 0
text_tokens: Optional[int] = meta.get("billed_units", {}).get("input_tokens")
@ -111,7 +110,6 @@ class CohereEmbeddingConfig:
encoding: Any,
input: list,
) -> EmbeddingResponse:
response_json = response.json()
## LOGGING
logging_obj.post_call(

View file

@ -148,4 +148,4 @@ class CohereRerankConfig(BaseRerankConfig):
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return CohereError(message=error_message, status_code=status_code)
return CohereError(message=error_message, status_code=status_code)

View file

@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Union
from litellm.llms.cohere.rerank.transformation import CohereRerankConfig
from litellm.types.rerank import OptionalRerankParams, RerankRequest
class CohereRerankV2Config(CohereRerankConfig):
"""
Reference: https://docs.cohere.com/v2/reference/rerank
@ -77,4 +78,4 @@ class CohereRerankV2Config(CohereRerankConfig):
return_documents=optional_rerank_params.get("return_documents", None),
max_tokens_per_doc=optional_rerank_params.get("max_tokens_per_doc", None),
)
return rerank_request.model_dump(exclude_none=True)
return rerank_request.model_dump(exclude_none=True)

View file

@ -32,7 +32,6 @@ DEFAULT_TIMEOUT = 600
class BaseLLMAIOHTTPHandler:
def __init__(self):
self.client_session: Optional[aiohttp.ClientSession] = None
@ -110,7 +109,6 @@ class BaseLLMAIOHTTPHandler:
content: Any = None,
params: Optional[dict] = None,
) -> httpx.Response:
max_retry_on_unprocessable_entity_error = (
provider_config.max_retry_on_unprocessable_entity_error
)

View file

@ -114,7 +114,6 @@ class AsyncHTTPHandler:
event_hooks: Optional[Mapping[str, List[Callable[..., Any]]]],
ssl_verify: Optional[VerifyTypes] = None,
) -> httpx.AsyncClient:
# SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
# /path/to/certificate.pem
if ssl_verify is None:
@ -590,7 +589,6 @@ class HTTPHandler:
timeout: Optional[Union[float, httpx.Timeout]] = None,
):
try:
if timeout is not None:
req = self.client.build_request(
"PATCH", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore
@ -609,7 +607,6 @@ class HTTPHandler:
llm_provider="litellm-httpx-handler",
)
except httpx.HTTPStatusError as e:
if stream is True:
setattr(e, "message", mask_sensitive_info(e.response.read()))
setattr(e, "text", mask_sensitive_info(e.response.read()))
@ -635,7 +632,6 @@ class HTTPHandler:
timeout: Optional[Union[float, httpx.Timeout]] = None,
):
try:
if timeout is not None:
req = self.client.build_request(
"PUT", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore

View file

@ -41,7 +41,6 @@ else:
class BaseLLMHTTPHandler:
async def _make_common_async_call(
self,
async_httpx_client: AsyncHTTPHandler,
@ -109,7 +108,6 @@ class BaseLLMHTTPHandler:
logging_obj: LiteLLMLoggingObj,
stream: bool = False,
) -> httpx.Response:
max_retry_on_unprocessable_entity_error = (
provider_config.max_retry_on_unprocessable_entity_error
)
@ -599,7 +597,6 @@ class BaseLLMHTTPHandler:
aembedding: bool = False,
headers={},
) -> EmbeddingResponse:
provider_config = ProviderConfigManager.get_provider_embedding_config(
model=model, provider=litellm.LlmProviders(custom_llm_provider)
)
@ -742,7 +739,6 @@ class BaseLLMHTTPHandler:
api_base: Optional[str] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
) -> RerankResponse:
# get config from model, custom llm provider
headers = provider_config.validate_environment(
api_key=api_key,
@ -828,7 +824,6 @@ class BaseLLMHTTPHandler:
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
) -> RerankResponse:
if client is None or not isinstance(client, AsyncHTTPHandler):
async_httpx_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders(custom_llm_provider)

View file

@ -16,9 +16,9 @@ class DatabricksBase:
api_base = api_base or f"{databricks_client.config.host}/serving-endpoints"
if api_key is None:
databricks_auth_headers: dict[str, str] = (
databricks_client.config.authenticate()
)
databricks_auth_headers: dict[
str, str
] = databricks_client.config.authenticate()
headers = {**databricks_auth_headers, **headers}
return api_base, headers

View file

@ -11,9 +11,9 @@ class DatabricksEmbeddingConfig:
Reference: https://learn.microsoft.com/en-us/azure/databricks/machine-learning/foundation-models/api-reference#--embedding-task
"""
instruction: Optional[str] = (
None # An optional instruction to pass to the embedding model. BGE Authors recommend 'Represent this sentence for searching relevant passages:' for retrieval queries
)
instruction: Optional[
str
] = None # An optional instruction to pass to the embedding model. BGE Authors recommend 'Represent this sentence for searching relevant passages:' for retrieval queries
def __init__(self, instruction: Optional[str] = None) -> None:
locals_ = locals().copy()

View file

@ -55,7 +55,6 @@ class ModelResponseIterator:
usage_chunk: Optional[Usage] = getattr(processed_chunk, "usage", None)
if usage_chunk is not None:
usage = ChatCompletionUsageBlock(
prompt_tokens=usage_chunk.prompt_tokens,
completion_tokens=usage_chunk.completion_tokens,

View file

@ -126,9 +126,9 @@ class DeepgramAudioTranscriptionConfig(BaseAudioTranscriptionConfig):
# Add additional metadata matching OpenAI format
response["task"] = "transcribe"
response["language"] = (
"english" # Deepgram auto-detects but doesn't return language
)
response[
"language"
] = "english" # Deepgram auto-detects but doesn't return language
response["duration"] = response_json["metadata"]["duration"]
# Transform words to match OpenAI format

Some files were not shown because too many files have changed in this diff Show more