mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
build(pyproject.toml): add new dev dependencies - for type checking (#9631)
* build(pyproject.toml): add new dev dependencies - for type checking * build: reformat files to fit black * ci: reformat to fit black * ci(test-litellm.yml): make tests run clear * build(pyproject.toml): add ruff * fix: fix ruff checks * build(mypy/): fix mypy linting errors * fix(hashicorp_secret_manager.py): fix passing cert for tls auth * build(mypy/): resolve all mypy errors * test: update test * fix: fix black formatting * build(pre-commit-config.yaml): use poetry run black * fix(proxy_server.py): fix linting error * fix: fix ruff safe representation error
This commit is contained in:
parent
95e5dfae5a
commit
9b7ebb6a7d
214 changed files with 1553 additions and 1433 deletions
53
.github/workflows/test-linting.yml
vendored
Normal file
53
.github/workflows/test-linting.yml
vendored
Normal file
|
@ -0,0 +1,53 @@
|
|||
name: LiteLLM Linting
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 5
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Install Poetry
|
||||
uses: snok/install-poetry@v1
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
poetry install --with dev
|
||||
|
||||
- name: Run Black formatting check
|
||||
run: |
|
||||
cd litellm
|
||||
poetry run black . --check
|
||||
cd ..
|
||||
|
||||
- name: Run Ruff linting
|
||||
run: |
|
||||
cd litellm
|
||||
poetry run ruff check .
|
||||
cd ..
|
||||
|
||||
- name: Run MyPy type checking
|
||||
run: |
|
||||
cd litellm
|
||||
poetry run mypy . --ignore-missing-imports
|
||||
cd ..
|
||||
|
||||
- name: Check for circular imports
|
||||
run: |
|
||||
cd litellm
|
||||
poetry run python ../tests/documentation_tests/test_circular_imports.py
|
||||
cd ..
|
||||
|
||||
- name: Check import safety
|
||||
run: |
|
||||
poetry run python -c "from litellm import *" || (echo '🚨 import failed, this means you introduced unprotected imports! 🚨'; exit 1)
|
2
.github/workflows/test-litellm.yml
vendored
2
.github/workflows/test-litellm.yml
vendored
|
@ -1,4 +1,4 @@
|
|||
name: LiteLLM Tests
|
||||
name: LiteLLM Mock Tests (folder - tests/litellm)
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
|
|
|
@ -14,10 +14,12 @@ repos:
|
|||
types: [python]
|
||||
files: litellm/.*\.py
|
||||
exclude: ^litellm/__init__.py$
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 24.2.0
|
||||
hooks:
|
||||
- id: black
|
||||
- id: black
|
||||
name: black
|
||||
entry: poetry run black
|
||||
language: system
|
||||
types: [python]
|
||||
files: litellm/.*\.py
|
||||
- repo: https://github.com/pycqa/flake8
|
||||
rev: 7.0.0 # The version of flake8 to use
|
||||
hooks:
|
||||
|
|
|
@ -444,9 +444,7 @@ class _ENTERPRISE_SecretDetection(CustomGuardrail):
|
|||
|
||||
detected_secrets = []
|
||||
for file in secrets.files:
|
||||
|
||||
for found_secret in secrets[file]:
|
||||
|
||||
if found_secret.secret_value is None:
|
||||
continue
|
||||
detected_secrets.append(
|
||||
|
@ -471,14 +469,12 @@ class _ENTERPRISE_SecretDetection(CustomGuardrail):
|
|||
data: dict,
|
||||
call_type: str, # "completion", "embeddings", "image_generation", "moderation"
|
||||
):
|
||||
|
||||
if await self.should_run_check(user_api_key_dict) is False:
|
||||
return
|
||||
|
||||
if "messages" in data and isinstance(data["messages"], list):
|
||||
for message in data["messages"]:
|
||||
if "content" in message and isinstance(message["content"], str):
|
||||
|
||||
detected_secrets = self.scan_message_for_secrets(message["content"])
|
||||
|
||||
for secret in detected_secrets:
|
||||
|
|
|
@ -122,19 +122,19 @@ langsmith_batch_size: Optional[int] = None
|
|||
prometheus_initialize_budget_metrics: Optional[bool] = False
|
||||
argilla_batch_size: Optional[int] = None
|
||||
datadog_use_v1: Optional[bool] = False # if you want to use v1 datadog logged payload
|
||||
gcs_pub_sub_use_v1: Optional[bool] = (
|
||||
False # if you want to use v1 gcs pubsub logged payload
|
||||
)
|
||||
gcs_pub_sub_use_v1: Optional[
|
||||
bool
|
||||
] = False # if you want to use v1 gcs pubsub logged payload
|
||||
argilla_transformation_object: Optional[Dict[str, Any]] = None
|
||||
_async_input_callback: List[Union[str, Callable, CustomLogger]] = (
|
||||
[]
|
||||
) # internal variable - async custom callbacks are routed here.
|
||||
_async_success_callback: List[Union[str, Callable, CustomLogger]] = (
|
||||
[]
|
||||
) # internal variable - async custom callbacks are routed here.
|
||||
_async_failure_callback: List[Union[str, Callable, CustomLogger]] = (
|
||||
[]
|
||||
) # internal variable - async custom callbacks are routed here.
|
||||
_async_input_callback: List[
|
||||
Union[str, Callable, CustomLogger]
|
||||
] = [] # internal variable - async custom callbacks are routed here.
|
||||
_async_success_callback: List[
|
||||
Union[str, Callable, CustomLogger]
|
||||
] = [] # internal variable - async custom callbacks are routed here.
|
||||
_async_failure_callback: List[
|
||||
Union[str, Callable, CustomLogger]
|
||||
] = [] # internal variable - async custom callbacks are routed here.
|
||||
pre_call_rules: List[Callable] = []
|
||||
post_call_rules: List[Callable] = []
|
||||
turn_off_message_logging: Optional[bool] = False
|
||||
|
@ -142,18 +142,18 @@ log_raw_request_response: bool = False
|
|||
redact_messages_in_exceptions: Optional[bool] = False
|
||||
redact_user_api_key_info: Optional[bool] = False
|
||||
filter_invalid_headers: Optional[bool] = False
|
||||
add_user_information_to_llm_headers: Optional[bool] = (
|
||||
None # adds user_id, team_id, token hash (params from StandardLoggingMetadata) to request headers
|
||||
)
|
||||
add_user_information_to_llm_headers: Optional[
|
||||
bool
|
||||
] = None # adds user_id, team_id, token hash (params from StandardLoggingMetadata) to request headers
|
||||
store_audit_logs = False # Enterprise feature, allow users to see audit logs
|
||||
### end of callbacks #############
|
||||
|
||||
email: Optional[str] = (
|
||||
None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||
)
|
||||
token: Optional[str] = (
|
||||
None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||
)
|
||||
email: Optional[
|
||||
str
|
||||
] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||
token: Optional[
|
||||
str
|
||||
] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||
telemetry = True
|
||||
max_tokens = 256 # OpenAI Defaults
|
||||
drop_params = bool(os.getenv("LITELLM_DROP_PARAMS", False))
|
||||
|
@ -229,24 +229,20 @@ enable_loadbalancing_on_batch_endpoints: Optional[bool] = None
|
|||
enable_caching_on_provider_specific_optional_params: bool = (
|
||||
False # feature-flag for caching on optional params - e.g. 'top_k'
|
||||
)
|
||||
caching: bool = (
|
||||
False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||
)
|
||||
caching_with_models: bool = (
|
||||
False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||
)
|
||||
cache: Optional[Cache] = (
|
||||
None # cache object <- use this - https://docs.litellm.ai/docs/caching
|
||||
)
|
||||
caching: bool = False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||
caching_with_models: bool = False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||
cache: Optional[
|
||||
Cache
|
||||
] = None # cache object <- use this - https://docs.litellm.ai/docs/caching
|
||||
default_in_memory_ttl: Optional[float] = None
|
||||
default_redis_ttl: Optional[float] = None
|
||||
default_redis_batch_cache_expiry: Optional[float] = None
|
||||
model_alias_map: Dict[str, str] = {}
|
||||
model_group_alias_map: Dict[str, str] = {}
|
||||
max_budget: float = 0.0 # set the max budget across all providers
|
||||
budget_duration: Optional[str] = (
|
||||
None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
|
||||
)
|
||||
budget_duration: Optional[
|
||||
str
|
||||
] = None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
|
||||
default_soft_budget: float = (
|
||||
50.0 # by default all litellm proxy keys have a soft budget of 50.0
|
||||
)
|
||||
|
@ -255,15 +251,11 @@ forward_traceparent_to_llm_provider: bool = False
|
|||
|
||||
_current_cost = 0.0 # private variable, used if max budget is set
|
||||
error_logs: Dict = {}
|
||||
add_function_to_prompt: bool = (
|
||||
False # if function calling not supported by api, append function call details to system prompt
|
||||
)
|
||||
add_function_to_prompt: bool = False # if function calling not supported by api, append function call details to system prompt
|
||||
client_session: Optional[httpx.Client] = None
|
||||
aclient_session: Optional[httpx.AsyncClient] = None
|
||||
model_fallbacks: Optional[List] = None # Deprecated for 'litellm.fallbacks'
|
||||
model_cost_map_url: str = (
|
||||
"https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
|
||||
)
|
||||
model_cost_map_url: str = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
|
||||
suppress_debug_info = False
|
||||
dynamodb_table_name: Optional[str] = None
|
||||
s3_callback_params: Optional[Dict] = None
|
||||
|
@ -285,9 +277,7 @@ disable_end_user_cost_tracking_prometheus_only: Optional[bool] = None
|
|||
custom_prometheus_metadata_labels: List[str] = []
|
||||
#### REQUEST PRIORITIZATION ####
|
||||
priority_reservation: Optional[Dict[str, float]] = None
|
||||
force_ipv4: bool = (
|
||||
False # when True, litellm will force ipv4 for all LLM requests. Some users have seen httpx ConnectionError when using ipv6.
|
||||
)
|
||||
force_ipv4: bool = False # when True, litellm will force ipv4 for all LLM requests. Some users have seen httpx ConnectionError when using ipv6.
|
||||
module_level_aclient = AsyncHTTPHandler(
|
||||
timeout=request_timeout, client_alias="module level aclient"
|
||||
)
|
||||
|
@ -301,13 +291,13 @@ fallbacks: Optional[List] = None
|
|||
context_window_fallbacks: Optional[List] = None
|
||||
content_policy_fallbacks: Optional[List] = None
|
||||
allowed_fails: int = 3
|
||||
num_retries_per_request: Optional[int] = (
|
||||
None # for the request overall (incl. fallbacks + model retries)
|
||||
)
|
||||
num_retries_per_request: Optional[
|
||||
int
|
||||
] = None # for the request overall (incl. fallbacks + model retries)
|
||||
####### SECRET MANAGERS #####################
|
||||
secret_manager_client: Optional[Any] = (
|
||||
None # list of instantiated key management clients - e.g. azure kv, infisical, etc.
|
||||
)
|
||||
secret_manager_client: Optional[
|
||||
Any
|
||||
] = None # list of instantiated key management clients - e.g. azure kv, infisical, etc.
|
||||
_google_kms_resource_name: Optional[str] = None
|
||||
_key_management_system: Optional[KeyManagementSystem] = None
|
||||
_key_management_settings: KeyManagementSettings = KeyManagementSettings()
|
||||
|
@ -1056,10 +1046,10 @@ from .types.llms.custom_llm import CustomLLMItem
|
|||
from .types.utils import GenericStreamingChunk
|
||||
|
||||
custom_provider_map: List[CustomLLMItem] = []
|
||||
_custom_providers: List[str] = (
|
||||
[]
|
||||
) # internal helper util, used to track names of custom providers
|
||||
disable_hf_tokenizer_download: Optional[bool] = (
|
||||
None # disable huggingface tokenizer download. Defaults to openai clk100
|
||||
)
|
||||
_custom_providers: List[
|
||||
str
|
||||
] = [] # internal helper util, used to track names of custom providers
|
||||
disable_hf_tokenizer_download: Optional[
|
||||
bool
|
||||
] = None # disable huggingface tokenizer download. Defaults to openai clk100
|
||||
global_disable_no_log_param: bool = False
|
||||
|
|
|
@ -15,7 +15,7 @@ from .types.services import ServiceLoggerPayload, ServiceTypes
|
|||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
Span = _Span
|
||||
Span = Union[_Span, Any]
|
||||
OTELClass = OpenTelemetry
|
||||
else:
|
||||
Span = Any
|
||||
|
|
|
@ -153,7 +153,6 @@ def create_batch(
|
|||
)
|
||||
api_base: Optional[str] = None
|
||||
if custom_llm_provider == "openai":
|
||||
|
||||
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||
api_base = (
|
||||
optional_params.api_base
|
||||
|
@ -358,7 +357,6 @@ def retrieve_batch(
|
|||
_is_async = kwargs.pop("aretrieve_batch", False) is True
|
||||
api_base: Optional[str] = None
|
||||
if custom_llm_provider == "openai":
|
||||
|
||||
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||
api_base = (
|
||||
optional_params.api_base
|
||||
|
|
|
@ -9,12 +9,12 @@ Has 4 methods:
|
|||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
Span = _Span
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
Span = Any
|
||||
|
||||
|
|
|
@ -66,9 +66,7 @@ class CachingHandlerResponse(BaseModel):
|
|||
|
||||
cached_result: Optional[Any] = None
|
||||
final_embedding_cached_response: Optional[EmbeddingResponse] = None
|
||||
embedding_all_elements_cache_hit: bool = (
|
||||
False # this is set to True when all elements in the list have a cache hit in the embedding cache, if true return the final_embedding_cached_response no need to make an API call
|
||||
)
|
||||
embedding_all_elements_cache_hit: bool = False # this is set to True when all elements in the list have a cache hit in the embedding cache, if true return the final_embedding_cached_response no need to make an API call
|
||||
|
||||
|
||||
class LLMCachingHandler:
|
||||
|
@ -738,7 +736,6 @@ class LLMCachingHandler:
|
|||
if self._should_store_result_in_cache(
|
||||
original_function=self.original_function, kwargs=new_kwargs
|
||||
):
|
||||
|
||||
litellm.cache.add_cache(result, **new_kwargs)
|
||||
|
||||
return
|
||||
|
@ -865,9 +862,9 @@ class LLMCachingHandler:
|
|||
}
|
||||
|
||||
if litellm.cache is not None:
|
||||
litellm_params["preset_cache_key"] = (
|
||||
litellm.cache._get_preset_cache_key_from_kwargs(**kwargs)
|
||||
)
|
||||
litellm_params[
|
||||
"preset_cache_key"
|
||||
] = litellm.cache._get_preset_cache_key_from_kwargs(**kwargs)
|
||||
else:
|
||||
litellm_params["preset_cache_key"] = None
|
||||
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
import json
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from .base_cache import BaseCache
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
Span = _Span
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
Span = Any
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ import asyncio
|
|||
import time
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||
|
||||
import litellm
|
||||
from litellm._logging import print_verbose, verbose_logger
|
||||
|
@ -24,7 +24,7 @@ from .redis_cache import RedisCache
|
|||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
Span = _Span
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
Span = Any
|
||||
|
||||
|
|
|
@ -8,7 +8,6 @@ from .in_memory_cache import InMemoryCache
|
|||
|
||||
|
||||
class LLMClientCache(InMemoryCache):
|
||||
|
||||
def update_cache_key_with_event_loop(self, key):
|
||||
"""
|
||||
Add the event loop to the cache key, to prevent event loop closed errors.
|
||||
|
|
|
@ -34,7 +34,7 @@ if TYPE_CHECKING:
|
|||
cluster_pipeline = ClusterPipeline
|
||||
async_redis_client = Redis
|
||||
async_redis_cluster_client = RedisCluster
|
||||
Span = _Span
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
pipeline = Any
|
||||
cluster_pipeline = Any
|
||||
|
@ -57,7 +57,6 @@ class RedisCache(BaseCache):
|
|||
socket_timeout: Optional[float] = 5.0, # default 5 second timeout
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
from litellm._service_logger import ServiceLogging
|
||||
|
||||
from .._redis import get_redis_client, get_redis_connection_pool
|
||||
|
|
|
@ -5,7 +5,7 @@ Key differences:
|
|||
- RedisClient NEEDs to be re-used across requests, adds 3000ms latency if it's re-created
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||
|
||||
from litellm.caching.redis_cache import RedisCache
|
||||
|
||||
|
@ -16,7 +16,7 @@ if TYPE_CHECKING:
|
|||
|
||||
pipeline = Pipeline
|
||||
async_redis_client = Redis
|
||||
Span = _Span
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
pipeline = Any
|
||||
async_redis_client = Any
|
||||
|
|
|
@ -13,23 +13,27 @@ import ast
|
|||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||
|
||||
import litellm
|
||||
from litellm._logging import print_verbose
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import get_str_from_messages
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
get_str_from_messages,
|
||||
)
|
||||
from litellm.types.utils import EmbeddingResponse
|
||||
|
||||
from .base_cache import BaseCache
|
||||
|
||||
|
||||
class RedisSemanticCache(BaseCache):
|
||||
"""
|
||||
Redis-backed semantic cache for LLM responses.
|
||||
|
||||
This cache uses vector similarity to find semantically similar prompts that have been
|
||||
Redis-backed semantic cache for LLM responses.
|
||||
|
||||
This cache uses vector similarity to find semantically similar prompts that have been
|
||||
previously sent to the LLM, allowing for cache hits even when prompts are not identical
|
||||
but carry similar meaning.
|
||||
"""
|
||||
|
||||
|
||||
DEFAULT_REDIS_INDEX_NAME: str = "litellm_semantic_cache_index"
|
||||
|
||||
def __init__(
|
||||
|
@ -57,7 +61,7 @@ class RedisSemanticCache(BaseCache):
|
|||
index_name: Name for the Redis index
|
||||
ttl: Default time-to-live for cache entries in seconds
|
||||
**kwargs: Additional arguments passed to the Redis client
|
||||
|
||||
|
||||
Raises:
|
||||
Exception: If similarity_threshold is not provided or required Redis
|
||||
connection information is missing
|
||||
|
@ -69,14 +73,14 @@ class RedisSemanticCache(BaseCache):
|
|||
index_name = self.DEFAULT_REDIS_INDEX_NAME
|
||||
|
||||
print_verbose(f"Redis semantic-cache initializing index - {index_name}")
|
||||
|
||||
|
||||
# Validate similarity threshold
|
||||
if similarity_threshold is None:
|
||||
raise ValueError("similarity_threshold must be provided, passed None")
|
||||
|
||||
|
||||
# Store configuration
|
||||
self.similarity_threshold = similarity_threshold
|
||||
|
||||
|
||||
# Convert similarity threshold [0,1] to distance threshold [0,2]
|
||||
# For cosine distance: 0 = most similar, 2 = least similar
|
||||
# While similarity: 1 = most similar, 0 = least similar
|
||||
|
@ -87,14 +91,16 @@ class RedisSemanticCache(BaseCache):
|
|||
if redis_url is None:
|
||||
try:
|
||||
# Attempt to use provided parameters or fallback to environment variables
|
||||
host = host or os.environ['REDIS_HOST']
|
||||
port = port or os.environ['REDIS_PORT']
|
||||
password = password or os.environ['REDIS_PASSWORD']
|
||||
host = host or os.environ["REDIS_HOST"]
|
||||
port = port or os.environ["REDIS_PORT"]
|
||||
password = password or os.environ["REDIS_PASSWORD"]
|
||||
except KeyError as e:
|
||||
# Raise a more informative exception if any of the required keys are missing
|
||||
missing_var = e.args[0]
|
||||
raise ValueError(f"Missing required Redis configuration: {missing_var}. "
|
||||
f"Provide {missing_var} or redis_url.") from e
|
||||
raise ValueError(
|
||||
f"Missing required Redis configuration: {missing_var}. "
|
||||
f"Provide {missing_var} or redis_url."
|
||||
) from e
|
||||
|
||||
redis_url = f"redis://:{password}@{host}:{port}"
|
||||
|
||||
|
@ -114,7 +120,7 @@ class RedisSemanticCache(BaseCache):
|
|||
def _get_ttl(self, **kwargs) -> Optional[int]:
|
||||
"""
|
||||
Get the TTL (time-to-live) value for cache entries.
|
||||
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments that may contain a custom TTL
|
||||
|
||||
|
@ -125,22 +131,25 @@ class RedisSemanticCache(BaseCache):
|
|||
if ttl is not None:
|
||||
ttl = int(ttl)
|
||||
return ttl
|
||||
|
||||
|
||||
def _get_embedding(self, prompt: str) -> List[float]:
|
||||
"""
|
||||
Generate an embedding vector for the given prompt using the configured embedding model.
|
||||
|
||||
|
||||
Args:
|
||||
prompt: The text to generate an embedding for
|
||||
|
||||
|
||||
Returns:
|
||||
List[float]: The embedding vector
|
||||
"""
|
||||
# Create an embedding from prompt
|
||||
embedding_response = litellm.embedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
embedding_response = cast(
|
||||
EmbeddingResponse,
|
||||
litellm.embedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
),
|
||||
)
|
||||
embedding = embedding_response["data"][0]["embedding"]
|
||||
return embedding
|
||||
|
@ -148,10 +157,10 @@ class RedisSemanticCache(BaseCache):
|
|||
def _get_cache_logic(self, cached_response: Any) -> Any:
|
||||
"""
|
||||
Process the cached response to prepare it for use.
|
||||
|
||||
|
||||
Args:
|
||||
cached_response: The raw cached response
|
||||
|
||||
|
||||
Returns:
|
||||
The processed cache response, or None if input was None
|
||||
"""
|
||||
|
@ -171,13 +180,13 @@ class RedisSemanticCache(BaseCache):
|
|||
except (ValueError, SyntaxError) as e:
|
||||
print_verbose(f"Error parsing cached response: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
return cached_response
|
||||
|
||||
def set_cache(self, key: str, value: Any, **kwargs) -> None:
|
||||
"""
|
||||
Store a value in the semantic cache.
|
||||
|
||||
|
||||
Args:
|
||||
key: The cache key (not directly used in semantic caching)
|
||||
value: The response value to cache
|
||||
|
@ -186,13 +195,14 @@ class RedisSemanticCache(BaseCache):
|
|||
"""
|
||||
print_verbose(f"Redis semantic-cache set_cache, kwargs: {kwargs}")
|
||||
|
||||
value_str: Optional[str] = None
|
||||
try:
|
||||
# Extract the prompt from messages
|
||||
messages = kwargs.get("messages", [])
|
||||
if not messages:
|
||||
print_verbose("No messages provided for semantic caching")
|
||||
return
|
||||
|
||||
|
||||
prompt = get_str_from_messages(messages)
|
||||
value_str = str(value)
|
||||
|
||||
|
@ -203,16 +213,18 @@ class RedisSemanticCache(BaseCache):
|
|||
else:
|
||||
self.llmcache.store(prompt, value_str)
|
||||
except Exception as e:
|
||||
print_verbose(f"Error setting {value_str} in the Redis semantic cache: {str(e)}")
|
||||
print_verbose(
|
||||
f"Error setting {value_str or value} in the Redis semantic cache: {str(e)}"
|
||||
)
|
||||
|
||||
def get_cache(self, key: str, **kwargs) -> Any:
|
||||
"""
|
||||
Retrieve a semantically similar cached response.
|
||||
|
||||
|
||||
Args:
|
||||
key: The cache key (not directly used in semantic caching)
|
||||
**kwargs: Additional arguments including 'messages' for the prompt
|
||||
|
||||
|
||||
Returns:
|
||||
The cached response if a semantically similar prompt is found, else None
|
||||
"""
|
||||
|
@ -224,7 +236,7 @@ class RedisSemanticCache(BaseCache):
|
|||
if not messages:
|
||||
print_verbose("No messages provided for semantic cache lookup")
|
||||
return None
|
||||
|
||||
|
||||
prompt = get_str_from_messages(messages)
|
||||
# Check the cache for semantically similar prompts
|
||||
results = self.llmcache.check(prompt=prompt)
|
||||
|
@ -236,12 +248,12 @@ class RedisSemanticCache(BaseCache):
|
|||
# Process the best matching result
|
||||
cache_hit = results[0]
|
||||
vector_distance = float(cache_hit["vector_distance"])
|
||||
|
||||
|
||||
# Convert vector distance back to similarity score
|
||||
# For cosine distance: 0 = most similar, 2 = least similar
|
||||
# While similarity: 1 = most similar, 0 = least similar
|
||||
similarity = 1 - vector_distance
|
||||
|
||||
|
||||
cached_prompt = cache_hit["prompt"]
|
||||
cached_response = cache_hit["response"]
|
||||
|
||||
|
@ -251,19 +263,19 @@ class RedisSemanticCache(BaseCache):
|
|||
f"current prompt: {prompt}, "
|
||||
f"cached prompt: {cached_prompt}"
|
||||
)
|
||||
|
||||
|
||||
return self._get_cache_logic(cached_response=cached_response)
|
||||
except Exception as e:
|
||||
print_verbose(f"Error retrieving from Redis semantic cache: {str(e)}")
|
||||
|
||||
|
||||
async def _get_async_embedding(self, prompt: str, **kwargs) -> List[float]:
|
||||
"""
|
||||
Asynchronously generate an embedding for the given prompt.
|
||||
|
||||
|
||||
Args:
|
||||
prompt: The text to generate an embedding for
|
||||
**kwargs: Additional arguments that may contain metadata
|
||||
|
||||
|
||||
Returns:
|
||||
List[float]: The embedding vector
|
||||
"""
|
||||
|
@ -275,7 +287,7 @@ class RedisSemanticCache(BaseCache):
|
|||
if llm_model_list is not None
|
||||
else []
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
if llm_router is not None and self.embedding_model in router_model_names:
|
||||
# Use the router for embedding generation
|
||||
|
@ -307,7 +319,7 @@ class RedisSemanticCache(BaseCache):
|
|||
async def async_set_cache(self, key: str, value: Any, **kwargs) -> None:
|
||||
"""
|
||||
Asynchronously store a value in the semantic cache.
|
||||
|
||||
|
||||
Args:
|
||||
key: The cache key (not directly used in semantic caching)
|
||||
value: The response value to cache
|
||||
|
@ -322,13 +334,13 @@ class RedisSemanticCache(BaseCache):
|
|||
if not messages:
|
||||
print_verbose("No messages provided for semantic caching")
|
||||
return
|
||||
|
||||
|
||||
prompt = get_str_from_messages(messages)
|
||||
value_str = str(value)
|
||||
|
||||
# Generate embedding for the value (response) to cache
|
||||
prompt_embedding = await self._get_async_embedding(prompt, **kwargs)
|
||||
|
||||
|
||||
# Get TTL and store in Redis semantic cache
|
||||
ttl = self._get_ttl(**kwargs)
|
||||
if ttl is not None:
|
||||
|
@ -336,13 +348,13 @@ class RedisSemanticCache(BaseCache):
|
|||
prompt,
|
||||
value_str,
|
||||
vector=prompt_embedding, # Pass through custom embedding
|
||||
ttl=ttl
|
||||
ttl=ttl,
|
||||
)
|
||||
else:
|
||||
await self.llmcache.astore(
|
||||
prompt,
|
||||
value_str,
|
||||
vector=prompt_embedding # Pass through custom embedding
|
||||
vector=prompt_embedding, # Pass through custom embedding
|
||||
)
|
||||
except Exception as e:
|
||||
print_verbose(f"Error in async_set_cache: {str(e)}")
|
||||
|
@ -350,11 +362,11 @@ class RedisSemanticCache(BaseCache):
|
|||
async def async_get_cache(self, key: str, **kwargs) -> Any:
|
||||
"""
|
||||
Asynchronously retrieve a semantically similar cached response.
|
||||
|
||||
|
||||
Args:
|
||||
key: The cache key (not directly used in semantic caching)
|
||||
**kwargs: Additional arguments including 'messages' for the prompt
|
||||
|
||||
|
||||
Returns:
|
||||
The cached response if a semantically similar prompt is found, else None
|
||||
"""
|
||||
|
@ -367,21 +379,20 @@ class RedisSemanticCache(BaseCache):
|
|||
print_verbose("No messages provided for semantic cache lookup")
|
||||
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
|
||||
return None
|
||||
|
||||
|
||||
prompt = get_str_from_messages(messages)
|
||||
|
||||
|
||||
# Generate embedding for the prompt
|
||||
prompt_embedding = await self._get_async_embedding(prompt, **kwargs)
|
||||
|
||||
# Check the cache for semantically similar prompts
|
||||
results = await self.llmcache.acheck(
|
||||
prompt=prompt,
|
||||
vector=prompt_embedding
|
||||
)
|
||||
results = await self.llmcache.acheck(prompt=prompt, vector=prompt_embedding)
|
||||
|
||||
# handle results / cache hit
|
||||
if not results:
|
||||
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0 # TODO why here but not above??
|
||||
kwargs.setdefault("metadata", {})[
|
||||
"semantic-similarity"
|
||||
] = 0.0 # TODO why here but not above??
|
||||
return None
|
||||
|
||||
cache_hit = results[0]
|
||||
|
@ -404,7 +415,7 @@ class RedisSemanticCache(BaseCache):
|
|||
f"current prompt: {prompt}, "
|
||||
f"cached prompt: {cached_prompt}"
|
||||
)
|
||||
|
||||
|
||||
return self._get_cache_logic(cached_response=cached_response)
|
||||
except Exception as e:
|
||||
print_verbose(f"Error in async_get_cache: {str(e)}")
|
||||
|
@ -413,17 +424,19 @@ class RedisSemanticCache(BaseCache):
|
|||
async def _index_info(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get information about the Redis index.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Information about the Redis index
|
||||
"""
|
||||
aindex = await self.llmcache._get_async_index()
|
||||
return await aindex.info()
|
||||
|
||||
async def async_set_cache_pipeline(self, cache_list: List[Tuple[str, Any]], **kwargs) -> None:
|
||||
async def async_set_cache_pipeline(
|
||||
self, cache_list: List[Tuple[str, Any]], **kwargs
|
||||
) -> None:
|
||||
"""
|
||||
Asynchronously store multiple values in the semantic cache.
|
||||
|
||||
|
||||
Args:
|
||||
cache_list: List of (key, value) tuples to cache
|
||||
**kwargs: Additional arguments
|
||||
|
|
|
@ -123,7 +123,7 @@ class S3Cache(BaseCache):
|
|||
) # Convert string to dictionary
|
||||
except Exception:
|
||||
cached_response = ast.literal_eval(cached_response)
|
||||
if type(cached_response) is not dict:
|
||||
if not isinstance(cached_response, dict):
|
||||
cached_response = dict(cached_response)
|
||||
verbose_logger.debug(
|
||||
f"Got S3 Cache: key: {key}, cached_response {cached_response}. Type Response {type(cached_response)}"
|
||||
|
|
|
@ -580,7 +580,6 @@ def completion_cost( # noqa: PLR0915
|
|||
- For un-mapped Replicate models, the cost is calculated based on the total time used for the request.
|
||||
"""
|
||||
try:
|
||||
|
||||
call_type = _infer_call_type(call_type, completion_response) or "completion"
|
||||
|
||||
if (
|
||||
|
|
|
@ -138,7 +138,6 @@ def create_fine_tuning_job(
|
|||
|
||||
# OpenAI
|
||||
if custom_llm_provider == "openai":
|
||||
|
||||
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||
api_base = (
|
||||
optional_params.api_base
|
||||
|
@ -360,7 +359,6 @@ def cancel_fine_tuning_job(
|
|||
|
||||
# OpenAI
|
||||
if custom_llm_provider == "openai":
|
||||
|
||||
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||
api_base = (
|
||||
optional_params.api_base
|
||||
|
@ -522,7 +520,6 @@ def list_fine_tuning_jobs(
|
|||
|
||||
# OpenAI
|
||||
if custom_llm_provider == "openai":
|
||||
|
||||
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||
api_base = (
|
||||
optional_params.api_base
|
||||
|
|
|
@ -19,7 +19,6 @@ else:
|
|||
|
||||
|
||||
def squash_payloads(queue):
|
||||
|
||||
squashed = {}
|
||||
if len(queue) == 0:
|
||||
return squashed
|
||||
|
|
|
@ -195,12 +195,15 @@ class SlackAlerting(CustomBatchLogger):
|
|||
if self.alerting is None or self.alert_types is None:
|
||||
return
|
||||
|
||||
time_difference_float, model, api_base, messages = (
|
||||
self._response_taking_too_long_callback_helper(
|
||||
kwargs=kwargs,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
(
|
||||
time_difference_float,
|
||||
model,
|
||||
api_base,
|
||||
messages,
|
||||
) = self._response_taking_too_long_callback_helper(
|
||||
kwargs=kwargs,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
if litellm.turn_off_message_logging or litellm.redact_messages_in_exceptions:
|
||||
messages = "Message not logged. litellm.redact_messages_in_exceptions=True"
|
||||
|
@ -819,9 +822,9 @@ class SlackAlerting(CustomBatchLogger):
|
|||
### UNIQUE CACHE KEY ###
|
||||
cache_key = provider + region_name
|
||||
|
||||
outage_value: Optional[ProviderRegionOutageModel] = (
|
||||
await self.internal_usage_cache.async_get_cache(key=cache_key)
|
||||
)
|
||||
outage_value: Optional[
|
||||
ProviderRegionOutageModel
|
||||
] = await self.internal_usage_cache.async_get_cache(key=cache_key)
|
||||
|
||||
if (
|
||||
getattr(exception, "status_code", None) is None
|
||||
|
@ -1402,9 +1405,9 @@ Model Info:
|
|||
self.alert_to_webhook_url is not None
|
||||
and alert_type in self.alert_to_webhook_url
|
||||
):
|
||||
slack_webhook_url: Optional[Union[str, List[str]]] = (
|
||||
self.alert_to_webhook_url[alert_type]
|
||||
)
|
||||
slack_webhook_url: Optional[
|
||||
Union[str, List[str]]
|
||||
] = self.alert_to_webhook_url[alert_type]
|
||||
elif self.default_webhook_url is not None:
|
||||
slack_webhook_url = self.default_webhook_url
|
||||
else:
|
||||
|
@ -1768,7 +1771,6 @@ Model Info:
|
|||
- Team Created, Updated, Deleted
|
||||
"""
|
||||
try:
|
||||
|
||||
message = f"`{event_name}`\n"
|
||||
|
||||
key_event_dict = key_event.model_dump()
|
||||
|
|
|
@ -283,4 +283,4 @@ class OpenInferenceSpanKindValues(Enum):
|
|||
|
||||
class OpenInferenceMimeTypeValues(Enum):
|
||||
TEXT = "text/plain"
|
||||
JSON = "application/json"
|
||||
JSON = "application/json"
|
||||
|
|
|
@ -98,7 +98,6 @@ class ArgillaLogger(CustomBatchLogger):
|
|||
argilla_dataset_name: Optional[str],
|
||||
argilla_base_url: Optional[str],
|
||||
) -> ArgillaCredentialsObject:
|
||||
|
||||
_credentials_api_key = argilla_api_key or os.getenv("ARGILLA_API_KEY")
|
||||
if _credentials_api_key is None:
|
||||
raise Exception("Invalid Argilla API Key given. _credentials_api_key=None.")
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
|
@ -7,7 +7,7 @@ from litellm.types.utils import StandardLoggingPayload
|
|||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
Span = _Span
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
Span = Any
|
||||
|
||||
|
|
|
@ -19,14 +19,13 @@ if TYPE_CHECKING:
|
|||
from litellm.types.integrations.arize import Protocol as _Protocol
|
||||
|
||||
Protocol = _Protocol
|
||||
Span = _Span
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
Protocol = Any
|
||||
Span = Any
|
||||
|
||||
|
||||
class ArizeLogger(OpenTelemetry):
|
||||
|
||||
def set_attributes(self, span: Span, kwargs, response_obj: Optional[Any]):
|
||||
ArizeLogger.set_arize_attributes(span, kwargs, response_obj)
|
||||
return
|
||||
|
|
|
@ -1,17 +1,20 @@
|
|||
import os
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from litellm.integrations.arize import _utils
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.arize import _utils
|
||||
from litellm.types.integrations.arize_phoenix import ArizePhoenixConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .opentelemetry import OpenTelemetryConfig as _OpenTelemetryConfig
|
||||
from litellm.types.integrations.arize import Protocol as _Protocol
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
from litellm.types.integrations.arize import Protocol as _Protocol
|
||||
|
||||
from .opentelemetry import OpenTelemetryConfig as _OpenTelemetryConfig
|
||||
|
||||
Protocol = _Protocol
|
||||
OpenTelemetryConfig = _OpenTelemetryConfig
|
||||
Span = _Span
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
Protocol = Any
|
||||
OpenTelemetryConfig = Any
|
||||
|
@ -20,6 +23,7 @@ else:
|
|||
|
||||
ARIZE_HOSTED_PHOENIX_ENDPOINT = "https://app.phoenix.arize.com/v1/traces"
|
||||
|
||||
|
||||
class ArizePhoenixLogger:
|
||||
@staticmethod
|
||||
def set_arize_phoenix_attributes(span: Span, kwargs, response_obj):
|
||||
|
@ -49,7 +53,7 @@ class ArizePhoenixLogger:
|
|||
protocol = "otlp_grpc"
|
||||
else:
|
||||
endpoint = ARIZE_HOSTED_PHOENIX_ENDPOINT
|
||||
protocol = "otlp_http"
|
||||
protocol = "otlp_http"
|
||||
verbose_logger.debug(
|
||||
f"No PHOENIX_COLLECTOR_ENDPOINT or PHOENIX_COLLECTOR_HTTP_ENDPOINT found, using default endpoint with http: {ARIZE_HOSTED_PHOENIX_ENDPOINT}"
|
||||
)
|
||||
|
@ -57,17 +61,16 @@ class ArizePhoenixLogger:
|
|||
otlp_auth_headers = None
|
||||
# If the endpoint is the Arize hosted Phoenix endpoint, use the api_key as the auth header as currently it is uses
|
||||
# a slightly different auth header format than self hosted phoenix
|
||||
if endpoint == ARIZE_HOSTED_PHOENIX_ENDPOINT:
|
||||
if endpoint == ARIZE_HOSTED_PHOENIX_ENDPOINT:
|
||||
if api_key is None:
|
||||
raise ValueError("PHOENIX_API_KEY must be set when the Arize hosted Phoenix endpoint is used.")
|
||||
raise ValueError(
|
||||
"PHOENIX_API_KEY must be set when the Arize hosted Phoenix endpoint is used."
|
||||
)
|
||||
otlp_auth_headers = f"api_key={api_key}"
|
||||
elif api_key is not None:
|
||||
# api_key/auth is optional for self hosted phoenix
|
||||
otlp_auth_headers = f"Authorization=Bearer {api_key}"
|
||||
|
||||
return ArizePhoenixConfig(
|
||||
otlp_auth_headers=otlp_auth_headers,
|
||||
protocol=protocol,
|
||||
endpoint=endpoint
|
||||
otlp_auth_headers=otlp_auth_headers, protocol=protocol, endpoint=endpoint
|
||||
)
|
||||
|
||||
|
|
|
@ -12,7 +12,10 @@ class AthinaLogger:
|
|||
"athina-api-key": self.athina_api_key,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
self.athina_logging_url = os.getenv("ATHINA_BASE_URL", "https://log.athina.ai") + "/api/v1/log/inference"
|
||||
self.athina_logging_url = (
|
||||
os.getenv("ATHINA_BASE_URL", "https://log.athina.ai")
|
||||
+ "/api/v1/log/inference"
|
||||
)
|
||||
self.additional_keys = [
|
||||
"environment",
|
||||
"prompt_slug",
|
||||
|
|
|
@ -50,12 +50,12 @@ class AzureBlobStorageLogger(CustomBatchLogger):
|
|||
self.azure_storage_file_system: str = _azure_storage_file_system
|
||||
|
||||
# Internal variables used for Token based authentication
|
||||
self.azure_auth_token: Optional[str] = (
|
||||
None # the Azure AD token to use for Azure Storage API requests
|
||||
)
|
||||
self.token_expiry: Optional[datetime] = (
|
||||
None # the expiry time of the currentAzure AD token
|
||||
)
|
||||
self.azure_auth_token: Optional[
|
||||
str
|
||||
] = None # the Azure AD token to use for Azure Storage API requests
|
||||
self.token_expiry: Optional[
|
||||
datetime
|
||||
] = None # the expiry time of the currentAzure AD token
|
||||
|
||||
asyncio.create_task(self.periodic_flush())
|
||||
self.flush_lock = asyncio.Lock()
|
||||
|
@ -153,7 +153,6 @@ class AzureBlobStorageLogger(CustomBatchLogger):
|
|||
3. Flush the data
|
||||
"""
|
||||
try:
|
||||
|
||||
if self.azure_storage_account_key:
|
||||
await self.upload_to_azure_data_lake_with_azure_account_key(
|
||||
payload=payload
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
import copy
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict
|
||||
from typing import Dict, Optional
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
@ -19,7 +19,9 @@ from litellm.llms.custom_httpx.http_handler import (
|
|||
)
|
||||
from litellm.utils import print_verbose
|
||||
|
||||
global_braintrust_http_handler = get_async_httpx_client(llm_provider=httpxSpecialProvider.LoggingCallback)
|
||||
global_braintrust_http_handler = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.LoggingCallback
|
||||
)
|
||||
global_braintrust_sync_http_handler = HTTPHandler()
|
||||
API_BASE = "https://api.braintrustdata.com/v1"
|
||||
|
||||
|
@ -35,7 +37,9 @@ def get_utc_datetime():
|
|||
|
||||
|
||||
class BraintrustLogger(CustomLogger):
|
||||
def __init__(self, api_key: Optional[str] = None, api_base: Optional[str] = None) -> None:
|
||||
def __init__(
|
||||
self, api_key: Optional[str] = None, api_base: Optional[str] = None
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.validate_environment(api_key=api_key)
|
||||
self.api_base = api_base or API_BASE
|
||||
|
@ -45,7 +49,9 @@ class BraintrustLogger(CustomLogger):
|
|||
"Authorization": "Bearer " + self.api_key,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
self._project_id_cache: Dict[str, str] = {} # Cache mapping project names to IDs
|
||||
self._project_id_cache: Dict[
|
||||
str, str
|
||||
] = {} # Cache mapping project names to IDs
|
||||
|
||||
def validate_environment(self, api_key: Optional[str]):
|
||||
"""
|
||||
|
@ -71,7 +77,9 @@ class BraintrustLogger(CustomLogger):
|
|||
|
||||
try:
|
||||
response = global_braintrust_sync_http_handler.post(
|
||||
f"{self.api_base}/project", headers=self.headers, json={"name": project_name}
|
||||
f"{self.api_base}/project",
|
||||
headers=self.headers,
|
||||
json={"name": project_name},
|
||||
)
|
||||
project_dict = response.json()
|
||||
project_id = project_dict["id"]
|
||||
|
@ -89,7 +97,9 @@ class BraintrustLogger(CustomLogger):
|
|||
|
||||
try:
|
||||
response = await global_braintrust_http_handler.post(
|
||||
f"{self.api_base}/project/register", headers=self.headers, json={"name": project_name}
|
||||
f"{self.api_base}/project/register",
|
||||
headers=self.headers,
|
||||
json={"name": project_name},
|
||||
)
|
||||
project_dict = response.json()
|
||||
project_id = project_dict["id"]
|
||||
|
@ -116,15 +126,21 @@ class BraintrustLogger(CustomLogger):
|
|||
if metadata is None:
|
||||
metadata = {}
|
||||
|
||||
proxy_headers = litellm_params.get("proxy_server_request", {}).get("headers", {}) or {}
|
||||
proxy_headers = (
|
||||
litellm_params.get("proxy_server_request", {}).get("headers", {}) or {}
|
||||
)
|
||||
|
||||
for metadata_param_key in proxy_headers:
|
||||
if metadata_param_key.startswith("braintrust"):
|
||||
trace_param_key = metadata_param_key.replace("braintrust", "", 1)
|
||||
if trace_param_key in metadata:
|
||||
verbose_logger.warning(f"Overwriting Braintrust `{trace_param_key}` from request header")
|
||||
verbose_logger.warning(
|
||||
f"Overwriting Braintrust `{trace_param_key}` from request header"
|
||||
)
|
||||
else:
|
||||
verbose_logger.debug(f"Found Braintrust `{trace_param_key}` in request header")
|
||||
verbose_logger.debug(
|
||||
f"Found Braintrust `{trace_param_key}` in request header"
|
||||
)
|
||||
metadata[trace_param_key] = proxy_headers.get(metadata_param_key)
|
||||
|
||||
return metadata
|
||||
|
@ -157,24 +173,35 @@ class BraintrustLogger(CustomLogger):
|
|||
output = None
|
||||
choices = []
|
||||
if response_obj is not None and (
|
||||
kwargs.get("call_type", None) == "embedding" or isinstance(response_obj, litellm.EmbeddingResponse)
|
||||
kwargs.get("call_type", None) == "embedding"
|
||||
or isinstance(response_obj, litellm.EmbeddingResponse)
|
||||
):
|
||||
output = None
|
||||
elif response_obj is not None and isinstance(response_obj, litellm.ModelResponse):
|
||||
elif response_obj is not None and isinstance(
|
||||
response_obj, litellm.ModelResponse
|
||||
):
|
||||
output = response_obj["choices"][0]["message"].json()
|
||||
choices = response_obj["choices"]
|
||||
elif response_obj is not None and isinstance(response_obj, litellm.TextCompletionResponse):
|
||||
elif response_obj is not None and isinstance(
|
||||
response_obj, litellm.TextCompletionResponse
|
||||
):
|
||||
output = response_obj.choices[0].text
|
||||
choices = response_obj.choices
|
||||
elif response_obj is not None and isinstance(response_obj, litellm.ImageResponse):
|
||||
elif response_obj is not None and isinstance(
|
||||
response_obj, litellm.ImageResponse
|
||||
):
|
||||
output = response_obj["data"]
|
||||
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
metadata = litellm_params.get("metadata", {}) or {} # if litellm_params['metadata'] == None
|
||||
metadata = (
|
||||
litellm_params.get("metadata", {}) or {}
|
||||
) # if litellm_params['metadata'] == None
|
||||
metadata = self.add_metadata_from_header(litellm_params, metadata)
|
||||
clean_metadata = {}
|
||||
try:
|
||||
metadata = copy.deepcopy(metadata) # Avoid modifying the original metadata
|
||||
metadata = copy.deepcopy(
|
||||
metadata
|
||||
) # Avoid modifying the original metadata
|
||||
except Exception:
|
||||
new_metadata = {}
|
||||
for key, value in metadata.items():
|
||||
|
@ -192,7 +219,9 @@ class BraintrustLogger(CustomLogger):
|
|||
project_id = metadata.get("project_id")
|
||||
if project_id is None:
|
||||
project_name = metadata.get("project_name")
|
||||
project_id = self.get_project_id_sync(project_name) if project_name else None
|
||||
project_id = (
|
||||
self.get_project_id_sync(project_name) if project_name else None
|
||||
)
|
||||
|
||||
if project_id is None:
|
||||
if self.default_project_id is None:
|
||||
|
@ -234,7 +263,8 @@ class BraintrustLogger(CustomLogger):
|
|||
"completion_tokens": usage_obj.completion_tokens,
|
||||
"total_tokens": usage_obj.total_tokens,
|
||||
"total_cost": cost,
|
||||
"time_to_first_token": end_time.timestamp() - start_time.timestamp(),
|
||||
"time_to_first_token": end_time.timestamp()
|
||||
- start_time.timestamp(),
|
||||
"start": start_time.timestamp(),
|
||||
"end": end_time.timestamp(),
|
||||
}
|
||||
|
@ -255,7 +285,9 @@ class BraintrustLogger(CustomLogger):
|
|||
request_data["metrics"] = metrics
|
||||
|
||||
try:
|
||||
print_verbose(f"global_braintrust_sync_http_handler.post: {global_braintrust_sync_http_handler.post}")
|
||||
print_verbose(
|
||||
f"global_braintrust_sync_http_handler.post: {global_braintrust_sync_http_handler.post}"
|
||||
)
|
||||
global_braintrust_sync_http_handler.post(
|
||||
url=f"{self.api_base}/project_logs/{project_id}/insert",
|
||||
json={"events": [request_data]},
|
||||
|
@ -276,20 +308,29 @@ class BraintrustLogger(CustomLogger):
|
|||
output = None
|
||||
choices = []
|
||||
if response_obj is not None and (
|
||||
kwargs.get("call_type", None) == "embedding" or isinstance(response_obj, litellm.EmbeddingResponse)
|
||||
kwargs.get("call_type", None) == "embedding"
|
||||
or isinstance(response_obj, litellm.EmbeddingResponse)
|
||||
):
|
||||
output = None
|
||||
elif response_obj is not None and isinstance(response_obj, litellm.ModelResponse):
|
||||
elif response_obj is not None and isinstance(
|
||||
response_obj, litellm.ModelResponse
|
||||
):
|
||||
output = response_obj["choices"][0]["message"].json()
|
||||
choices = response_obj["choices"]
|
||||
elif response_obj is not None and isinstance(response_obj, litellm.TextCompletionResponse):
|
||||
elif response_obj is not None and isinstance(
|
||||
response_obj, litellm.TextCompletionResponse
|
||||
):
|
||||
output = response_obj.choices[0].text
|
||||
choices = response_obj.choices
|
||||
elif response_obj is not None and isinstance(response_obj, litellm.ImageResponse):
|
||||
elif response_obj is not None and isinstance(
|
||||
response_obj, litellm.ImageResponse
|
||||
):
|
||||
output = response_obj["data"]
|
||||
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
metadata = litellm_params.get("metadata", {}) or {} # if litellm_params['metadata'] == None
|
||||
metadata = (
|
||||
litellm_params.get("metadata", {}) or {}
|
||||
) # if litellm_params['metadata'] == None
|
||||
metadata = self.add_metadata_from_header(litellm_params, metadata)
|
||||
clean_metadata = {}
|
||||
new_metadata = {}
|
||||
|
@ -313,7 +354,11 @@ class BraintrustLogger(CustomLogger):
|
|||
project_id = metadata.get("project_id")
|
||||
if project_id is None:
|
||||
project_name = metadata.get("project_name")
|
||||
project_id = await self.get_project_id_async(project_name) if project_name else None
|
||||
project_id = (
|
||||
await self.get_project_id_async(project_name)
|
||||
if project_name
|
||||
else None
|
||||
)
|
||||
|
||||
if project_id is None:
|
||||
if self.default_project_id is None:
|
||||
|
@ -362,8 +407,14 @@ class BraintrustLogger(CustomLogger):
|
|||
api_call_start_time = kwargs.get("api_call_start_time")
|
||||
completion_start_time = kwargs.get("completion_start_time")
|
||||
|
||||
if api_call_start_time is not None and completion_start_time is not None:
|
||||
metrics["time_to_first_token"] = completion_start_time.timestamp() - api_call_start_time.timestamp()
|
||||
if (
|
||||
api_call_start_time is not None
|
||||
and completion_start_time is not None
|
||||
):
|
||||
metrics["time_to_first_token"] = (
|
||||
completion_start_time.timestamp()
|
||||
- api_call_start_time.timestamp()
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"id": litellm_call_id,
|
||||
|
|
|
@ -14,7 +14,6 @@ from litellm.integrations.custom_logger import CustomLogger
|
|||
|
||||
|
||||
class CustomBatchLogger(CustomLogger):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
flush_lock: Optional[asyncio.Lock] = None,
|
||||
|
|
|
@ -7,7 +7,6 @@ from litellm.types.utils import StandardLoggingGuardrailInformation
|
|||
|
||||
|
||||
class CustomGuardrail(CustomLogger):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
guardrail_name: Optional[str] = None,
|
||||
|
|
|
@ -31,7 +31,7 @@ from litellm.types.utils import (
|
|||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
Span = _Span
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
Span = Any
|
||||
|
||||
|
|
|
@ -233,7 +233,6 @@ class DataDogLogger(
|
|||
pass
|
||||
|
||||
async def _log_async_event(self, kwargs, response_obj, start_time, end_time):
|
||||
|
||||
dd_payload = self.create_datadog_logging_payload(
|
||||
kwargs=kwargs,
|
||||
response_obj=response_obj,
|
||||
|
|
|
@ -125,9 +125,9 @@ class GCSBucketBase(CustomBatchLogger):
|
|||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = (
|
||||
kwargs.get("standard_callback_dynamic_params", None)
|
||||
)
|
||||
standard_callback_dynamic_params: Optional[
|
||||
StandardCallbackDynamicParams
|
||||
] = kwargs.get("standard_callback_dynamic_params", None)
|
||||
|
||||
bucket_name: str
|
||||
path_service_account: Optional[str]
|
||||
|
|
|
@ -70,12 +70,13 @@ class GcsPubSubLogger(CustomBatchLogger):
|
|||
"""Construct authorization headers using Vertex AI auth"""
|
||||
from litellm import vertex_chat_completion
|
||||
|
||||
_auth_header, vertex_project = (
|
||||
await vertex_chat_completion._ensure_access_token_async(
|
||||
credentials=self.path_service_account_json,
|
||||
project_id=None,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
(
|
||||
_auth_header,
|
||||
vertex_project,
|
||||
) = await vertex_chat_completion._ensure_access_token_async(
|
||||
credentials=self.path_service_account_json,
|
||||
project_id=None,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
auth_header, _ = vertex_chat_completion._get_token_and_url(
|
||||
|
|
|
@ -155,11 +155,7 @@ class HumanloopLogger(CustomLogger):
|
|||
prompt_id: str,
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
) -> Tuple[
|
||||
str,
|
||||
List[AllMessageValues],
|
||||
dict,
|
||||
]:
|
||||
) -> Tuple[str, List[AllMessageValues], dict,]:
|
||||
humanloop_api_key = dynamic_callback_params.get(
|
||||
"humanloop_api_key"
|
||||
) or get_secret_str("HUMANLOOP_API_KEY")
|
||||
|
|
|
@ -471,9 +471,9 @@ class LangFuseLogger:
|
|||
# we clean out all extra litellm metadata params before logging
|
||||
clean_metadata: Dict[str, Any] = {}
|
||||
if prompt_management_metadata is not None:
|
||||
clean_metadata["prompt_management_metadata"] = (
|
||||
prompt_management_metadata
|
||||
)
|
||||
clean_metadata[
|
||||
"prompt_management_metadata"
|
||||
] = prompt_management_metadata
|
||||
if isinstance(metadata, dict):
|
||||
for key, value in metadata.items():
|
||||
# generate langfuse tags - Default Tags sent to Langfuse from LiteLLM Proxy
|
||||
|
|
|
@ -19,7 +19,6 @@ else:
|
|||
|
||||
|
||||
class LangFuseHandler:
|
||||
|
||||
@staticmethod
|
||||
def get_langfuse_logger_for_request(
|
||||
standard_callback_dynamic_params: StandardCallbackDynamicParams,
|
||||
|
@ -87,7 +86,9 @@ class LangFuseHandler:
|
|||
if globalLangfuseLogger is not None:
|
||||
return globalLangfuseLogger
|
||||
|
||||
credentials_dict: Dict[str, Any] = (
|
||||
credentials_dict: Dict[
|
||||
str, Any
|
||||
] = (
|
||||
{}
|
||||
) # the global langfuse logger uses Environment Variables, there are no dynamic credentials
|
||||
globalLangfuseLogger = in_memory_dynamic_logger_cache.get_cache(
|
||||
|
|
|
@ -172,11 +172,7 @@ class LangfusePromptManagement(LangFuseLogger, PromptManagementBase, CustomLogge
|
|||
prompt_id: str,
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
) -> Tuple[
|
||||
str,
|
||||
List[AllMessageValues],
|
||||
dict,
|
||||
]:
|
||||
) -> Tuple[str, List[AllMessageValues], dict,]:
|
||||
return self.get_chat_completion_prompt(
|
||||
model,
|
||||
messages,
|
||||
|
|
|
@ -75,7 +75,6 @@ class LangsmithLogger(CustomBatchLogger):
|
|||
langsmith_project: Optional[str] = None,
|
||||
langsmith_base_url: Optional[str] = None,
|
||||
) -> LangsmithCredentialsObject:
|
||||
|
||||
_credentials_api_key = langsmith_api_key or os.getenv("LANGSMITH_API_KEY")
|
||||
if _credentials_api_key is None:
|
||||
raise Exception(
|
||||
|
@ -443,9 +442,9 @@ class LangsmithLogger(CustomBatchLogger):
|
|||
|
||||
Otherwise, use the default credentials.
|
||||
"""
|
||||
standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = (
|
||||
kwargs.get("standard_callback_dynamic_params", None)
|
||||
)
|
||||
standard_callback_dynamic_params: Optional[
|
||||
StandardCallbackDynamicParams
|
||||
] = kwargs.get("standard_callback_dynamic_params", None)
|
||||
if standard_callback_dynamic_params is not None:
|
||||
credentials = self.get_credentials_from_env(
|
||||
langsmith_api_key=standard_callback_dynamic_params.get(
|
||||
|
@ -481,7 +480,6 @@ class LangsmithLogger(CustomBatchLogger):
|
|||
asyncio.run(self.async_send_batch())
|
||||
|
||||
def get_run_by_id(self, run_id):
|
||||
|
||||
langsmith_api_key = self.default_credentials["LANGSMITH_API_KEY"]
|
||||
|
||||
langsmith_api_base = self.default_credentials["LANGSMITH_BASE_URL"]
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
import json
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
from litellm.proxy._types import SpanAttributes
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
Span = _Span
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
Span = Any
|
||||
|
||||
|
|
|
@ -20,7 +20,6 @@ def parse_tool_calls(tool_calls):
|
|||
return None
|
||||
|
||||
def clean_tool_call(tool_call):
|
||||
|
||||
serialized = {
|
||||
"type": tool_call.type,
|
||||
"id": tool_call.id,
|
||||
|
@ -36,7 +35,6 @@ def parse_tool_calls(tool_calls):
|
|||
|
||||
|
||||
def parse_messages(input):
|
||||
|
||||
if input is None:
|
||||
return None
|
||||
|
||||
|
|
|
@ -48,14 +48,17 @@ class MlflowLogger(CustomLogger):
|
|||
|
||||
def _extract_and_set_chat_attributes(self, span, kwargs, response_obj):
|
||||
try:
|
||||
from mlflow.tracing.utils import set_span_chat_messages, set_span_chat_tools
|
||||
from mlflow.tracing.utils import set_span_chat_messages # type: ignore
|
||||
from mlflow.tracing.utils import set_span_chat_tools # type: ignore
|
||||
except ImportError:
|
||||
return
|
||||
|
||||
inputs = self._construct_input(kwargs)
|
||||
input_messages = inputs.get("messages", [])
|
||||
output_messages = [c.message.model_dump(exclude_none=True)
|
||||
for c in getattr(response_obj, "choices", [])]
|
||||
output_messages = [
|
||||
c.message.model_dump(exclude_none=True)
|
||||
for c in getattr(response_obj, "choices", [])
|
||||
]
|
||||
if messages := [*input_messages, *output_messages]:
|
||||
set_span_chat_messages(span, messages)
|
||||
if tools := inputs.get("tools"):
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import os
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
|
@ -23,10 +23,10 @@ if TYPE_CHECKING:
|
|||
)
|
||||
from litellm.proxy.proxy_server import UserAPIKeyAuth as _UserAPIKeyAuth
|
||||
|
||||
Span = _Span
|
||||
SpanExporter = _SpanExporter
|
||||
UserAPIKeyAuth = _UserAPIKeyAuth
|
||||
ManagementEndpointLoggingPayload = _ManagementEndpointLoggingPayload
|
||||
Span = Union[_Span, Any]
|
||||
SpanExporter = Union[_SpanExporter, Any]
|
||||
UserAPIKeyAuth = Union[_UserAPIKeyAuth, Any]
|
||||
ManagementEndpointLoggingPayload = Union[_ManagementEndpointLoggingPayload, Any]
|
||||
else:
|
||||
Span = Any
|
||||
SpanExporter = Any
|
||||
|
@ -46,7 +46,6 @@ LITELLM_REQUEST_SPAN_NAME = "litellm_request"
|
|||
|
||||
@dataclass
|
||||
class OpenTelemetryConfig:
|
||||
|
||||
exporter: Union[str, SpanExporter] = "console"
|
||||
endpoint: Optional[str] = None
|
||||
headers: Optional[str] = None
|
||||
|
@ -154,7 +153,6 @@ class OpenTelemetry(CustomLogger):
|
|||
end_time: Optional[Union[datetime, float]] = None,
|
||||
event_metadata: Optional[dict] = None,
|
||||
):
|
||||
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.trace import Status, StatusCode
|
||||
|
||||
|
@ -215,7 +213,6 @@ class OpenTelemetry(CustomLogger):
|
|||
end_time: Optional[Union[float, datetime]] = None,
|
||||
event_metadata: Optional[dict] = None,
|
||||
):
|
||||
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.trace import Status, StatusCode
|
||||
|
||||
|
@ -353,9 +350,9 @@ class OpenTelemetry(CustomLogger):
|
|||
"""
|
||||
from opentelemetry import trace
|
||||
|
||||
standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = (
|
||||
kwargs.get("standard_callback_dynamic_params")
|
||||
)
|
||||
standard_callback_dynamic_params: Optional[
|
||||
StandardCallbackDynamicParams
|
||||
] = kwargs.get("standard_callback_dynamic_params")
|
||||
if not standard_callback_dynamic_params:
|
||||
return
|
||||
|
||||
|
@ -722,7 +719,6 @@ class OpenTelemetry(CustomLogger):
|
|||
span.set_attribute(key, primitive_value)
|
||||
|
||||
def set_raw_request_attributes(self, span: Span, kwargs, response_obj):
|
||||
|
||||
kwargs.get("optional_params", {})
|
||||
litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||
custom_llm_provider = litellm_params.get("custom_llm_provider", "Unknown")
|
||||
|
@ -843,12 +839,14 @@ class OpenTelemetry(CustomLogger):
|
|||
headers=dynamic_headers or self.OTEL_HEADERS
|
||||
)
|
||||
|
||||
if isinstance(self.OTEL_EXPORTER, SpanExporter):
|
||||
if hasattr(
|
||||
self.OTEL_EXPORTER, "export"
|
||||
): # Check if it has the export method that SpanExporter requires
|
||||
verbose_logger.debug(
|
||||
"OpenTelemetry: intiializing SpanExporter. Value of OTEL_EXPORTER: %s",
|
||||
self.OTEL_EXPORTER,
|
||||
)
|
||||
return SimpleSpanProcessor(self.OTEL_EXPORTER)
|
||||
return SimpleSpanProcessor(cast(SpanExporter, self.OTEL_EXPORTER))
|
||||
|
||||
if self.OTEL_EXPORTER == "console":
|
||||
verbose_logger.debug(
|
||||
|
@ -907,7 +905,6 @@ class OpenTelemetry(CustomLogger):
|
|||
logging_payload: ManagementEndpointLoggingPayload,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
):
|
||||
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.trace import Status, StatusCode
|
||||
|
||||
|
@ -961,7 +958,6 @@ class OpenTelemetry(CustomLogger):
|
|||
logging_payload: ManagementEndpointLoggingPayload,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
):
|
||||
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.trace import Status, StatusCode
|
||||
|
||||
|
|
|
@ -185,7 +185,6 @@ class OpikLogger(CustomBatchLogger):
|
|||
def _create_opik_payload( # noqa: PLR0915
|
||||
self, kwargs, response_obj, start_time, end_time
|
||||
) -> List[Dict]:
|
||||
|
||||
# Get metadata
|
||||
_litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||
litellm_params_metadata = _litellm_params.get("metadata", {}) or {}
|
||||
|
|
|
@ -988,9 +988,9 @@ class PrometheusLogger(CustomLogger):
|
|||
):
|
||||
try:
|
||||
verbose_logger.debug("setting remaining tokens requests metric")
|
||||
standard_logging_payload: Optional[StandardLoggingPayload] = (
|
||||
request_kwargs.get("standard_logging_object")
|
||||
)
|
||||
standard_logging_payload: Optional[
|
||||
StandardLoggingPayload
|
||||
] = request_kwargs.get("standard_logging_object")
|
||||
|
||||
if standard_logging_payload is None:
|
||||
return
|
||||
|
|
|
@ -14,7 +14,6 @@ class PromptManagementClient(TypedDict):
|
|||
|
||||
|
||||
class PromptManagementBase(ABC):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def integration_name(self) -> str:
|
||||
|
@ -83,11 +82,7 @@ class PromptManagementBase(ABC):
|
|||
prompt_id: str,
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
) -> Tuple[
|
||||
str,
|
||||
List[AllMessageValues],
|
||||
dict,
|
||||
]:
|
||||
) -> Tuple[str, List[AllMessageValues], dict,]:
|
||||
if not self.should_run_prompt_management(
|
||||
prompt_id=prompt_id, dynamic_callback_params=dynamic_callback_params
|
||||
):
|
||||
|
|
|
@ -38,7 +38,7 @@ class S3Logger:
|
|||
if litellm.s3_callback_params is not None:
|
||||
# read in .env variables - example os.environ/AWS_BUCKET_NAME
|
||||
for key, value in litellm.s3_callback_params.items():
|
||||
if type(value) is str and value.startswith("os.environ/"):
|
||||
if isinstance(value, str) and value.startswith("os.environ/"):
|
||||
litellm.s3_callback_params[key] = litellm.get_secret(value)
|
||||
# now set s3 params from litellm.s3_logger_params
|
||||
s3_bucket_name = litellm.s3_callback_params.get("s3_bucket_name")
|
||||
|
|
|
@ -21,11 +21,11 @@ try:
|
|||
# contains a (known) object attribute
|
||||
object: Literal["chat.completion", "edit", "text_completion"]
|
||||
|
||||
def __getitem__(self, key: K) -> V: ... # noqa
|
||||
def __getitem__(self, key: K) -> V:
|
||||
... # noqa
|
||||
|
||||
def get( # noqa
|
||||
self, key: K, default: Optional[V] = None
|
||||
) -> Optional[V]: ... # pragma: no cover
|
||||
def get(self, key: K, default: Optional[V] = None) -> Optional[V]: # noqa
|
||||
... # pragma: no cover
|
||||
|
||||
class OpenAIRequestResponseResolver:
|
||||
def __call__(
|
||||
|
|
|
@ -10,7 +10,7 @@ from litellm.types.llms.openai import AllMessageValues
|
|||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
Span = _Span
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
Span = Any
|
||||
|
||||
|
|
|
@ -11,7 +11,9 @@ except (ImportError, AttributeError):
|
|||
# Old way to access resources, which setuptools deprecated some time ago
|
||||
import pkg_resources # type: ignore
|
||||
|
||||
filename = pkg_resources.resource_filename(__name__, "litellm_core_utils/tokenizers")
|
||||
filename = pkg_resources.resource_filename(
|
||||
__name__, "litellm_core_utils/tokenizers"
|
||||
)
|
||||
|
||||
os.environ["TIKTOKEN_CACHE_DIR"] = os.getenv(
|
||||
"CUSTOM_TIKTOKEN_CACHE_DIR", filename
|
||||
|
|
|
@ -239,9 +239,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
self.litellm_trace_id = litellm_trace_id
|
||||
self.function_id = function_id
|
||||
self.streaming_chunks: List[Any] = [] # for generating complete stream response
|
||||
self.sync_streaming_chunks: List[Any] = (
|
||||
[]
|
||||
) # for generating complete stream response
|
||||
self.sync_streaming_chunks: List[
|
||||
Any
|
||||
] = [] # for generating complete stream response
|
||||
self.log_raw_request_response = log_raw_request_response
|
||||
|
||||
# Initialize dynamic callbacks
|
||||
|
@ -452,18 +452,19 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
prompt_id: str,
|
||||
prompt_variables: Optional[dict],
|
||||
) -> Tuple[str, List[AllMessageValues], dict]:
|
||||
|
||||
custom_logger = self.get_custom_logger_for_prompt_management(model)
|
||||
if custom_logger:
|
||||
model, messages, non_default_params = (
|
||||
custom_logger.get_chat_completion_prompt(
|
||||
model=model,
|
||||
messages=messages,
|
||||
non_default_params=non_default_params,
|
||||
prompt_id=prompt_id,
|
||||
prompt_variables=prompt_variables,
|
||||
dynamic_callback_params=self.standard_callback_dynamic_params,
|
||||
)
|
||||
(
|
||||
model,
|
||||
messages,
|
||||
non_default_params,
|
||||
) = custom_logger.get_chat_completion_prompt(
|
||||
model=model,
|
||||
messages=messages,
|
||||
non_default_params=non_default_params,
|
||||
prompt_id=prompt_id,
|
||||
prompt_variables=prompt_variables,
|
||||
dynamic_callback_params=self.standard_callback_dynamic_params,
|
||||
)
|
||||
self.messages = messages
|
||||
return model, messages, non_default_params
|
||||
|
@ -541,12 +542,11 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
model
|
||||
): # if model name was changes pre-call, overwrite the initial model call name with the new one
|
||||
self.model_call_details["model"] = model
|
||||
self.model_call_details["litellm_params"]["api_base"] = (
|
||||
self._get_masked_api_base(additional_args.get("api_base", ""))
|
||||
)
|
||||
self.model_call_details["litellm_params"][
|
||||
"api_base"
|
||||
] = self._get_masked_api_base(additional_args.get("api_base", ""))
|
||||
|
||||
def pre_call(self, input, api_key, model=None, additional_args={}): # noqa: PLR0915
|
||||
|
||||
# Log the exact input to the LLM API
|
||||
litellm.error_logs["PRE_CALL"] = locals()
|
||||
try:
|
||||
|
@ -568,19 +568,16 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
self.log_raw_request_response is True
|
||||
or log_raw_request_response is True
|
||||
):
|
||||
|
||||
_litellm_params = self.model_call_details.get("litellm_params", {})
|
||||
_metadata = _litellm_params.get("metadata", {}) or {}
|
||||
try:
|
||||
# [Non-blocking Extra Debug Information in metadata]
|
||||
if turn_off_message_logging is True:
|
||||
|
||||
_metadata["raw_request"] = (
|
||||
"redacted by litellm. \
|
||||
_metadata[
|
||||
"raw_request"
|
||||
] = "redacted by litellm. \
|
||||
'litellm.turn_off_message_logging=True'"
|
||||
)
|
||||
else:
|
||||
|
||||
curl_command = self._get_request_curl_command(
|
||||
api_base=additional_args.get("api_base", ""),
|
||||
headers=additional_args.get("headers", {}),
|
||||
|
@ -590,33 +587,33 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
|
||||
_metadata["raw_request"] = str(curl_command)
|
||||
# split up, so it's easier to parse in the UI
|
||||
self.model_call_details["raw_request_typed_dict"] = (
|
||||
RawRequestTypedDict(
|
||||
raw_request_api_base=str(
|
||||
additional_args.get("api_base") or ""
|
||||
),
|
||||
raw_request_body=self._get_raw_request_body(
|
||||
additional_args.get("complete_input_dict", {})
|
||||
),
|
||||
raw_request_headers=self._get_masked_headers(
|
||||
additional_args.get("headers", {}) or {},
|
||||
ignore_sensitive_headers=True,
|
||||
),
|
||||
error=None,
|
||||
)
|
||||
self.model_call_details[
|
||||
"raw_request_typed_dict"
|
||||
] = RawRequestTypedDict(
|
||||
raw_request_api_base=str(
|
||||
additional_args.get("api_base") or ""
|
||||
),
|
||||
raw_request_body=self._get_raw_request_body(
|
||||
additional_args.get("complete_input_dict", {})
|
||||
),
|
||||
raw_request_headers=self._get_masked_headers(
|
||||
additional_args.get("headers", {}) or {},
|
||||
ignore_sensitive_headers=True,
|
||||
),
|
||||
error=None,
|
||||
)
|
||||
except Exception as e:
|
||||
self.model_call_details["raw_request_typed_dict"] = (
|
||||
RawRequestTypedDict(
|
||||
error=str(e),
|
||||
)
|
||||
self.model_call_details[
|
||||
"raw_request_typed_dict"
|
||||
] = RawRequestTypedDict(
|
||||
error=str(e),
|
||||
)
|
||||
traceback.print_exc()
|
||||
_metadata["raw_request"] = (
|
||||
"Unable to Log \
|
||||
_metadata[
|
||||
"raw_request"
|
||||
] = "Unable to Log \
|
||||
raw request: {}".format(
|
||||
str(e)
|
||||
)
|
||||
str(e)
|
||||
)
|
||||
if self.logger_fn and callable(self.logger_fn):
|
||||
try:
|
||||
|
@ -941,9 +938,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
verbose_logger.debug(
|
||||
f"response_cost_failure_debug_information: {debug_info}"
|
||||
)
|
||||
self.model_call_details["response_cost_failure_debug_information"] = (
|
||||
debug_info
|
||||
)
|
||||
self.model_call_details[
|
||||
"response_cost_failure_debug_information"
|
||||
] = debug_info
|
||||
return None
|
||||
|
||||
try:
|
||||
|
@ -968,9 +965,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
verbose_logger.debug(
|
||||
f"response_cost_failure_debug_information: {debug_info}"
|
||||
)
|
||||
self.model_call_details["response_cost_failure_debug_information"] = (
|
||||
debug_info
|
||||
)
|
||||
self.model_call_details[
|
||||
"response_cost_failure_debug_information"
|
||||
] = debug_info
|
||||
|
||||
return None
|
||||
|
||||
|
@ -995,7 +992,6 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
def should_run_callback(
|
||||
self, callback: litellm.CALLBACK_TYPES, litellm_params: dict, event_hook: str
|
||||
) -> bool:
|
||||
|
||||
if litellm.global_disable_no_log_param:
|
||||
return True
|
||||
|
||||
|
@ -1027,9 +1023,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
end_time = datetime.datetime.now()
|
||||
if self.completion_start_time is None:
|
||||
self.completion_start_time = end_time
|
||||
self.model_call_details["completion_start_time"] = (
|
||||
self.completion_start_time
|
||||
)
|
||||
self.model_call_details[
|
||||
"completion_start_time"
|
||||
] = self.completion_start_time
|
||||
self.model_call_details["log_event_type"] = "successful_api_call"
|
||||
self.model_call_details["end_time"] = end_time
|
||||
self.model_call_details["cache_hit"] = cache_hit
|
||||
|
@ -1083,39 +1079,39 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
"response_cost"
|
||||
]
|
||||
else:
|
||||
self.model_call_details["response_cost"] = (
|
||||
self._response_cost_calculator(result=result)
|
||||
)
|
||||
self.model_call_details[
|
||||
"response_cost"
|
||||
] = self._response_cost_calculator(result=result)
|
||||
## STANDARDIZED LOGGING PAYLOAD
|
||||
|
||||
self.model_call_details["standard_logging_object"] = (
|
||||
get_standard_logging_object_payload(
|
||||
kwargs=self.model_call_details,
|
||||
init_response_obj=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=self,
|
||||
status="success",
|
||||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||
)
|
||||
self.model_call_details[
|
||||
"standard_logging_object"
|
||||
] = get_standard_logging_object_payload(
|
||||
kwargs=self.model_call_details,
|
||||
init_response_obj=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=self,
|
||||
status="success",
|
||||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||
)
|
||||
elif isinstance(result, dict): # pass-through endpoints
|
||||
## STANDARDIZED LOGGING PAYLOAD
|
||||
self.model_call_details["standard_logging_object"] = (
|
||||
get_standard_logging_object_payload(
|
||||
kwargs=self.model_call_details,
|
||||
init_response_obj=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=self,
|
||||
status="success",
|
||||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||
)
|
||||
self.model_call_details[
|
||||
"standard_logging_object"
|
||||
] = get_standard_logging_object_payload(
|
||||
kwargs=self.model_call_details,
|
||||
init_response_obj=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=self,
|
||||
status="success",
|
||||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||
)
|
||||
elif standard_logging_object is not None:
|
||||
self.model_call_details["standard_logging_object"] = (
|
||||
standard_logging_object
|
||||
)
|
||||
self.model_call_details[
|
||||
"standard_logging_object"
|
||||
] = standard_logging_object
|
||||
else: # streaming chunks + image gen.
|
||||
self.model_call_details["response_cost"] = None
|
||||
|
||||
|
@ -1154,7 +1150,6 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
standard_logging_object=kwargs.get("standard_logging_object", None),
|
||||
)
|
||||
try:
|
||||
|
||||
## BUILD COMPLETE STREAMED RESPONSE
|
||||
complete_streaming_response: Optional[
|
||||
Union[ModelResponse, TextCompletionResponse, ResponsesAPIResponse]
|
||||
|
@ -1172,23 +1167,23 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
verbose_logger.debug(
|
||||
"Logging Details LiteLLM-Success Call streaming complete"
|
||||
)
|
||||
self.model_call_details["complete_streaming_response"] = (
|
||||
complete_streaming_response
|
||||
)
|
||||
self.model_call_details["response_cost"] = (
|
||||
self._response_cost_calculator(result=complete_streaming_response)
|
||||
)
|
||||
self.model_call_details[
|
||||
"complete_streaming_response"
|
||||
] = complete_streaming_response
|
||||
self.model_call_details[
|
||||
"response_cost"
|
||||
] = self._response_cost_calculator(result=complete_streaming_response)
|
||||
## STANDARDIZED LOGGING PAYLOAD
|
||||
self.model_call_details["standard_logging_object"] = (
|
||||
get_standard_logging_object_payload(
|
||||
kwargs=self.model_call_details,
|
||||
init_response_obj=complete_streaming_response,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=self,
|
||||
status="success",
|
||||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||
)
|
||||
self.model_call_details[
|
||||
"standard_logging_object"
|
||||
] = get_standard_logging_object_payload(
|
||||
kwargs=self.model_call_details,
|
||||
init_response_obj=complete_streaming_response,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=self,
|
||||
status="success",
|
||||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||
)
|
||||
callbacks = self.get_combined_callback_list(
|
||||
dynamic_success_callbacks=self.dynamic_success_callbacks,
|
||||
|
@ -1207,7 +1202,6 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
## LOGGING HOOK ##
|
||||
for callback in callbacks:
|
||||
if isinstance(callback, CustomLogger):
|
||||
|
||||
self.model_call_details, result = callback.logging_hook(
|
||||
kwargs=self.model_call_details,
|
||||
result=result,
|
||||
|
@ -1538,10 +1532,10 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
)
|
||||
else:
|
||||
if self.stream and complete_streaming_response:
|
||||
self.model_call_details["complete_response"] = (
|
||||
self.model_call_details.get(
|
||||
"complete_streaming_response", {}
|
||||
)
|
||||
self.model_call_details[
|
||||
"complete_response"
|
||||
] = self.model_call_details.get(
|
||||
"complete_streaming_response", {}
|
||||
)
|
||||
result = self.model_call_details["complete_response"]
|
||||
openMeterLogger.log_success_event(
|
||||
|
@ -1581,10 +1575,10 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
)
|
||||
else:
|
||||
if self.stream and complete_streaming_response:
|
||||
self.model_call_details["complete_response"] = (
|
||||
self.model_call_details.get(
|
||||
"complete_streaming_response", {}
|
||||
)
|
||||
self.model_call_details[
|
||||
"complete_response"
|
||||
] = self.model_call_details.get(
|
||||
"complete_streaming_response", {}
|
||||
)
|
||||
result = self.model_call_details["complete_response"]
|
||||
|
||||
|
@ -1659,7 +1653,6 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
if self.call_type == CallTypes.aretrieve_batch.value and isinstance(
|
||||
result, LiteLLMBatch
|
||||
):
|
||||
|
||||
response_cost, batch_usage, batch_models = await _handle_completed_batch(
|
||||
batch=result, custom_llm_provider=self.custom_llm_provider
|
||||
)
|
||||
|
@ -1692,9 +1685,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
if complete_streaming_response is not None:
|
||||
print_verbose("Async success callbacks: Got a complete streaming response")
|
||||
|
||||
self.model_call_details["async_complete_streaming_response"] = (
|
||||
complete_streaming_response
|
||||
)
|
||||
self.model_call_details[
|
||||
"async_complete_streaming_response"
|
||||
] = complete_streaming_response
|
||||
try:
|
||||
if self.model_call_details.get("cache_hit", False) is True:
|
||||
self.model_call_details["response_cost"] = 0.0
|
||||
|
@ -1704,10 +1697,10 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
model_call_details=self.model_call_details
|
||||
)
|
||||
# base_model defaults to None if not set on model_info
|
||||
self.model_call_details["response_cost"] = (
|
||||
self._response_cost_calculator(
|
||||
result=complete_streaming_response
|
||||
)
|
||||
self.model_call_details[
|
||||
"response_cost"
|
||||
] = self._response_cost_calculator(
|
||||
result=complete_streaming_response
|
||||
)
|
||||
|
||||
verbose_logger.debug(
|
||||
|
@ -1720,16 +1713,16 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
self.model_call_details["response_cost"] = None
|
||||
|
||||
## STANDARDIZED LOGGING PAYLOAD
|
||||
self.model_call_details["standard_logging_object"] = (
|
||||
get_standard_logging_object_payload(
|
||||
kwargs=self.model_call_details,
|
||||
init_response_obj=complete_streaming_response,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=self,
|
||||
status="success",
|
||||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||
)
|
||||
self.model_call_details[
|
||||
"standard_logging_object"
|
||||
] = get_standard_logging_object_payload(
|
||||
kwargs=self.model_call_details,
|
||||
init_response_obj=complete_streaming_response,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=self,
|
||||
status="success",
|
||||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||
)
|
||||
callbacks = self.get_combined_callback_list(
|
||||
dynamic_success_callbacks=self.dynamic_async_success_callbacks,
|
||||
|
@ -1935,18 +1928,18 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
|
||||
## STANDARDIZED LOGGING PAYLOAD
|
||||
|
||||
self.model_call_details["standard_logging_object"] = (
|
||||
get_standard_logging_object_payload(
|
||||
kwargs=self.model_call_details,
|
||||
init_response_obj={},
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=self,
|
||||
status="failure",
|
||||
error_str=str(exception),
|
||||
original_exception=exception,
|
||||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||
)
|
||||
self.model_call_details[
|
||||
"standard_logging_object"
|
||||
] = get_standard_logging_object_payload(
|
||||
kwargs=self.model_call_details,
|
||||
init_response_obj={},
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=self,
|
||||
status="failure",
|
||||
error_str=str(exception),
|
||||
original_exception=exception,
|
||||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||
)
|
||||
return start_time, end_time
|
||||
|
||||
|
@ -2084,7 +2077,6 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
)
|
||||
is not True
|
||||
): # custom logger class
|
||||
|
||||
callback.log_failure_event(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
|
@ -2713,9 +2705,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
|
|||
endpoint=arize_config.endpoint,
|
||||
)
|
||||
|
||||
os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = (
|
||||
f"space_key={arize_config.space_key},api_key={arize_config.api_key}"
|
||||
)
|
||||
os.environ[
|
||||
"OTEL_EXPORTER_OTLP_TRACES_HEADERS"
|
||||
] = f"space_key={arize_config.space_key},api_key={arize_config.api_key}"
|
||||
for callback in _in_memory_loggers:
|
||||
if (
|
||||
isinstance(callback, ArizeLogger)
|
||||
|
@ -2739,9 +2731,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
|
|||
|
||||
# auth can be disabled on local deployments of arize phoenix
|
||||
if arize_phoenix_config.otlp_auth_headers is not None:
|
||||
os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = (
|
||||
arize_phoenix_config.otlp_auth_headers
|
||||
)
|
||||
os.environ[
|
||||
"OTEL_EXPORTER_OTLP_TRACES_HEADERS"
|
||||
] = arize_phoenix_config.otlp_auth_headers
|
||||
|
||||
for callback in _in_memory_loggers:
|
||||
if (
|
||||
|
@ -2832,9 +2824,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
|
|||
exporter="otlp_http",
|
||||
endpoint="https://langtrace.ai/api/trace",
|
||||
)
|
||||
os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = (
|
||||
f"api_key={os.getenv('LANGTRACE_API_KEY')}"
|
||||
)
|
||||
os.environ[
|
||||
"OTEL_EXPORTER_OTLP_TRACES_HEADERS"
|
||||
] = f"api_key={os.getenv('LANGTRACE_API_KEY')}"
|
||||
for callback in _in_memory_loggers:
|
||||
if (
|
||||
isinstance(callback, OpenTelemetry)
|
||||
|
@ -3223,7 +3215,6 @@ class StandardLoggingPayloadSetup:
|
|||
custom_llm_provider: Optional[str],
|
||||
init_response_obj: Union[Any, BaseModel, dict],
|
||||
) -> StandardLoggingModelInformation:
|
||||
|
||||
model_cost_name = _select_model_name_for_cost_calc(
|
||||
model=None,
|
||||
completion_response=init_response_obj, # type: ignore
|
||||
|
@ -3286,7 +3277,6 @@ class StandardLoggingPayloadSetup:
|
|||
def get_additional_headers(
|
||||
additiona_headers: Optional[dict],
|
||||
) -> Optional[StandardLoggingAdditionalHeaders]:
|
||||
|
||||
if additiona_headers is None:
|
||||
return None
|
||||
|
||||
|
@ -3322,10 +3312,10 @@ class StandardLoggingPayloadSetup:
|
|||
for key in StandardLoggingHiddenParams.__annotations__.keys():
|
||||
if key in hidden_params:
|
||||
if key == "additional_headers":
|
||||
clean_hidden_params["additional_headers"] = (
|
||||
StandardLoggingPayloadSetup.get_additional_headers(
|
||||
hidden_params[key]
|
||||
)
|
||||
clean_hidden_params[
|
||||
"additional_headers"
|
||||
] = StandardLoggingPayloadSetup.get_additional_headers(
|
||||
hidden_params[key]
|
||||
)
|
||||
else:
|
||||
clean_hidden_params[key] = hidden_params[key] # type: ignore
|
||||
|
@ -3463,12 +3453,14 @@ def get_standard_logging_object_payload(
|
|||
)
|
||||
|
||||
# cleanup timestamps
|
||||
start_time_float, end_time_float, completion_start_time_float = (
|
||||
StandardLoggingPayloadSetup.cleanup_timestamps(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
completion_start_time=completion_start_time,
|
||||
)
|
||||
(
|
||||
start_time_float,
|
||||
end_time_float,
|
||||
completion_start_time_float,
|
||||
) = StandardLoggingPayloadSetup.cleanup_timestamps(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
completion_start_time=completion_start_time,
|
||||
)
|
||||
response_time = StandardLoggingPayloadSetup.get_response_time(
|
||||
start_time_float=start_time_float,
|
||||
|
@ -3495,7 +3487,6 @@ def get_standard_logging_object_payload(
|
|||
|
||||
saved_cache_cost: float = 0.0
|
||||
if cache_hit is True:
|
||||
|
||||
id = f"{id}_cache_hit{time.time()}" # do not duplicate the request id
|
||||
saved_cache_cost = (
|
||||
logging_obj._response_cost_calculator(
|
||||
|
@ -3658,9 +3649,9 @@ def scrub_sensitive_keys_in_metadata(litellm_params: Optional[dict]):
|
|||
):
|
||||
for k, v in metadata["user_api_key_metadata"].items():
|
||||
if k == "logging": # prevent logging user logging keys
|
||||
cleaned_user_api_key_metadata[k] = (
|
||||
"scrubbed_by_litellm_for_sensitive_keys"
|
||||
)
|
||||
cleaned_user_api_key_metadata[
|
||||
k
|
||||
] = "scrubbed_by_litellm_for_sensitive_keys"
|
||||
else:
|
||||
cleaned_user_api_key_metadata[k] = v
|
||||
|
||||
|
|
|
@ -258,14 +258,12 @@ def _extract_reasoning_content(message: dict) -> Tuple[Optional[str], Optional[s
|
|||
|
||||
|
||||
class LiteLLMResponseObjectHandler:
|
||||
|
||||
@staticmethod
|
||||
def convert_to_image_response(
|
||||
response_object: dict,
|
||||
model_response_object: Optional[ImageResponse] = None,
|
||||
hidden_params: Optional[dict] = None,
|
||||
) -> ImageResponse:
|
||||
|
||||
response_object.update({"hidden_params": hidden_params})
|
||||
|
||||
if model_response_object is None:
|
||||
|
@ -481,9 +479,9 @@ def convert_to_model_response_object( # noqa: PLR0915
|
|||
provider_specific_fields["thinking_blocks"] = thinking_blocks
|
||||
|
||||
if reasoning_content:
|
||||
provider_specific_fields["reasoning_content"] = (
|
||||
reasoning_content
|
||||
)
|
||||
provider_specific_fields[
|
||||
"reasoning_content"
|
||||
] = reasoning_content
|
||||
|
||||
message = Message(
|
||||
content=content,
|
||||
|
|
|
@ -17,7 +17,6 @@ from litellm.types.rerank import RerankRequest
|
|||
|
||||
|
||||
class ModelParamHelper:
|
||||
|
||||
@staticmethod
|
||||
def get_standard_logging_model_parameters(
|
||||
model_parameters: dict,
|
||||
|
|
|
@ -257,7 +257,6 @@ def _insert_assistant_continue_message(
|
|||
and message.get("role") == "user" # Current is user
|
||||
and messages[i + 1].get("role") == "user"
|
||||
): # Next is user
|
||||
|
||||
# Insert assistant message
|
||||
continue_message = (
|
||||
assistant_continue_message or DEFAULT_ASSISTANT_CONTINUE_MESSAGE
|
||||
|
|
|
@ -1042,10 +1042,10 @@ def convert_to_gemini_tool_call_invoke(
|
|||
if tool_calls is not None:
|
||||
for tool in tool_calls:
|
||||
if "function" in tool:
|
||||
gemini_function_call: Optional[VertexFunctionCall] = (
|
||||
_gemini_tool_call_invoke_helper(
|
||||
function_call_params=tool["function"]
|
||||
)
|
||||
gemini_function_call: Optional[
|
||||
VertexFunctionCall
|
||||
] = _gemini_tool_call_invoke_helper(
|
||||
function_call_params=tool["function"]
|
||||
)
|
||||
if gemini_function_call is not None:
|
||||
_parts_list.append(
|
||||
|
@ -1432,9 +1432,9 @@ def anthropic_messages_pt( # noqa: PLR0915
|
|||
)
|
||||
|
||||
if "cache_control" in _content_element:
|
||||
_anthropic_content_element["cache_control"] = (
|
||||
_content_element["cache_control"]
|
||||
)
|
||||
_anthropic_content_element[
|
||||
"cache_control"
|
||||
] = _content_element["cache_control"]
|
||||
user_content.append(_anthropic_content_element)
|
||||
elif m.get("type", "") == "text":
|
||||
m = cast(ChatCompletionTextObject, m)
|
||||
|
@ -1466,9 +1466,9 @@ def anthropic_messages_pt( # noqa: PLR0915
|
|||
)
|
||||
|
||||
if "cache_control" in _content_element:
|
||||
_anthropic_content_text_element["cache_control"] = (
|
||||
_content_element["cache_control"]
|
||||
)
|
||||
_anthropic_content_text_element[
|
||||
"cache_control"
|
||||
] = _content_element["cache_control"]
|
||||
|
||||
user_content.append(_anthropic_content_text_element)
|
||||
|
||||
|
@ -1533,7 +1533,6 @@ def anthropic_messages_pt( # noqa: PLR0915
|
|||
"content"
|
||||
] # don't pass empty text blocks. anthropic api raises errors.
|
||||
):
|
||||
|
||||
_anthropic_text_content_element = AnthropicMessagesTextParam(
|
||||
type="text",
|
||||
text=assistant_content_block["content"],
|
||||
|
@ -1569,7 +1568,6 @@ def anthropic_messages_pt( # noqa: PLR0915
|
|||
msg_i += 1
|
||||
|
||||
if assistant_content:
|
||||
|
||||
new_messages.append({"role": "assistant", "content": assistant_content})
|
||||
|
||||
if msg_i == init_msg_i: # prevent infinite loops
|
||||
|
@ -2245,7 +2243,6 @@ class BedrockImageProcessor:
|
|||
@staticmethod
|
||||
async def get_image_details_async(image_url) -> Tuple[str, str]:
|
||||
try:
|
||||
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.PromptFactory,
|
||||
params={"concurrent_limit": 1},
|
||||
|
@ -2612,7 +2609,6 @@ def get_user_message_block_or_continue_message(
|
|||
for item in modified_content_block:
|
||||
# Check if the list is empty
|
||||
if item["type"] == "text":
|
||||
|
||||
if not item["text"].strip():
|
||||
# Replace empty text with continue message
|
||||
_user_continue_message = ChatCompletionUserMessage(
|
||||
|
@ -3207,7 +3203,6 @@ def _bedrock_converse_messages_pt( # noqa: PLR0915
|
|||
assistant_content: List[BedrockContentBlock] = []
|
||||
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
|
||||
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
|
||||
|
||||
assistant_message_block = get_assistant_message_block_or_continue_message(
|
||||
message=messages[msg_i],
|
||||
assistant_continue_message=assistant_continue_message,
|
||||
|
@ -3410,7 +3405,6 @@ def response_schema_prompt(model: str, response_schema: dict) -> str:
|
|||
{"role": "user", "content": "{}".format(response_schema)}
|
||||
]
|
||||
if f"{model}/response_schema_prompt" in litellm.custom_prompt_dict:
|
||||
|
||||
custom_prompt_details = litellm.custom_prompt_dict[
|
||||
f"{model}/response_schema_prompt"
|
||||
] # allow user to define custom response schema prompt by model
|
||||
|
|
|
@ -122,7 +122,6 @@ class RealTimeStreaming:
|
|||
pass
|
||||
|
||||
async def bidirectional_forward(self):
|
||||
|
||||
forward_task = asyncio.create_task(self.backend_to_client_send_messages())
|
||||
try:
|
||||
await self.client_ack_messages()
|
||||
|
|
|
@ -135,9 +135,9 @@ def _get_turn_off_message_logging_from_dynamic_params(
|
|||
|
||||
handles boolean and string values of `turn_off_message_logging`
|
||||
"""
|
||||
standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = (
|
||||
model_call_details.get("standard_callback_dynamic_params", None)
|
||||
)
|
||||
standard_callback_dynamic_params: Optional[
|
||||
StandardCallbackDynamicParams
|
||||
] = model_call_details.get("standard_callback_dynamic_params", None)
|
||||
if standard_callback_dynamic_params:
|
||||
_turn_off_message_logging = standard_callback_dynamic_params.get(
|
||||
"turn_off_message_logging"
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from typing import Any, Dict, Optional, Set
|
||||
|
||||
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
|
||||
|
||||
|
||||
|
@ -40,7 +41,10 @@ class SensitiveDataMasker:
|
|||
return result
|
||||
|
||||
def mask_dict(
|
||||
self, data: Dict[str, Any], depth: int = 0, max_depth: int = DEFAULT_MAX_RECURSE_DEPTH
|
||||
self,
|
||||
data: Dict[str, Any],
|
||||
depth: int = 0,
|
||||
max_depth: int = DEFAULT_MAX_RECURSE_DEPTH,
|
||||
) -> Dict[str, Any]:
|
||||
if depth >= max_depth:
|
||||
return data
|
||||
|
|
|
@ -104,7 +104,6 @@ class ChunkProcessor:
|
|||
def get_combined_tool_content(
|
||||
self, tool_call_chunks: List[Dict[str, Any]]
|
||||
) -> List[ChatCompletionMessageToolCall]:
|
||||
|
||||
argument_list: List[str] = []
|
||||
delta = tool_call_chunks[0]["choices"][0]["delta"]
|
||||
id = None
|
||||
|
|
|
@ -84,9 +84,9 @@ class CustomStreamWrapper:
|
|||
|
||||
self.system_fingerprint: Optional[str] = None
|
||||
self.received_finish_reason: Optional[str] = None
|
||||
self.intermittent_finish_reason: Optional[str] = (
|
||||
None # finish reasons that show up mid-stream
|
||||
)
|
||||
self.intermittent_finish_reason: Optional[
|
||||
str
|
||||
] = None # finish reasons that show up mid-stream
|
||||
self.special_tokens = [
|
||||
"<|assistant|>",
|
||||
"<|system|>",
|
||||
|
@ -814,7 +814,6 @@ class CustomStreamWrapper:
|
|||
model_response: ModelResponseStream,
|
||||
response_obj: Dict[str, Any],
|
||||
):
|
||||
|
||||
print_verbose(
|
||||
f"completion_obj: {completion_obj}, model_response.choices[0]: {model_response.choices[0]}, response_obj: {response_obj}"
|
||||
)
|
||||
|
@ -1008,7 +1007,6 @@ class CustomStreamWrapper:
|
|||
self.custom_llm_provider
|
||||
and self.custom_llm_provider in litellm._custom_providers
|
||||
):
|
||||
|
||||
if self.received_finish_reason is not None:
|
||||
if "provider_specific_fields" not in chunk:
|
||||
raise StopIteration
|
||||
|
@ -1379,9 +1377,9 @@ class CustomStreamWrapper:
|
|||
_json_delta = delta.model_dump()
|
||||
print_verbose(f"_json_delta: {_json_delta}")
|
||||
if "role" not in _json_delta or _json_delta["role"] is None:
|
||||
_json_delta["role"] = (
|
||||
"assistant" # mistral's api returns role as None
|
||||
)
|
||||
_json_delta[
|
||||
"role"
|
||||
] = "assistant" # mistral's api returns role as None
|
||||
if "tool_calls" in _json_delta and isinstance(
|
||||
_json_delta["tool_calls"], list
|
||||
):
|
||||
|
@ -1758,9 +1756,9 @@ class CustomStreamWrapper:
|
|||
chunk = next(self.completion_stream)
|
||||
if chunk is not None and chunk != b"":
|
||||
print_verbose(f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}")
|
||||
processed_chunk: Optional[ModelResponseStream] = (
|
||||
self.chunk_creator(chunk=chunk)
|
||||
)
|
||||
processed_chunk: Optional[
|
||||
ModelResponseStream
|
||||
] = self.chunk_creator(chunk=chunk)
|
||||
print_verbose(
|
||||
f"PROCESSED CHUNK POST CHUNK CREATOR: {processed_chunk}"
|
||||
)
|
||||
|
|
|
@ -290,7 +290,6 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
headers={},
|
||||
client=None,
|
||||
):
|
||||
|
||||
optional_params = copy.deepcopy(optional_params)
|
||||
stream = optional_params.pop("stream", None)
|
||||
json_mode: bool = optional_params.pop("json_mode", False)
|
||||
|
@ -491,7 +490,6 @@ class ModelResponseIterator:
|
|||
def _handle_usage(
|
||||
self, anthropic_usage_chunk: Union[dict, UsageDelta]
|
||||
) -> AnthropicChatCompletionUsageBlock:
|
||||
|
||||
usage_block = AnthropicChatCompletionUsageBlock(
|
||||
prompt_tokens=anthropic_usage_chunk.get("input_tokens", 0),
|
||||
completion_tokens=anthropic_usage_chunk.get("output_tokens", 0),
|
||||
|
@ -515,7 +513,9 @@ class ModelResponseIterator:
|
|||
|
||||
return usage_block
|
||||
|
||||
def _content_block_delta_helper(self, chunk: dict) -> Tuple[
|
||||
def _content_block_delta_helper(
|
||||
self, chunk: dict
|
||||
) -> Tuple[
|
||||
str,
|
||||
Optional[ChatCompletionToolCallChunk],
|
||||
List[ChatCompletionThinkingBlock],
|
||||
|
@ -592,9 +592,12 @@ class ModelResponseIterator:
|
|||
Anthropic content chunk
|
||||
chunk = {'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': 'Hello'}}
|
||||
"""
|
||||
text, tool_use, thinking_blocks, provider_specific_fields = (
|
||||
self._content_block_delta_helper(chunk=chunk)
|
||||
)
|
||||
(
|
||||
text,
|
||||
tool_use,
|
||||
thinking_blocks,
|
||||
provider_specific_fields,
|
||||
) = self._content_block_delta_helper(chunk=chunk)
|
||||
if thinking_blocks:
|
||||
reasoning_content = self._handle_reasoning_content(
|
||||
thinking_blocks=thinking_blocks
|
||||
|
@ -620,7 +623,6 @@ class ModelResponseIterator:
|
|||
"index": self.tool_index,
|
||||
}
|
||||
elif type_chunk == "content_block_stop":
|
||||
|
||||
ContentBlockStop(**chunk) # type: ignore
|
||||
# check if tool call content block
|
||||
is_empty = self.check_empty_tool_call_args()
|
||||
|
|
|
@ -49,9 +49,9 @@ class AnthropicConfig(BaseConfig):
|
|||
to pass metadata to anthropic, it's {"user_id": "any-relevant-information"}
|
||||
"""
|
||||
|
||||
max_tokens: Optional[int] = (
|
||||
4096 # anthropic requires a default value (Opus, Sonnet, and Haiku have the same default)
|
||||
)
|
||||
max_tokens: Optional[
|
||||
int
|
||||
] = 4096 # anthropic requires a default value (Opus, Sonnet, and Haiku have the same default)
|
||||
stop_sequences: Optional[list] = None
|
||||
temperature: Optional[int] = None
|
||||
top_p: Optional[int] = None
|
||||
|
@ -104,7 +104,6 @@ class AnthropicConfig(BaseConfig):
|
|||
def get_json_schema_from_pydantic_object(
|
||||
self, response_format: Union[Any, Dict, None]
|
||||
) -> Optional[dict]:
|
||||
|
||||
return type_to_response_format_param(
|
||||
response_format, ref_template="/$defs/{model}"
|
||||
) # Relevant issue: https://github.com/BerriAI/litellm/issues/7755
|
||||
|
@ -125,7 +124,6 @@ class AnthropicConfig(BaseConfig):
|
|||
is_vertex_request: bool = False,
|
||||
user_anthropic_beta_headers: Optional[List[str]] = None,
|
||||
) -> dict:
|
||||
|
||||
betas = set()
|
||||
if prompt_caching_set:
|
||||
betas.add("prompt-caching-2024-07-31")
|
||||
|
@ -300,7 +298,6 @@ class AnthropicConfig(BaseConfig):
|
|||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
|
||||
is_thinking_enabled = self.is_thinking_enabled(
|
||||
non_default_params=non_default_params
|
||||
)
|
||||
|
@ -321,11 +318,11 @@ class AnthropicConfig(BaseConfig):
|
|||
optional_params=optional_params, tools=tool_value
|
||||
)
|
||||
if param == "tool_choice" or param == "parallel_tool_calls":
|
||||
_tool_choice: Optional[AnthropicMessagesToolChoice] = (
|
||||
self._map_tool_choice(
|
||||
tool_choice=non_default_params.get("tool_choice"),
|
||||
parallel_tool_use=non_default_params.get("parallel_tool_calls"),
|
||||
)
|
||||
_tool_choice: Optional[
|
||||
AnthropicMessagesToolChoice
|
||||
] = self._map_tool_choice(
|
||||
tool_choice=non_default_params.get("tool_choice"),
|
||||
parallel_tool_use=non_default_params.get("parallel_tool_calls"),
|
||||
)
|
||||
|
||||
if _tool_choice is not None:
|
||||
|
@ -341,7 +338,6 @@ class AnthropicConfig(BaseConfig):
|
|||
if param == "top_p":
|
||||
optional_params["top_p"] = value
|
||||
if param == "response_format" and isinstance(value, dict):
|
||||
|
||||
ignore_response_format_types = ["text"]
|
||||
if value["type"] in ignore_response_format_types: # value is a no-op
|
||||
continue
|
||||
|
@ -470,9 +466,9 @@ class AnthropicConfig(BaseConfig):
|
|||
text=system_message_block["content"],
|
||||
)
|
||||
if "cache_control" in system_message_block:
|
||||
anthropic_system_message_content["cache_control"] = (
|
||||
system_message_block["cache_control"]
|
||||
)
|
||||
anthropic_system_message_content[
|
||||
"cache_control"
|
||||
] = system_message_block["cache_control"]
|
||||
anthropic_system_message_list.append(
|
||||
anthropic_system_message_content
|
||||
)
|
||||
|
@ -486,9 +482,9 @@ class AnthropicConfig(BaseConfig):
|
|||
)
|
||||
)
|
||||
if "cache_control" in _content:
|
||||
anthropic_system_message_content["cache_control"] = (
|
||||
_content["cache_control"]
|
||||
)
|
||||
anthropic_system_message_content[
|
||||
"cache_control"
|
||||
] = _content["cache_control"]
|
||||
|
||||
anthropic_system_message_list.append(
|
||||
anthropic_system_message_content
|
||||
|
@ -597,7 +593,9 @@ class AnthropicConfig(BaseConfig):
|
|||
)
|
||||
return _message
|
||||
|
||||
def extract_response_content(self, completion_response: dict) -> Tuple[
|
||||
def extract_response_content(
|
||||
self, completion_response: dict
|
||||
) -> Tuple[
|
||||
str,
|
||||
Optional[List[Any]],
|
||||
Optional[List[ChatCompletionThinkingBlock]],
|
||||
|
@ -693,9 +691,13 @@ class AnthropicConfig(BaseConfig):
|
|||
reasoning_content: Optional[str] = None
|
||||
tool_calls: List[ChatCompletionToolCallChunk] = []
|
||||
|
||||
text_content, citations, thinking_blocks, reasoning_content, tool_calls = (
|
||||
self.extract_response_content(completion_response=completion_response)
|
||||
)
|
||||
(
|
||||
text_content,
|
||||
citations,
|
||||
thinking_blocks,
|
||||
reasoning_content,
|
||||
tool_calls,
|
||||
) = self.extract_response_content(completion_response=completion_response)
|
||||
|
||||
_message = litellm.Message(
|
||||
tool_calls=tool_calls,
|
||||
|
|
|
@ -54,9 +54,9 @@ class AnthropicTextConfig(BaseConfig):
|
|||
to pass metadata to anthropic, it's {"user_id": "any-relevant-information"}
|
||||
"""
|
||||
|
||||
max_tokens_to_sample: Optional[int] = (
|
||||
litellm.max_tokens
|
||||
) # anthropic requires a default
|
||||
max_tokens_to_sample: Optional[
|
||||
int
|
||||
] = litellm.max_tokens # anthropic requires a default
|
||||
stop_sequences: Optional[list] = None
|
||||
temperature: Optional[int] = None
|
||||
top_p: Optional[int] = None
|
||||
|
|
|
@ -25,7 +25,6 @@ from litellm.utils import ProviderConfigManager, client
|
|||
|
||||
|
||||
class AnthropicMessagesHandler:
|
||||
|
||||
@staticmethod
|
||||
async def _handle_anthropic_streaming(
|
||||
response: httpx.Response,
|
||||
|
@ -74,19 +73,22 @@ async def anthropic_messages(
|
|||
"""
|
||||
# Use provided client or create a new one
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
model, _custom_llm_provider, dynamic_api_key, dynamic_api_base = (
|
||||
litellm.get_llm_provider(
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=optional_params.api_base,
|
||||
api_key=optional_params.api_key,
|
||||
)
|
||||
(
|
||||
model,
|
||||
_custom_llm_provider,
|
||||
dynamic_api_key,
|
||||
dynamic_api_base,
|
||||
) = litellm.get_llm_provider(
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=optional_params.api_base,
|
||||
api_key=optional_params.api_key,
|
||||
)
|
||||
anthropic_messages_provider_config: Optional[BaseAnthropicMessagesConfig] = (
|
||||
ProviderConfigManager.get_provider_anthropic_messages_config(
|
||||
model=model,
|
||||
provider=litellm.LlmProviders(_custom_llm_provider),
|
||||
)
|
||||
anthropic_messages_provider_config: Optional[
|
||||
BaseAnthropicMessagesConfig
|
||||
] = ProviderConfigManager.get_provider_anthropic_messages_config(
|
||||
model=model,
|
||||
provider=litellm.LlmProviders(_custom_llm_provider),
|
||||
)
|
||||
if anthropic_messages_provider_config is None:
|
||||
raise ValueError(
|
||||
|
|
|
@ -654,7 +654,6 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
|
|||
) -> EmbeddingResponse:
|
||||
response = None
|
||||
try:
|
||||
|
||||
openai_aclient = self.get_azure_openai_client(
|
||||
api_version=api_version,
|
||||
api_base=api_base,
|
||||
|
@ -835,7 +834,6 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
|
|||
"2023-10-01-preview",
|
||||
]
|
||||
): # CREATE + POLL for azure dall-e-2 calls
|
||||
|
||||
api_base = modify_url(
|
||||
original_url=api_base, new_path="/openai/images/generations:submit"
|
||||
)
|
||||
|
@ -867,7 +865,6 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
|
|||
)
|
||||
while response.json()["status"] not in ["succeeded", "failed"]:
|
||||
if time.time() - start_time > timeout_secs:
|
||||
|
||||
raise AzureOpenAIError(
|
||||
status_code=408, message="Operation polling timed out."
|
||||
)
|
||||
|
@ -935,7 +932,6 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
|
|||
"2023-10-01-preview",
|
||||
]
|
||||
): # CREATE + POLL for azure dall-e-2 calls
|
||||
|
||||
api_base = modify_url(
|
||||
original_url=api_base, new_path="/openai/images/generations:submit"
|
||||
)
|
||||
|
@ -1199,7 +1195,6 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
|
|||
client=None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> HttpxBinaryResponseContent:
|
||||
|
||||
max_retries = optional_params.pop("max_retries", 2)
|
||||
|
||||
if aspeech is not None and aspeech is True:
|
||||
|
@ -1253,7 +1248,6 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
|
|||
client=None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> HttpxBinaryResponseContent:
|
||||
|
||||
azure_client: AsyncAzureOpenAI = self.get_azure_openai_client(
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
|
|
|
@ -50,15 +50,15 @@ class AzureBatchesAPI(BaseAzureLLM):
|
|||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
|
||||
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
||||
self.get_azure_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
litellm_params=litellm_params or {},
|
||||
)
|
||||
azure_client: Optional[
|
||||
Union[AzureOpenAI, AsyncAzureOpenAI]
|
||||
] = self.get_azure_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
litellm_params=litellm_params or {},
|
||||
)
|
||||
if azure_client is None:
|
||||
raise ValueError(
|
||||
|
@ -96,15 +96,15 @@ class AzureBatchesAPI(BaseAzureLLM):
|
|||
client: Optional[AzureOpenAI] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
):
|
||||
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
||||
self.get_azure_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
litellm_params=litellm_params or {},
|
||||
)
|
||||
azure_client: Optional[
|
||||
Union[AzureOpenAI, AsyncAzureOpenAI]
|
||||
] = self.get_azure_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
litellm_params=litellm_params or {},
|
||||
)
|
||||
if azure_client is None:
|
||||
raise ValueError(
|
||||
|
@ -144,15 +144,15 @@ class AzureBatchesAPI(BaseAzureLLM):
|
|||
client: Optional[AzureOpenAI] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
):
|
||||
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
||||
self.get_azure_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
litellm_params=litellm_params or {},
|
||||
)
|
||||
azure_client: Optional[
|
||||
Union[AzureOpenAI, AsyncAzureOpenAI]
|
||||
] = self.get_azure_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
litellm_params=litellm_params or {},
|
||||
)
|
||||
if azure_client is None:
|
||||
raise ValueError(
|
||||
|
@ -183,15 +183,15 @@ class AzureBatchesAPI(BaseAzureLLM):
|
|||
client: Optional[AzureOpenAI] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
):
|
||||
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
||||
self.get_azure_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
litellm_params=litellm_params or {},
|
||||
)
|
||||
azure_client: Optional[
|
||||
Union[AzureOpenAI, AsyncAzureOpenAI]
|
||||
] = self.get_azure_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
litellm_params=litellm_params or {},
|
||||
)
|
||||
if azure_client is None:
|
||||
raise ValueError(
|
||||
|
|
|
@ -306,7 +306,6 @@ class BaseAzureLLM(BaseOpenAILLM):
|
|||
api_version: Optional[str],
|
||||
is_async: bool,
|
||||
) -> dict:
|
||||
|
||||
azure_ad_token_provider: Optional[Callable[[], str]] = None
|
||||
# If we have api_key, then we have higher priority
|
||||
azure_ad_token = litellm_params.get("azure_ad_token")
|
||||
|
|
|
@ -46,16 +46,15 @@ class AzureOpenAIFilesAPI(BaseAzureLLM):
|
|||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> Union[FileObject, Coroutine[Any, Any, FileObject]]:
|
||||
|
||||
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
||||
self.get_azure_openai_client(
|
||||
litellm_params=litellm_params or {},
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
)
|
||||
openai_client: Optional[
|
||||
Union[AzureOpenAI, AsyncAzureOpenAI]
|
||||
] = self.get_azure_openai_client(
|
||||
litellm_params=litellm_params or {},
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
)
|
||||
if openai_client is None:
|
||||
raise ValueError(
|
||||
|
@ -95,15 +94,15 @@ class AzureOpenAIFilesAPI(BaseAzureLLM):
|
|||
) -> Union[
|
||||
HttpxBinaryResponseContent, Coroutine[Any, Any, HttpxBinaryResponseContent]
|
||||
]:
|
||||
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
||||
self.get_azure_openai_client(
|
||||
litellm_params=litellm_params or {},
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
)
|
||||
openai_client: Optional[
|
||||
Union[AzureOpenAI, AsyncAzureOpenAI]
|
||||
] = self.get_azure_openai_client(
|
||||
litellm_params=litellm_params or {},
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
)
|
||||
if openai_client is None:
|
||||
raise ValueError(
|
||||
|
@ -145,15 +144,15 @@ class AzureOpenAIFilesAPI(BaseAzureLLM):
|
|||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
):
|
||||
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
||||
self.get_azure_openai_client(
|
||||
litellm_params=litellm_params or {},
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
)
|
||||
openai_client: Optional[
|
||||
Union[AzureOpenAI, AsyncAzureOpenAI]
|
||||
] = self.get_azure_openai_client(
|
||||
litellm_params=litellm_params or {},
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
)
|
||||
if openai_client is None:
|
||||
raise ValueError(
|
||||
|
@ -197,15 +196,15 @@ class AzureOpenAIFilesAPI(BaseAzureLLM):
|
|||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
):
|
||||
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
||||
self.get_azure_openai_client(
|
||||
litellm_params=litellm_params or {},
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
)
|
||||
openai_client: Optional[
|
||||
Union[AzureOpenAI, AsyncAzureOpenAI]
|
||||
] = self.get_azure_openai_client(
|
||||
litellm_params=litellm_params or {},
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
)
|
||||
if openai_client is None:
|
||||
raise ValueError(
|
||||
|
@ -251,15 +250,15 @@ class AzureOpenAIFilesAPI(BaseAzureLLM):
|
|||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
):
|
||||
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
||||
self.get_azure_openai_client(
|
||||
litellm_params=litellm_params or {},
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
)
|
||||
openai_client: Optional[
|
||||
Union[AzureOpenAI, AsyncAzureOpenAI]
|
||||
] = self.get_azure_openai_client(
|
||||
litellm_params=litellm_params or {},
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
)
|
||||
if openai_client is None:
|
||||
raise ValueError(
|
||||
|
|
|
@ -25,14 +25,7 @@ class AzureOpenAIFineTuningAPI(OpenAIFineTuningAPI, BaseAzureLLM):
|
|||
_is_async: bool = False,
|
||||
api_version: Optional[str] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> Optional[
|
||||
Union[
|
||||
OpenAI,
|
||||
AsyncOpenAI,
|
||||
AzureOpenAI,
|
||||
AsyncAzureOpenAI,
|
||||
]
|
||||
]:
|
||||
) -> Optional[Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI,]]:
|
||||
# Override to use Azure-specific client initialization
|
||||
if isinstance(client, OpenAI) or isinstance(client, AsyncOpenAI):
|
||||
client = None
|
||||
|
|
|
@ -145,7 +145,6 @@ class AzureAIStudioConfig(OpenAIConfig):
|
|||
2. If message contains an image or audio, send as is (user-intended)
|
||||
"""
|
||||
for message in messages:
|
||||
|
||||
# Do nothing if the message contains an image or audio
|
||||
if _audio_or_image_in_message_content(message):
|
||||
continue
|
||||
|
|
|
@ -22,7 +22,6 @@ class AzureAICohereConfig:
|
|||
pass
|
||||
|
||||
def _map_azure_model_group(self, model: str) -> str:
|
||||
|
||||
if model == "offer-cohere-embed-multili-paygo":
|
||||
return "Cohere-embed-v3-multilingual"
|
||||
elif model == "offer-cohere-embed-english-paygo":
|
||||
|
|
|
@ -17,7 +17,6 @@ from .cohere_transformation import AzureAICohereConfig
|
|||
|
||||
|
||||
class AzureAIEmbedding(OpenAIChatCompletion):
|
||||
|
||||
def _process_response(
|
||||
self,
|
||||
image_embedding_responses: Optional[List],
|
||||
|
@ -145,7 +144,6 @@ class AzureAIEmbedding(OpenAIChatCompletion):
|
|||
api_base: Optional[str] = None,
|
||||
client=None,
|
||||
) -> EmbeddingResponse:
|
||||
|
||||
(
|
||||
image_embeddings_request,
|
||||
v1_embeddings_request,
|
||||
|
|
|
@ -17,6 +17,7 @@ class AzureAIRerankConfig(CohereRerankConfig):
|
|||
"""
|
||||
Azure AI Rerank - Follows the same Spec as Cohere Rerank
|
||||
"""
|
||||
|
||||
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
|
|
|
@ -9,7 +9,6 @@ from litellm.types.utils import ModelResponse, TextCompletionResponse
|
|||
|
||||
|
||||
class BaseLLM:
|
||||
|
||||
_client_session: Optional[httpx.Client] = None
|
||||
|
||||
def process_response(
|
||||
|
|
|
@ -218,7 +218,6 @@ class BaseConfig(ABC):
|
|||
json_schema = value["json_schema"]["schema"]
|
||||
|
||||
if json_schema and not is_response_format_supported:
|
||||
|
||||
_tool_choice = ChatCompletionToolChoiceObjectParam(
|
||||
type="function",
|
||||
function=ChatCompletionToolChoiceFunctionParam(
|
||||
|
|
|
@ -58,7 +58,6 @@ class BaseResponsesAPIConfig(ABC):
|
|||
model: str,
|
||||
drop_params: bool,
|
||||
) -> Dict:
|
||||
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
|
|
@ -81,7 +81,6 @@ def make_sync_call(
|
|||
|
||||
|
||||
class BedrockConverseLLM(BaseAWSLLM):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
@ -114,7 +113,6 @@ class BedrockConverseLLM(BaseAWSLLM):
|
|||
fake_stream: bool = False,
|
||||
json_mode: Optional[bool] = False,
|
||||
) -> CustomStreamWrapper:
|
||||
|
||||
request_data = await litellm.AmazonConverseConfig()._async_transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
@ -179,7 +177,6 @@ class BedrockConverseLLM(BaseAWSLLM):
|
|||
headers: dict = {},
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
|
||||
request_data = await litellm.AmazonConverseConfig()._async_transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
@ -265,7 +262,6 @@ class BedrockConverseLLM(BaseAWSLLM):
|
|||
extra_headers: Optional[dict] = None,
|
||||
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
||||
):
|
||||
|
||||
## SETUP ##
|
||||
stream = optional_params.pop("stream", None)
|
||||
unencoded_model_id = optional_params.pop("model_id", None)
|
||||
|
@ -301,9 +297,9 @@ class BedrockConverseLLM(BaseAWSLLM):
|
|||
aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None)
|
||||
optional_params.pop("aws_region_name", None)
|
||||
|
||||
litellm_params["aws_region_name"] = (
|
||||
aws_region_name # [DO NOT DELETE] important for async calls
|
||||
)
|
||||
litellm_params[
|
||||
"aws_region_name"
|
||||
] = aws_region_name # [DO NOT DELETE] important for async calls
|
||||
|
||||
credentials: Credentials = self.get_credentials(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
|
|
|
@ -223,7 +223,6 @@ class AmazonConverseConfig(BaseConfig):
|
|||
)
|
||||
for param, value in non_default_params.items():
|
||||
if param == "response_format" and isinstance(value, dict):
|
||||
|
||||
ignore_response_format_types = ["text"]
|
||||
if value["type"] in ignore_response_format_types: # value is a no-op
|
||||
continue
|
||||
|
@ -715,9 +714,9 @@ class AmazonConverseConfig(BaseConfig):
|
|||
chat_completion_message: ChatCompletionResponseMessage = {"role": "assistant"}
|
||||
content_str = ""
|
||||
tools: List[ChatCompletionToolCallChunk] = []
|
||||
reasoningContentBlocks: Optional[List[BedrockConverseReasoningContentBlock]] = (
|
||||
None
|
||||
)
|
||||
reasoningContentBlocks: Optional[
|
||||
List[BedrockConverseReasoningContentBlock]
|
||||
] = None
|
||||
|
||||
if message is not None:
|
||||
for idx, content in enumerate(message["content"]):
|
||||
|
@ -727,7 +726,6 @@ class AmazonConverseConfig(BaseConfig):
|
|||
if "text" in content:
|
||||
content_str += content["text"]
|
||||
if "toolUse" in content:
|
||||
|
||||
## check tool name was formatted by litellm
|
||||
_response_tool_name = content["toolUse"]["name"]
|
||||
response_tool_name = get_bedrock_tool_name(
|
||||
|
@ -754,12 +752,12 @@ class AmazonConverseConfig(BaseConfig):
|
|||
chat_completion_message["provider_specific_fields"] = {
|
||||
"reasoningContentBlocks": reasoningContentBlocks,
|
||||
}
|
||||
chat_completion_message["reasoning_content"] = (
|
||||
self._transform_reasoning_content(reasoningContentBlocks)
|
||||
)
|
||||
chat_completion_message["thinking_blocks"] = (
|
||||
self._transform_thinking_blocks(reasoningContentBlocks)
|
||||
)
|
||||
chat_completion_message[
|
||||
"reasoning_content"
|
||||
] = self._transform_reasoning_content(reasoningContentBlocks)
|
||||
chat_completion_message[
|
||||
"thinking_blocks"
|
||||
] = self._transform_thinking_blocks(reasoningContentBlocks)
|
||||
chat_completion_message["content"] = content_str
|
||||
if json_mode is True and tools is not None and len(tools) == 1:
|
||||
# to support 'json_schema' logic on bedrock models
|
||||
|
|
|
@ -496,9 +496,9 @@ class BedrockLLM(BaseAWSLLM):
|
|||
content=None,
|
||||
)
|
||||
model_response.choices[0].message = _message # type: ignore
|
||||
model_response._hidden_params["original_response"] = (
|
||||
outputText # allow user to access raw anthropic tool calling response
|
||||
)
|
||||
model_response._hidden_params[
|
||||
"original_response"
|
||||
] = outputText # allow user to access raw anthropic tool calling response
|
||||
if (
|
||||
_is_function_call is True
|
||||
and stream is not None
|
||||
|
@ -806,9 +806,9 @@ class BedrockLLM(BaseAWSLLM):
|
|||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
if stream is True:
|
||||
inference_params["stream"] = (
|
||||
True # cohere requires stream = True in inference params
|
||||
)
|
||||
inference_params[
|
||||
"stream"
|
||||
] = True # cohere requires stream = True in inference params
|
||||
data = json.dumps({"prompt": prompt, **inference_params})
|
||||
elif provider == "anthropic":
|
||||
if model.startswith("anthropic.claude-3"):
|
||||
|
@ -1205,7 +1205,6 @@ class BedrockLLM(BaseAWSLLM):
|
|||
def get_response_stream_shape():
|
||||
global _response_stream_shape_cache
|
||||
if _response_stream_shape_cache is None:
|
||||
|
||||
from botocore.loaders import Loader
|
||||
from botocore.model import ServiceModel
|
||||
|
||||
|
@ -1539,7 +1538,6 @@ class AmazonDeepSeekR1StreamDecoder(AWSEventStreamDecoder):
|
|||
model: str,
|
||||
sync_stream: bool,
|
||||
) -> None:
|
||||
|
||||
super().__init__(model=model)
|
||||
from litellm.llms.bedrock.chat.invoke_transformations.amazon_deepseek_transformation import (
|
||||
AmazonDeepseekR1ResponseIterator,
|
||||
|
|
|
@ -225,9 +225,9 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
|||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
if stream is True:
|
||||
inference_params["stream"] = (
|
||||
True # cohere requires stream = True in inference params
|
||||
)
|
||||
inference_params[
|
||||
"stream"
|
||||
] = True # cohere requires stream = True in inference params
|
||||
request_data = {"prompt": prompt, **inference_params}
|
||||
elif provider == "anthropic":
|
||||
return litellm.AmazonAnthropicClaude3Config().transform_request(
|
||||
|
@ -311,7 +311,6 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
|||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
|
||||
try:
|
||||
completion_response = raw_response.json()
|
||||
except Exception:
|
||||
|
|
|
@ -314,7 +314,6 @@ def get_bedrock_tool_name(response_tool_name: str) -> str:
|
|||
|
||||
|
||||
class BedrockModelInfo(BaseLLMModelInfo):
|
||||
|
||||
global_config = AmazonBedrockGlobalConfig()
|
||||
all_global_regions = global_config.get_all_regions()
|
||||
|
||||
|
|
|
@ -33,9 +33,9 @@ class AmazonTitanMultimodalEmbeddingG1Config:
|
|||
) -> dict:
|
||||
for k, v in non_default_params.items():
|
||||
if k == "dimensions":
|
||||
optional_params["embeddingConfig"] = (
|
||||
AmazonTitanMultimodalEmbeddingConfig(outputEmbeddingLength=v)
|
||||
)
|
||||
optional_params[
|
||||
"embeddingConfig"
|
||||
] = AmazonTitanMultimodalEmbeddingConfig(outputEmbeddingLength=v)
|
||||
return optional_params
|
||||
|
||||
def _transform_request(
|
||||
|
@ -58,7 +58,6 @@ class AmazonTitanMultimodalEmbeddingG1Config:
|
|||
def _transform_response(
|
||||
self, response_list: List[dict], model: str
|
||||
) -> EmbeddingResponse:
|
||||
|
||||
total_prompt_tokens = 0
|
||||
transformed_responses: List[Embedding] = []
|
||||
for index, response in enumerate(response_list):
|
||||
|
|
|
@ -1,12 +1,16 @@
|
|||
import types
|
||||
from typing import List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from openai.types.image import Image
|
||||
|
||||
from litellm.types.llms.bedrock import (
|
||||
AmazonNovaCanvasTextToImageRequest, AmazonNovaCanvasTextToImageResponse,
|
||||
AmazonNovaCanvasTextToImageParams, AmazonNovaCanvasRequestBase, AmazonNovaCanvasColorGuidedGenerationParams,
|
||||
AmazonNovaCanvasColorGuidedGenerationParams,
|
||||
AmazonNovaCanvasColorGuidedRequest,
|
||||
AmazonNovaCanvasImageGenerationConfig,
|
||||
AmazonNovaCanvasRequestBase,
|
||||
AmazonNovaCanvasTextToImageParams,
|
||||
AmazonNovaCanvasTextToImageRequest,
|
||||
AmazonNovaCanvasTextToImageResponse,
|
||||
)
|
||||
from litellm.types.utils import ImageResponse
|
||||
|
||||
|
@ -23,7 +27,7 @@ class AmazonNovaCanvasConfig:
|
|||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
|
@ -32,13 +36,12 @@ class AmazonNovaCanvasConfig:
|
|||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
and v is not None
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_supported_openai_params(cls, model: Optional[str] = None) -> List:
|
||||
"""
|
||||
"""
|
||||
""" """
|
||||
return ["n", "size", "quality"]
|
||||
|
||||
@classmethod
|
||||
|
@ -56,7 +59,7 @@ class AmazonNovaCanvasConfig:
|
|||
|
||||
@classmethod
|
||||
def transform_request_body(
|
||||
cls, text: str, optional_params: dict
|
||||
cls, text: str, optional_params: dict
|
||||
) -> AmazonNovaCanvasRequestBase:
|
||||
"""
|
||||
Transform the request body for Amazon Nova Canvas model
|
||||
|
@ -65,18 +68,64 @@ class AmazonNovaCanvasConfig:
|
|||
image_generation_config = optional_params.pop("imageGenerationConfig", {})
|
||||
image_generation_config = {**image_generation_config, **optional_params}
|
||||
if task_type == "TEXT_IMAGE":
|
||||
text_to_image_params = image_generation_config.pop("textToImageParams", {})
|
||||
text_to_image_params = {"text" :text, **text_to_image_params}
|
||||
text_to_image_params = AmazonNovaCanvasTextToImageParams(**text_to_image_params)
|
||||
return AmazonNovaCanvasTextToImageRequest(textToImageParams=text_to_image_params, taskType=task_type,
|
||||
imageGenerationConfig=image_generation_config)
|
||||
text_to_image_params: Dict[str, Any] = image_generation_config.pop(
|
||||
"textToImageParams", {}
|
||||
)
|
||||
text_to_image_params = {"text": text, **text_to_image_params}
|
||||
try:
|
||||
text_to_image_params_typed = AmazonNovaCanvasTextToImageParams(
|
||||
**text_to_image_params # type: ignore
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Error transforming text to image params: {e}. Got params: {text_to_image_params}, Expected params: {AmazonNovaCanvasTextToImageParams.__annotations__}"
|
||||
)
|
||||
|
||||
try:
|
||||
image_generation_config_typed = AmazonNovaCanvasImageGenerationConfig(
|
||||
**image_generation_config
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Error transforming image generation config: {e}. Got params: {image_generation_config}, Expected params: {AmazonNovaCanvasImageGenerationConfig.__annotations__}"
|
||||
)
|
||||
|
||||
return AmazonNovaCanvasTextToImageRequest(
|
||||
textToImageParams=text_to_image_params_typed,
|
||||
taskType=task_type,
|
||||
imageGenerationConfig=image_generation_config_typed,
|
||||
)
|
||||
if task_type == "COLOR_GUIDED_GENERATION":
|
||||
color_guided_generation_params = image_generation_config.pop("colorGuidedGenerationParams", {})
|
||||
color_guided_generation_params = {"text": text, **color_guided_generation_params}
|
||||
color_guided_generation_params = AmazonNovaCanvasColorGuidedGenerationParams(**color_guided_generation_params)
|
||||
return AmazonNovaCanvasColorGuidedRequest(taskType=task_type,
|
||||
colorGuidedGenerationParams=color_guided_generation_params,
|
||||
imageGenerationConfig=image_generation_config)
|
||||
color_guided_generation_params: Dict[
|
||||
str, Any
|
||||
] = image_generation_config.pop("colorGuidedGenerationParams", {})
|
||||
color_guided_generation_params = {
|
||||
"text": text,
|
||||
**color_guided_generation_params,
|
||||
}
|
||||
try:
|
||||
color_guided_generation_params_typed = AmazonNovaCanvasColorGuidedGenerationParams(
|
||||
**color_guided_generation_params # type: ignore
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Error transforming color guided generation params: {e}. Got params: {color_guided_generation_params}, Expected params: {AmazonNovaCanvasColorGuidedGenerationParams.__annotations__}"
|
||||
)
|
||||
|
||||
try:
|
||||
image_generation_config_typed = AmazonNovaCanvasImageGenerationConfig(
|
||||
**image_generation_config
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Error transforming image generation config: {e}. Got params: {image_generation_config}, Expected params: {AmazonNovaCanvasImageGenerationConfig.__annotations__}"
|
||||
)
|
||||
|
||||
return AmazonNovaCanvasColorGuidedRequest(
|
||||
taskType=task_type,
|
||||
colorGuidedGenerationParams=color_guided_generation_params_typed,
|
||||
imageGenerationConfig=image_generation_config_typed,
|
||||
)
|
||||
raise NotImplementedError(f"Task type {task_type} is not supported")
|
||||
|
||||
@classmethod
|
||||
|
@ -87,7 +136,9 @@ class AmazonNovaCanvasConfig:
|
|||
_size = non_default_params.get("size")
|
||||
if _size is not None:
|
||||
width, height = _size.split("x")
|
||||
optional_params["width"], optional_params["height"] = int(width), int(height)
|
||||
optional_params["width"], optional_params["height"] = int(width), int(
|
||||
height
|
||||
)
|
||||
if non_default_params.get("n") is not None:
|
||||
optional_params["numberOfImages"] = non_default_params.get("n")
|
||||
if non_default_params.get("quality") is not None:
|
||||
|
@ -99,7 +150,7 @@ class AmazonNovaCanvasConfig:
|
|||
|
||||
@classmethod
|
||||
def transform_response_dict_to_openai_response(
|
||||
cls, model_response: ImageResponse, response_dict: dict
|
||||
cls, model_response: ImageResponse, response_dict: dict
|
||||
) -> ImageResponse:
|
||||
"""
|
||||
Transform the response dict to the OpenAI response
|
||||
|
|
|
@ -267,7 +267,11 @@ class BedrockImageGeneration(BaseAWSLLM):
|
|||
**inference_params,
|
||||
}
|
||||
elif provider == "amazon":
|
||||
return dict(litellm.AmazonNovaCanvasConfig.transform_request_body(text=prompt, optional_params=optional_params))
|
||||
return dict(
|
||||
litellm.AmazonNovaCanvasConfig.transform_request_body(
|
||||
text=prompt, optional_params=optional_params
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise BedrockError(
|
||||
status_code=422, message=f"Unsupported model={model}, passed in"
|
||||
|
@ -303,8 +307,11 @@ class BedrockImageGeneration(BaseAWSLLM):
|
|||
config_class = (
|
||||
litellm.AmazonStability3Config
|
||||
if litellm.AmazonStability3Config._is_stability_3_model(model=model)
|
||||
else litellm.AmazonNovaCanvasConfig if litellm.AmazonNovaCanvasConfig._is_nova_model(model=model)
|
||||
else litellm.AmazonStabilityConfig
|
||||
else (
|
||||
litellm.AmazonNovaCanvasConfig
|
||||
if litellm.AmazonNovaCanvasConfig._is_nova_model(model=model)
|
||||
else litellm.AmazonStabilityConfig
|
||||
)
|
||||
)
|
||||
config_class.transform_response_dict_to_openai_response(
|
||||
model_response=model_response,
|
||||
|
|
|
@ -60,7 +60,6 @@ class BedrockRerankHandler(BaseAWSLLM):
|
|||
extra_headers: Optional[dict] = None,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
) -> RerankResponse:
|
||||
|
||||
request_data = RerankRequest(
|
||||
model=model,
|
||||
query=query,
|
||||
|
|
|
@ -29,7 +29,6 @@ from litellm.types.rerank import (
|
|||
|
||||
|
||||
class BedrockRerankConfig:
|
||||
|
||||
def _transform_sources(
|
||||
self, documents: List[Union[str, dict]]
|
||||
) -> List[BedrockRerankSource]:
|
||||
|
|
|
@ -314,7 +314,6 @@ class CodestralTextCompletion:
|
|||
return _response
|
||||
### SYNC COMPLETION
|
||||
else:
|
||||
|
||||
response = litellm.module_level_client.post(
|
||||
url=completion_url,
|
||||
headers=headers,
|
||||
|
@ -352,13 +351,11 @@ class CodestralTextCompletion:
|
|||
logger_fn=None,
|
||||
headers={},
|
||||
) -> TextCompletionResponse:
|
||||
|
||||
async_handler = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.TEXT_COMPLETION_CODESTRAL,
|
||||
params={"timeout": timeout},
|
||||
)
|
||||
try:
|
||||
|
||||
response = await async_handler.post(
|
||||
api_base, headers=headers, data=json.dumps(data)
|
||||
)
|
||||
|
|
|
@ -78,7 +78,6 @@ class CodestralTextCompletionConfig(OpenAITextCompletionConfig):
|
|||
return optional_params
|
||||
|
||||
def _chunk_parser(self, chunk_data: str) -> GenericStreamingChunk:
|
||||
|
||||
text = ""
|
||||
is_finished = False
|
||||
finish_reason = None
|
||||
|
|
|
@ -180,7 +180,6 @@ class CohereChatConfig(BaseConfig):
|
|||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
|
||||
## Load Config
|
||||
for k, v in litellm.CohereChatConfig.get_config().items():
|
||||
if (
|
||||
|
@ -222,7 +221,6 @@ class CohereChatConfig(BaseConfig):
|
|||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
|
||||
try:
|
||||
raw_response_json = raw_response.json()
|
||||
model_response.choices[0].message.content = raw_response_json["text"] # type: ignore
|
||||
|
|
|
@ -56,7 +56,6 @@ async def async_embedding(
|
|||
encoding: Callable,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
):
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
|
|
|
@ -72,7 +72,6 @@ class CohereEmbeddingConfig:
|
|||
return transformed_request
|
||||
|
||||
def _calculate_usage(self, input: List[str], encoding: Any, meta: dict) -> Usage:
|
||||
|
||||
input_tokens = 0
|
||||
|
||||
text_tokens: Optional[int] = meta.get("billed_units", {}).get("input_tokens")
|
||||
|
@ -111,7 +110,6 @@ class CohereEmbeddingConfig:
|
|||
encoding: Any,
|
||||
input: list,
|
||||
) -> EmbeddingResponse:
|
||||
|
||||
response_json = response.json()
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
|
|
|
@ -148,4 +148,4 @@ class CohereRerankConfig(BaseRerankConfig):
|
|||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return CohereError(message=error_message, status_code=status_code)
|
||||
return CohereError(message=error_message, status_code=status_code)
|
||||
|
|
|
@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Union
|
|||
from litellm.llms.cohere.rerank.transformation import CohereRerankConfig
|
||||
from litellm.types.rerank import OptionalRerankParams, RerankRequest
|
||||
|
||||
|
||||
class CohereRerankV2Config(CohereRerankConfig):
|
||||
"""
|
||||
Reference: https://docs.cohere.com/v2/reference/rerank
|
||||
|
@ -77,4 +78,4 @@ class CohereRerankV2Config(CohereRerankConfig):
|
|||
return_documents=optional_rerank_params.get("return_documents", None),
|
||||
max_tokens_per_doc=optional_rerank_params.get("max_tokens_per_doc", None),
|
||||
)
|
||||
return rerank_request.model_dump(exclude_none=True)
|
||||
return rerank_request.model_dump(exclude_none=True)
|
||||
|
|
|
@ -32,7 +32,6 @@ DEFAULT_TIMEOUT = 600
|
|||
|
||||
|
||||
class BaseLLMAIOHTTPHandler:
|
||||
|
||||
def __init__(self):
|
||||
self.client_session: Optional[aiohttp.ClientSession] = None
|
||||
|
||||
|
@ -110,7 +109,6 @@ class BaseLLMAIOHTTPHandler:
|
|||
content: Any = None,
|
||||
params: Optional[dict] = None,
|
||||
) -> httpx.Response:
|
||||
|
||||
max_retry_on_unprocessable_entity_error = (
|
||||
provider_config.max_retry_on_unprocessable_entity_error
|
||||
)
|
||||
|
|
|
@ -114,7 +114,6 @@ class AsyncHTTPHandler:
|
|||
event_hooks: Optional[Mapping[str, List[Callable[..., Any]]]],
|
||||
ssl_verify: Optional[VerifyTypes] = None,
|
||||
) -> httpx.AsyncClient:
|
||||
|
||||
# SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
|
||||
# /path/to/certificate.pem
|
||||
if ssl_verify is None:
|
||||
|
@ -590,7 +589,6 @@ class HTTPHandler:
|
|||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
):
|
||||
try:
|
||||
|
||||
if timeout is not None:
|
||||
req = self.client.build_request(
|
||||
"PATCH", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore
|
||||
|
@ -609,7 +607,6 @@ class HTTPHandler:
|
|||
llm_provider="litellm-httpx-handler",
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
|
||||
if stream is True:
|
||||
setattr(e, "message", mask_sensitive_info(e.response.read()))
|
||||
setattr(e, "text", mask_sensitive_info(e.response.read()))
|
||||
|
@ -635,7 +632,6 @@ class HTTPHandler:
|
|||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
):
|
||||
try:
|
||||
|
||||
if timeout is not None:
|
||||
req = self.client.build_request(
|
||||
"PUT", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore
|
||||
|
|
|
@ -41,7 +41,6 @@ else:
|
|||
|
||||
|
||||
class BaseLLMHTTPHandler:
|
||||
|
||||
async def _make_common_async_call(
|
||||
self,
|
||||
async_httpx_client: AsyncHTTPHandler,
|
||||
|
@ -109,7 +108,6 @@ class BaseLLMHTTPHandler:
|
|||
logging_obj: LiteLLMLoggingObj,
|
||||
stream: bool = False,
|
||||
) -> httpx.Response:
|
||||
|
||||
max_retry_on_unprocessable_entity_error = (
|
||||
provider_config.max_retry_on_unprocessable_entity_error
|
||||
)
|
||||
|
@ -599,7 +597,6 @@ class BaseLLMHTTPHandler:
|
|||
aembedding: bool = False,
|
||||
headers={},
|
||||
) -> EmbeddingResponse:
|
||||
|
||||
provider_config = ProviderConfigManager.get_provider_embedding_config(
|
||||
model=model, provider=litellm.LlmProviders(custom_llm_provider)
|
||||
)
|
||||
|
@ -742,7 +739,6 @@ class BaseLLMHTTPHandler:
|
|||
api_base: Optional[str] = None,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
) -> RerankResponse:
|
||||
|
||||
# get config from model, custom llm provider
|
||||
headers = provider_config.validate_environment(
|
||||
api_key=api_key,
|
||||
|
@ -828,7 +824,6 @@ class BaseLLMHTTPHandler:
|
|||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
) -> RerankResponse:
|
||||
|
||||
if client is None or not isinstance(client, AsyncHTTPHandler):
|
||||
async_httpx_client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders(custom_llm_provider)
|
||||
|
|
|
@ -16,9 +16,9 @@ class DatabricksBase:
|
|||
api_base = api_base or f"{databricks_client.config.host}/serving-endpoints"
|
||||
|
||||
if api_key is None:
|
||||
databricks_auth_headers: dict[str, str] = (
|
||||
databricks_client.config.authenticate()
|
||||
)
|
||||
databricks_auth_headers: dict[
|
||||
str, str
|
||||
] = databricks_client.config.authenticate()
|
||||
headers = {**databricks_auth_headers, **headers}
|
||||
|
||||
return api_base, headers
|
||||
|
|
|
@ -11,9 +11,9 @@ class DatabricksEmbeddingConfig:
|
|||
Reference: https://learn.microsoft.com/en-us/azure/databricks/machine-learning/foundation-models/api-reference#--embedding-task
|
||||
"""
|
||||
|
||||
instruction: Optional[str] = (
|
||||
None # An optional instruction to pass to the embedding model. BGE Authors recommend 'Represent this sentence for searching relevant passages:' for retrieval queries
|
||||
)
|
||||
instruction: Optional[
|
||||
str
|
||||
] = None # An optional instruction to pass to the embedding model. BGE Authors recommend 'Represent this sentence for searching relevant passages:' for retrieval queries
|
||||
|
||||
def __init__(self, instruction: Optional[str] = None) -> None:
|
||||
locals_ = locals().copy()
|
||||
|
|
|
@ -55,7 +55,6 @@ class ModelResponseIterator:
|
|||
|
||||
usage_chunk: Optional[Usage] = getattr(processed_chunk, "usage", None)
|
||||
if usage_chunk is not None:
|
||||
|
||||
usage = ChatCompletionUsageBlock(
|
||||
prompt_tokens=usage_chunk.prompt_tokens,
|
||||
completion_tokens=usage_chunk.completion_tokens,
|
||||
|
|
|
@ -126,9 +126,9 @@ class DeepgramAudioTranscriptionConfig(BaseAudioTranscriptionConfig):
|
|||
|
||||
# Add additional metadata matching OpenAI format
|
||||
response["task"] = "transcribe"
|
||||
response["language"] = (
|
||||
"english" # Deepgram auto-detects but doesn't return language
|
||||
)
|
||||
response[
|
||||
"language"
|
||||
] = "english" # Deepgram auto-detects but doesn't return language
|
||||
response["duration"] = response_json["metadata"]["duration"]
|
||||
|
||||
# Transform words to match OpenAI format
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue