mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
build(pyproject.toml): add new dev dependencies - for type checking (#9631)
* build(pyproject.toml): add new dev dependencies - for type checking * build: reformat files to fit black * ci: reformat to fit black * ci(test-litellm.yml): make tests run clear * build(pyproject.toml): add ruff * fix: fix ruff checks * build(mypy/): fix mypy linting errors * fix(hashicorp_secret_manager.py): fix passing cert for tls auth * build(mypy/): resolve all mypy errors * test: update test * fix: fix black formatting * build(pre-commit-config.yaml): use poetry run black * fix(proxy_server.py): fix linting error * fix: fix ruff safe representation error
This commit is contained in:
parent
72198737f8
commit
d7b294dd0a
214 changed files with 1553 additions and 1433 deletions
53
.github/workflows/test-linting.yml
vendored
Normal file
53
.github/workflows/test-linting.yml
vendored
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
name: LiteLLM Linting
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
branches: [ main ]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
lint:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
timeout-minutes: 5
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: '3.12'
|
||||||
|
|
||||||
|
- name: Install Poetry
|
||||||
|
uses: snok/install-poetry@v1
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
poetry install --with dev
|
||||||
|
|
||||||
|
- name: Run Black formatting check
|
||||||
|
run: |
|
||||||
|
cd litellm
|
||||||
|
poetry run black . --check
|
||||||
|
cd ..
|
||||||
|
|
||||||
|
- name: Run Ruff linting
|
||||||
|
run: |
|
||||||
|
cd litellm
|
||||||
|
poetry run ruff check .
|
||||||
|
cd ..
|
||||||
|
|
||||||
|
- name: Run MyPy type checking
|
||||||
|
run: |
|
||||||
|
cd litellm
|
||||||
|
poetry run mypy . --ignore-missing-imports
|
||||||
|
cd ..
|
||||||
|
|
||||||
|
- name: Check for circular imports
|
||||||
|
run: |
|
||||||
|
cd litellm
|
||||||
|
poetry run python ../tests/documentation_tests/test_circular_imports.py
|
||||||
|
cd ..
|
||||||
|
|
||||||
|
- name: Check import safety
|
||||||
|
run: |
|
||||||
|
poetry run python -c "from litellm import *" || (echo '🚨 import failed, this means you introduced unprotected imports! 🚨'; exit 1)
|
2
.github/workflows/test-litellm.yml
vendored
2
.github/workflows/test-litellm.yml
vendored
|
@ -1,4 +1,4 @@
|
||||||
name: LiteLLM Tests
|
name: LiteLLM Mock Tests (folder - tests/litellm)
|
||||||
|
|
||||||
on:
|
on:
|
||||||
pull_request:
|
pull_request:
|
||||||
|
|
|
@ -14,10 +14,12 @@ repos:
|
||||||
types: [python]
|
types: [python]
|
||||||
files: litellm/.*\.py
|
files: litellm/.*\.py
|
||||||
exclude: ^litellm/__init__.py$
|
exclude: ^litellm/__init__.py$
|
||||||
- repo: https://github.com/psf/black
|
|
||||||
rev: 24.2.0
|
|
||||||
hooks:
|
|
||||||
- id: black
|
- id: black
|
||||||
|
name: black
|
||||||
|
entry: poetry run black
|
||||||
|
language: system
|
||||||
|
types: [python]
|
||||||
|
files: litellm/.*\.py
|
||||||
- repo: https://github.com/pycqa/flake8
|
- repo: https://github.com/pycqa/flake8
|
||||||
rev: 7.0.0 # The version of flake8 to use
|
rev: 7.0.0 # The version of flake8 to use
|
||||||
hooks:
|
hooks:
|
||||||
|
|
|
@ -444,9 +444,7 @@ class _ENTERPRISE_SecretDetection(CustomGuardrail):
|
||||||
|
|
||||||
detected_secrets = []
|
detected_secrets = []
|
||||||
for file in secrets.files:
|
for file in secrets.files:
|
||||||
|
|
||||||
for found_secret in secrets[file]:
|
for found_secret in secrets[file]:
|
||||||
|
|
||||||
if found_secret.secret_value is None:
|
if found_secret.secret_value is None:
|
||||||
continue
|
continue
|
||||||
detected_secrets.append(
|
detected_secrets.append(
|
||||||
|
@ -471,14 +469,12 @@ class _ENTERPRISE_SecretDetection(CustomGuardrail):
|
||||||
data: dict,
|
data: dict,
|
||||||
call_type: str, # "completion", "embeddings", "image_generation", "moderation"
|
call_type: str, # "completion", "embeddings", "image_generation", "moderation"
|
||||||
):
|
):
|
||||||
|
|
||||||
if await self.should_run_check(user_api_key_dict) is False:
|
if await self.should_run_check(user_api_key_dict) is False:
|
||||||
return
|
return
|
||||||
|
|
||||||
if "messages" in data and isinstance(data["messages"], list):
|
if "messages" in data and isinstance(data["messages"], list):
|
||||||
for message in data["messages"]:
|
for message in data["messages"]:
|
||||||
if "content" in message and isinstance(message["content"], str):
|
if "content" in message and isinstance(message["content"], str):
|
||||||
|
|
||||||
detected_secrets = self.scan_message_for_secrets(message["content"])
|
detected_secrets = self.scan_message_for_secrets(message["content"])
|
||||||
|
|
||||||
for secret in detected_secrets:
|
for secret in detected_secrets:
|
||||||
|
|
|
@ -122,19 +122,19 @@ langsmith_batch_size: Optional[int] = None
|
||||||
prometheus_initialize_budget_metrics: Optional[bool] = False
|
prometheus_initialize_budget_metrics: Optional[bool] = False
|
||||||
argilla_batch_size: Optional[int] = None
|
argilla_batch_size: Optional[int] = None
|
||||||
datadog_use_v1: Optional[bool] = False # if you want to use v1 datadog logged payload
|
datadog_use_v1: Optional[bool] = False # if you want to use v1 datadog logged payload
|
||||||
gcs_pub_sub_use_v1: Optional[bool] = (
|
gcs_pub_sub_use_v1: Optional[
|
||||||
False # if you want to use v1 gcs pubsub logged payload
|
bool
|
||||||
)
|
] = False # if you want to use v1 gcs pubsub logged payload
|
||||||
argilla_transformation_object: Optional[Dict[str, Any]] = None
|
argilla_transformation_object: Optional[Dict[str, Any]] = None
|
||||||
_async_input_callback: List[Union[str, Callable, CustomLogger]] = (
|
_async_input_callback: List[
|
||||||
[]
|
Union[str, Callable, CustomLogger]
|
||||||
) # internal variable - async custom callbacks are routed here.
|
] = [] # internal variable - async custom callbacks are routed here.
|
||||||
_async_success_callback: List[Union[str, Callable, CustomLogger]] = (
|
_async_success_callback: List[
|
||||||
[]
|
Union[str, Callable, CustomLogger]
|
||||||
) # internal variable - async custom callbacks are routed here.
|
] = [] # internal variable - async custom callbacks are routed here.
|
||||||
_async_failure_callback: List[Union[str, Callable, CustomLogger]] = (
|
_async_failure_callback: List[
|
||||||
[]
|
Union[str, Callable, CustomLogger]
|
||||||
) # internal variable - async custom callbacks are routed here.
|
] = [] # internal variable - async custom callbacks are routed here.
|
||||||
pre_call_rules: List[Callable] = []
|
pre_call_rules: List[Callable] = []
|
||||||
post_call_rules: List[Callable] = []
|
post_call_rules: List[Callable] = []
|
||||||
turn_off_message_logging: Optional[bool] = False
|
turn_off_message_logging: Optional[bool] = False
|
||||||
|
@ -142,18 +142,18 @@ log_raw_request_response: bool = False
|
||||||
redact_messages_in_exceptions: Optional[bool] = False
|
redact_messages_in_exceptions: Optional[bool] = False
|
||||||
redact_user_api_key_info: Optional[bool] = False
|
redact_user_api_key_info: Optional[bool] = False
|
||||||
filter_invalid_headers: Optional[bool] = False
|
filter_invalid_headers: Optional[bool] = False
|
||||||
add_user_information_to_llm_headers: Optional[bool] = (
|
add_user_information_to_llm_headers: Optional[
|
||||||
None # adds user_id, team_id, token hash (params from StandardLoggingMetadata) to request headers
|
bool
|
||||||
)
|
] = None # adds user_id, team_id, token hash (params from StandardLoggingMetadata) to request headers
|
||||||
store_audit_logs = False # Enterprise feature, allow users to see audit logs
|
store_audit_logs = False # Enterprise feature, allow users to see audit logs
|
||||||
### end of callbacks #############
|
### end of callbacks #############
|
||||||
|
|
||||||
email: Optional[str] = (
|
email: Optional[
|
||||||
None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
str
|
||||||
)
|
] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||||
token: Optional[str] = (
|
token: Optional[
|
||||||
None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
str
|
||||||
)
|
] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||||
telemetry = True
|
telemetry = True
|
||||||
max_tokens = 256 # OpenAI Defaults
|
max_tokens = 256 # OpenAI Defaults
|
||||||
drop_params = bool(os.getenv("LITELLM_DROP_PARAMS", False))
|
drop_params = bool(os.getenv("LITELLM_DROP_PARAMS", False))
|
||||||
|
@ -229,24 +229,20 @@ enable_loadbalancing_on_batch_endpoints: Optional[bool] = None
|
||||||
enable_caching_on_provider_specific_optional_params: bool = (
|
enable_caching_on_provider_specific_optional_params: bool = (
|
||||||
False # feature-flag for caching on optional params - e.g. 'top_k'
|
False # feature-flag for caching on optional params - e.g. 'top_k'
|
||||||
)
|
)
|
||||||
caching: bool = (
|
caching: bool = False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||||
False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
caching_with_models: bool = False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||||
)
|
cache: Optional[
|
||||||
caching_with_models: bool = (
|
Cache
|
||||||
False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
] = None # cache object <- use this - https://docs.litellm.ai/docs/caching
|
||||||
)
|
|
||||||
cache: Optional[Cache] = (
|
|
||||||
None # cache object <- use this - https://docs.litellm.ai/docs/caching
|
|
||||||
)
|
|
||||||
default_in_memory_ttl: Optional[float] = None
|
default_in_memory_ttl: Optional[float] = None
|
||||||
default_redis_ttl: Optional[float] = None
|
default_redis_ttl: Optional[float] = None
|
||||||
default_redis_batch_cache_expiry: Optional[float] = None
|
default_redis_batch_cache_expiry: Optional[float] = None
|
||||||
model_alias_map: Dict[str, str] = {}
|
model_alias_map: Dict[str, str] = {}
|
||||||
model_group_alias_map: Dict[str, str] = {}
|
model_group_alias_map: Dict[str, str] = {}
|
||||||
max_budget: float = 0.0 # set the max budget across all providers
|
max_budget: float = 0.0 # set the max budget across all providers
|
||||||
budget_duration: Optional[str] = (
|
budget_duration: Optional[
|
||||||
None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
|
str
|
||||||
)
|
] = None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
|
||||||
default_soft_budget: float = (
|
default_soft_budget: float = (
|
||||||
50.0 # by default all litellm proxy keys have a soft budget of 50.0
|
50.0 # by default all litellm proxy keys have a soft budget of 50.0
|
||||||
)
|
)
|
||||||
|
@ -255,15 +251,11 @@ forward_traceparent_to_llm_provider: bool = False
|
||||||
|
|
||||||
_current_cost = 0.0 # private variable, used if max budget is set
|
_current_cost = 0.0 # private variable, used if max budget is set
|
||||||
error_logs: Dict = {}
|
error_logs: Dict = {}
|
||||||
add_function_to_prompt: bool = (
|
add_function_to_prompt: bool = False # if function calling not supported by api, append function call details to system prompt
|
||||||
False # if function calling not supported by api, append function call details to system prompt
|
|
||||||
)
|
|
||||||
client_session: Optional[httpx.Client] = None
|
client_session: Optional[httpx.Client] = None
|
||||||
aclient_session: Optional[httpx.AsyncClient] = None
|
aclient_session: Optional[httpx.AsyncClient] = None
|
||||||
model_fallbacks: Optional[List] = None # Deprecated for 'litellm.fallbacks'
|
model_fallbacks: Optional[List] = None # Deprecated for 'litellm.fallbacks'
|
||||||
model_cost_map_url: str = (
|
model_cost_map_url: str = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
|
||||||
"https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
|
|
||||||
)
|
|
||||||
suppress_debug_info = False
|
suppress_debug_info = False
|
||||||
dynamodb_table_name: Optional[str] = None
|
dynamodb_table_name: Optional[str] = None
|
||||||
s3_callback_params: Optional[Dict] = None
|
s3_callback_params: Optional[Dict] = None
|
||||||
|
@ -285,9 +277,7 @@ disable_end_user_cost_tracking_prometheus_only: Optional[bool] = None
|
||||||
custom_prometheus_metadata_labels: List[str] = []
|
custom_prometheus_metadata_labels: List[str] = []
|
||||||
#### REQUEST PRIORITIZATION ####
|
#### REQUEST PRIORITIZATION ####
|
||||||
priority_reservation: Optional[Dict[str, float]] = None
|
priority_reservation: Optional[Dict[str, float]] = None
|
||||||
force_ipv4: bool = (
|
force_ipv4: bool = False # when True, litellm will force ipv4 for all LLM requests. Some users have seen httpx ConnectionError when using ipv6.
|
||||||
False # when True, litellm will force ipv4 for all LLM requests. Some users have seen httpx ConnectionError when using ipv6.
|
|
||||||
)
|
|
||||||
module_level_aclient = AsyncHTTPHandler(
|
module_level_aclient = AsyncHTTPHandler(
|
||||||
timeout=request_timeout, client_alias="module level aclient"
|
timeout=request_timeout, client_alias="module level aclient"
|
||||||
)
|
)
|
||||||
|
@ -301,13 +291,13 @@ fallbacks: Optional[List] = None
|
||||||
context_window_fallbacks: Optional[List] = None
|
context_window_fallbacks: Optional[List] = None
|
||||||
content_policy_fallbacks: Optional[List] = None
|
content_policy_fallbacks: Optional[List] = None
|
||||||
allowed_fails: int = 3
|
allowed_fails: int = 3
|
||||||
num_retries_per_request: Optional[int] = (
|
num_retries_per_request: Optional[
|
||||||
None # for the request overall (incl. fallbacks + model retries)
|
int
|
||||||
)
|
] = None # for the request overall (incl. fallbacks + model retries)
|
||||||
####### SECRET MANAGERS #####################
|
####### SECRET MANAGERS #####################
|
||||||
secret_manager_client: Optional[Any] = (
|
secret_manager_client: Optional[
|
||||||
None # list of instantiated key management clients - e.g. azure kv, infisical, etc.
|
Any
|
||||||
)
|
] = None # list of instantiated key management clients - e.g. azure kv, infisical, etc.
|
||||||
_google_kms_resource_name: Optional[str] = None
|
_google_kms_resource_name: Optional[str] = None
|
||||||
_key_management_system: Optional[KeyManagementSystem] = None
|
_key_management_system: Optional[KeyManagementSystem] = None
|
||||||
_key_management_settings: KeyManagementSettings = KeyManagementSettings()
|
_key_management_settings: KeyManagementSettings = KeyManagementSettings()
|
||||||
|
@ -1056,10 +1046,10 @@ from .types.llms.custom_llm import CustomLLMItem
|
||||||
from .types.utils import GenericStreamingChunk
|
from .types.utils import GenericStreamingChunk
|
||||||
|
|
||||||
custom_provider_map: List[CustomLLMItem] = []
|
custom_provider_map: List[CustomLLMItem] = []
|
||||||
_custom_providers: List[str] = (
|
_custom_providers: List[
|
||||||
[]
|
str
|
||||||
) # internal helper util, used to track names of custom providers
|
] = [] # internal helper util, used to track names of custom providers
|
||||||
disable_hf_tokenizer_download: Optional[bool] = (
|
disable_hf_tokenizer_download: Optional[
|
||||||
None # disable huggingface tokenizer download. Defaults to openai clk100
|
bool
|
||||||
)
|
] = None # disable huggingface tokenizer download. Defaults to openai clk100
|
||||||
global_disable_no_log_param: bool = False
|
global_disable_no_log_param: bool = False
|
||||||
|
|
|
@ -15,7 +15,7 @@ from .types.services import ServiceLoggerPayload, ServiceTypes
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from opentelemetry.trace import Span as _Span
|
from opentelemetry.trace import Span as _Span
|
||||||
|
|
||||||
Span = _Span
|
Span = Union[_Span, Any]
|
||||||
OTELClass = OpenTelemetry
|
OTELClass = OpenTelemetry
|
||||||
else:
|
else:
|
||||||
Span = Any
|
Span = Any
|
||||||
|
|
|
@ -153,7 +153,6 @@ def create_batch(
|
||||||
)
|
)
|
||||||
api_base: Optional[str] = None
|
api_base: Optional[str] = None
|
||||||
if custom_llm_provider == "openai":
|
if custom_llm_provider == "openai":
|
||||||
|
|
||||||
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||||
api_base = (
|
api_base = (
|
||||||
optional_params.api_base
|
optional_params.api_base
|
||||||
|
@ -358,7 +357,6 @@ def retrieve_batch(
|
||||||
_is_async = kwargs.pop("aretrieve_batch", False) is True
|
_is_async = kwargs.pop("aretrieve_batch", False) is True
|
||||||
api_base: Optional[str] = None
|
api_base: Optional[str] = None
|
||||||
if custom_llm_provider == "openai":
|
if custom_llm_provider == "openai":
|
||||||
|
|
||||||
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||||
api_base = (
|
api_base = (
|
||||||
optional_params.api_base
|
optional_params.api_base
|
||||||
|
|
|
@ -9,12 +9,12 @@ Has 4 methods:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from opentelemetry.trace import Span as _Span
|
from opentelemetry.trace import Span as _Span
|
||||||
|
|
||||||
Span = _Span
|
Span = Union[_Span, Any]
|
||||||
else:
|
else:
|
||||||
Span = Any
|
Span = Any
|
||||||
|
|
||||||
|
|
|
@ -66,9 +66,7 @@ class CachingHandlerResponse(BaseModel):
|
||||||
|
|
||||||
cached_result: Optional[Any] = None
|
cached_result: Optional[Any] = None
|
||||||
final_embedding_cached_response: Optional[EmbeddingResponse] = None
|
final_embedding_cached_response: Optional[EmbeddingResponse] = None
|
||||||
embedding_all_elements_cache_hit: bool = (
|
embedding_all_elements_cache_hit: bool = False # this is set to True when all elements in the list have a cache hit in the embedding cache, if true return the final_embedding_cached_response no need to make an API call
|
||||||
False # this is set to True when all elements in the list have a cache hit in the embedding cache, if true return the final_embedding_cached_response no need to make an API call
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LLMCachingHandler:
|
class LLMCachingHandler:
|
||||||
|
@ -738,7 +736,6 @@ class LLMCachingHandler:
|
||||||
if self._should_store_result_in_cache(
|
if self._should_store_result_in_cache(
|
||||||
original_function=self.original_function, kwargs=new_kwargs
|
original_function=self.original_function, kwargs=new_kwargs
|
||||||
):
|
):
|
||||||
|
|
||||||
litellm.cache.add_cache(result, **new_kwargs)
|
litellm.cache.add_cache(result, **new_kwargs)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
@ -865,9 +862,9 @@ class LLMCachingHandler:
|
||||||
}
|
}
|
||||||
|
|
||||||
if litellm.cache is not None:
|
if litellm.cache is not None:
|
||||||
litellm_params["preset_cache_key"] = (
|
litellm_params[
|
||||||
litellm.cache._get_preset_cache_key_from_kwargs(**kwargs)
|
"preset_cache_key"
|
||||||
)
|
] = litellm.cache._get_preset_cache_key_from_kwargs(**kwargs)
|
||||||
else:
|
else:
|
||||||
litellm_params["preset_cache_key"] = None
|
litellm_params["preset_cache_key"] = None
|
||||||
|
|
||||||
|
|
|
@ -1,12 +1,12 @@
|
||||||
import json
|
import json
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||||
|
|
||||||
from .base_cache import BaseCache
|
from .base_cache import BaseCache
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from opentelemetry.trace import Span as _Span
|
from opentelemetry.trace import Span as _Span
|
||||||
|
|
||||||
Span = _Span
|
Span = Union[_Span, Any]
|
||||||
else:
|
else:
|
||||||
Span = Any
|
Span = Any
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,7 @@ import asyncio
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import TYPE_CHECKING, Any, List, Optional
|
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import print_verbose, verbose_logger
|
from litellm._logging import print_verbose, verbose_logger
|
||||||
|
@ -24,7 +24,7 @@ from .redis_cache import RedisCache
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from opentelemetry.trace import Span as _Span
|
from opentelemetry.trace import Span as _Span
|
||||||
|
|
||||||
Span = _Span
|
Span = Union[_Span, Any]
|
||||||
else:
|
else:
|
||||||
Span = Any
|
Span = Any
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,6 @@ from .in_memory_cache import InMemoryCache
|
||||||
|
|
||||||
|
|
||||||
class LLMClientCache(InMemoryCache):
|
class LLMClientCache(InMemoryCache):
|
||||||
|
|
||||||
def update_cache_key_with_event_loop(self, key):
|
def update_cache_key_with_event_loop(self, key):
|
||||||
"""
|
"""
|
||||||
Add the event loop to the cache key, to prevent event loop closed errors.
|
Add the event loop to the cache key, to prevent event loop closed errors.
|
||||||
|
|
|
@ -34,7 +34,7 @@ if TYPE_CHECKING:
|
||||||
cluster_pipeline = ClusterPipeline
|
cluster_pipeline = ClusterPipeline
|
||||||
async_redis_client = Redis
|
async_redis_client = Redis
|
||||||
async_redis_cluster_client = RedisCluster
|
async_redis_cluster_client = RedisCluster
|
||||||
Span = _Span
|
Span = Union[_Span, Any]
|
||||||
else:
|
else:
|
||||||
pipeline = Any
|
pipeline = Any
|
||||||
cluster_pipeline = Any
|
cluster_pipeline = Any
|
||||||
|
@ -57,7 +57,6 @@ class RedisCache(BaseCache):
|
||||||
socket_timeout: Optional[float] = 5.0, # default 5 second timeout
|
socket_timeout: Optional[float] = 5.0, # default 5 second timeout
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
|
||||||
from litellm._service_logger import ServiceLogging
|
from litellm._service_logger import ServiceLogging
|
||||||
|
|
||||||
from .._redis import get_redis_client, get_redis_connection_pool
|
from .._redis import get_redis_client, get_redis_connection_pool
|
||||||
|
|
|
@ -5,7 +5,7 @@ Key differences:
|
||||||
- RedisClient NEEDs to be re-used across requests, adds 3000ms latency if it's re-created
|
- RedisClient NEEDs to be re-used across requests, adds 3000ms latency if it's re-created
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Any, List, Optional
|
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||||
|
|
||||||
from litellm.caching.redis_cache import RedisCache
|
from litellm.caching.redis_cache import RedisCache
|
||||||
|
|
||||||
|
@ -16,7 +16,7 @@ if TYPE_CHECKING:
|
||||||
|
|
||||||
pipeline = Pipeline
|
pipeline = Pipeline
|
||||||
async_redis_client = Redis
|
async_redis_client = Redis
|
||||||
Span = _Span
|
Span = Union[_Span, Any]
|
||||||
else:
|
else:
|
||||||
pipeline = Any
|
pipeline = Any
|
||||||
async_redis_client = Any
|
async_redis_client = Any
|
||||||
|
|
|
@ -13,11 +13,15 @@ import ast
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import print_verbose
|
from litellm._logging import print_verbose
|
||||||
from litellm.litellm_core_utils.prompt_templates.common_utils import get_str_from_messages
|
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||||
|
get_str_from_messages,
|
||||||
|
)
|
||||||
|
from litellm.types.utils import EmbeddingResponse
|
||||||
|
|
||||||
from .base_cache import BaseCache
|
from .base_cache import BaseCache
|
||||||
|
|
||||||
|
|
||||||
|
@ -87,14 +91,16 @@ class RedisSemanticCache(BaseCache):
|
||||||
if redis_url is None:
|
if redis_url is None:
|
||||||
try:
|
try:
|
||||||
# Attempt to use provided parameters or fallback to environment variables
|
# Attempt to use provided parameters or fallback to environment variables
|
||||||
host = host or os.environ['REDIS_HOST']
|
host = host or os.environ["REDIS_HOST"]
|
||||||
port = port or os.environ['REDIS_PORT']
|
port = port or os.environ["REDIS_PORT"]
|
||||||
password = password or os.environ['REDIS_PASSWORD']
|
password = password or os.environ["REDIS_PASSWORD"]
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
# Raise a more informative exception if any of the required keys are missing
|
# Raise a more informative exception if any of the required keys are missing
|
||||||
missing_var = e.args[0]
|
missing_var = e.args[0]
|
||||||
raise ValueError(f"Missing required Redis configuration: {missing_var}. "
|
raise ValueError(
|
||||||
f"Provide {missing_var} or redis_url.") from e
|
f"Missing required Redis configuration: {missing_var}. "
|
||||||
|
f"Provide {missing_var} or redis_url."
|
||||||
|
) from e
|
||||||
|
|
||||||
redis_url = f"redis://:{password}@{host}:{port}"
|
redis_url = f"redis://:{password}@{host}:{port}"
|
||||||
|
|
||||||
|
@ -137,10 +143,13 @@ class RedisSemanticCache(BaseCache):
|
||||||
List[float]: The embedding vector
|
List[float]: The embedding vector
|
||||||
"""
|
"""
|
||||||
# Create an embedding from prompt
|
# Create an embedding from prompt
|
||||||
embedding_response = litellm.embedding(
|
embedding_response = cast(
|
||||||
|
EmbeddingResponse,
|
||||||
|
litellm.embedding(
|
||||||
model=self.embedding_model,
|
model=self.embedding_model,
|
||||||
input=prompt,
|
input=prompt,
|
||||||
cache={"no-store": True, "no-cache": True},
|
cache={"no-store": True, "no-cache": True},
|
||||||
|
),
|
||||||
)
|
)
|
||||||
embedding = embedding_response["data"][0]["embedding"]
|
embedding = embedding_response["data"][0]["embedding"]
|
||||||
return embedding
|
return embedding
|
||||||
|
@ -186,6 +195,7 @@ class RedisSemanticCache(BaseCache):
|
||||||
"""
|
"""
|
||||||
print_verbose(f"Redis semantic-cache set_cache, kwargs: {kwargs}")
|
print_verbose(f"Redis semantic-cache set_cache, kwargs: {kwargs}")
|
||||||
|
|
||||||
|
value_str: Optional[str] = None
|
||||||
try:
|
try:
|
||||||
# Extract the prompt from messages
|
# Extract the prompt from messages
|
||||||
messages = kwargs.get("messages", [])
|
messages = kwargs.get("messages", [])
|
||||||
|
@ -203,7 +213,9 @@ class RedisSemanticCache(BaseCache):
|
||||||
else:
|
else:
|
||||||
self.llmcache.store(prompt, value_str)
|
self.llmcache.store(prompt, value_str)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print_verbose(f"Error setting {value_str} in the Redis semantic cache: {str(e)}")
|
print_verbose(
|
||||||
|
f"Error setting {value_str or value} in the Redis semantic cache: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
def get_cache(self, key: str, **kwargs) -> Any:
|
def get_cache(self, key: str, **kwargs) -> Any:
|
||||||
"""
|
"""
|
||||||
|
@ -336,13 +348,13 @@ class RedisSemanticCache(BaseCache):
|
||||||
prompt,
|
prompt,
|
||||||
value_str,
|
value_str,
|
||||||
vector=prompt_embedding, # Pass through custom embedding
|
vector=prompt_embedding, # Pass through custom embedding
|
||||||
ttl=ttl
|
ttl=ttl,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
await self.llmcache.astore(
|
await self.llmcache.astore(
|
||||||
prompt,
|
prompt,
|
||||||
value_str,
|
value_str,
|
||||||
vector=prompt_embedding # Pass through custom embedding
|
vector=prompt_embedding, # Pass through custom embedding
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print_verbose(f"Error in async_set_cache: {str(e)}")
|
print_verbose(f"Error in async_set_cache: {str(e)}")
|
||||||
|
@ -374,14 +386,13 @@ class RedisSemanticCache(BaseCache):
|
||||||
prompt_embedding = await self._get_async_embedding(prompt, **kwargs)
|
prompt_embedding = await self._get_async_embedding(prompt, **kwargs)
|
||||||
|
|
||||||
# Check the cache for semantically similar prompts
|
# Check the cache for semantically similar prompts
|
||||||
results = await self.llmcache.acheck(
|
results = await self.llmcache.acheck(prompt=prompt, vector=prompt_embedding)
|
||||||
prompt=prompt,
|
|
||||||
vector=prompt_embedding
|
|
||||||
)
|
|
||||||
|
|
||||||
# handle results / cache hit
|
# handle results / cache hit
|
||||||
if not results:
|
if not results:
|
||||||
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0 # TODO why here but not above??
|
kwargs.setdefault("metadata", {})[
|
||||||
|
"semantic-similarity"
|
||||||
|
] = 0.0 # TODO why here but not above??
|
||||||
return None
|
return None
|
||||||
|
|
||||||
cache_hit = results[0]
|
cache_hit = results[0]
|
||||||
|
@ -420,7 +431,9 @@ class RedisSemanticCache(BaseCache):
|
||||||
aindex = await self.llmcache._get_async_index()
|
aindex = await self.llmcache._get_async_index()
|
||||||
return await aindex.info()
|
return await aindex.info()
|
||||||
|
|
||||||
async def async_set_cache_pipeline(self, cache_list: List[Tuple[str, Any]], **kwargs) -> None:
|
async def async_set_cache_pipeline(
|
||||||
|
self, cache_list: List[Tuple[str, Any]], **kwargs
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Asynchronously store multiple values in the semantic cache.
|
Asynchronously store multiple values in the semantic cache.
|
||||||
|
|
||||||
|
|
|
@ -123,7 +123,7 @@ class S3Cache(BaseCache):
|
||||||
) # Convert string to dictionary
|
) # Convert string to dictionary
|
||||||
except Exception:
|
except Exception:
|
||||||
cached_response = ast.literal_eval(cached_response)
|
cached_response = ast.literal_eval(cached_response)
|
||||||
if type(cached_response) is not dict:
|
if not isinstance(cached_response, dict):
|
||||||
cached_response = dict(cached_response)
|
cached_response = dict(cached_response)
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
f"Got S3 Cache: key: {key}, cached_response {cached_response}. Type Response {type(cached_response)}"
|
f"Got S3 Cache: key: {key}, cached_response {cached_response}. Type Response {type(cached_response)}"
|
||||||
|
|
|
@ -580,7 +580,6 @@ def completion_cost( # noqa: PLR0915
|
||||||
- For un-mapped Replicate models, the cost is calculated based on the total time used for the request.
|
- For un-mapped Replicate models, the cost is calculated based on the total time used for the request.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
|
||||||
call_type = _infer_call_type(call_type, completion_response) or "completion"
|
call_type = _infer_call_type(call_type, completion_response) or "completion"
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
|
|
@ -138,7 +138,6 @@ def create_fine_tuning_job(
|
||||||
|
|
||||||
# OpenAI
|
# OpenAI
|
||||||
if custom_llm_provider == "openai":
|
if custom_llm_provider == "openai":
|
||||||
|
|
||||||
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||||
api_base = (
|
api_base = (
|
||||||
optional_params.api_base
|
optional_params.api_base
|
||||||
|
@ -360,7 +359,6 @@ def cancel_fine_tuning_job(
|
||||||
|
|
||||||
# OpenAI
|
# OpenAI
|
||||||
if custom_llm_provider == "openai":
|
if custom_llm_provider == "openai":
|
||||||
|
|
||||||
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||||
api_base = (
|
api_base = (
|
||||||
optional_params.api_base
|
optional_params.api_base
|
||||||
|
@ -522,7 +520,6 @@ def list_fine_tuning_jobs(
|
||||||
|
|
||||||
# OpenAI
|
# OpenAI
|
||||||
if custom_llm_provider == "openai":
|
if custom_llm_provider == "openai":
|
||||||
|
|
||||||
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||||
api_base = (
|
api_base = (
|
||||||
optional_params.api_base
|
optional_params.api_base
|
||||||
|
|
|
@ -19,7 +19,6 @@ else:
|
||||||
|
|
||||||
|
|
||||||
def squash_payloads(queue):
|
def squash_payloads(queue):
|
||||||
|
|
||||||
squashed = {}
|
squashed = {}
|
||||||
if len(queue) == 0:
|
if len(queue) == 0:
|
||||||
return squashed
|
return squashed
|
||||||
|
|
|
@ -195,13 +195,16 @@ class SlackAlerting(CustomBatchLogger):
|
||||||
if self.alerting is None or self.alert_types is None:
|
if self.alerting is None or self.alert_types is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
time_difference_float, model, api_base, messages = (
|
(
|
||||||
self._response_taking_too_long_callback_helper(
|
time_difference_float,
|
||||||
|
model,
|
||||||
|
api_base,
|
||||||
|
messages,
|
||||||
|
) = self._response_taking_too_long_callback_helper(
|
||||||
kwargs=kwargs,
|
kwargs=kwargs,
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
if litellm.turn_off_message_logging or litellm.redact_messages_in_exceptions:
|
if litellm.turn_off_message_logging or litellm.redact_messages_in_exceptions:
|
||||||
messages = "Message not logged. litellm.redact_messages_in_exceptions=True"
|
messages = "Message not logged. litellm.redact_messages_in_exceptions=True"
|
||||||
request_info = f"\nRequest Model: `{model}`\nAPI Base: `{api_base}`\nMessages: `{messages}`"
|
request_info = f"\nRequest Model: `{model}`\nAPI Base: `{api_base}`\nMessages: `{messages}`"
|
||||||
|
@ -819,9 +822,9 @@ class SlackAlerting(CustomBatchLogger):
|
||||||
### UNIQUE CACHE KEY ###
|
### UNIQUE CACHE KEY ###
|
||||||
cache_key = provider + region_name
|
cache_key = provider + region_name
|
||||||
|
|
||||||
outage_value: Optional[ProviderRegionOutageModel] = (
|
outage_value: Optional[
|
||||||
await self.internal_usage_cache.async_get_cache(key=cache_key)
|
ProviderRegionOutageModel
|
||||||
)
|
] = await self.internal_usage_cache.async_get_cache(key=cache_key)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
getattr(exception, "status_code", None) is None
|
getattr(exception, "status_code", None) is None
|
||||||
|
@ -1402,9 +1405,9 @@ Model Info:
|
||||||
self.alert_to_webhook_url is not None
|
self.alert_to_webhook_url is not None
|
||||||
and alert_type in self.alert_to_webhook_url
|
and alert_type in self.alert_to_webhook_url
|
||||||
):
|
):
|
||||||
slack_webhook_url: Optional[Union[str, List[str]]] = (
|
slack_webhook_url: Optional[
|
||||||
self.alert_to_webhook_url[alert_type]
|
Union[str, List[str]]
|
||||||
)
|
] = self.alert_to_webhook_url[alert_type]
|
||||||
elif self.default_webhook_url is not None:
|
elif self.default_webhook_url is not None:
|
||||||
slack_webhook_url = self.default_webhook_url
|
slack_webhook_url = self.default_webhook_url
|
||||||
else:
|
else:
|
||||||
|
@ -1768,7 +1771,6 @@ Model Info:
|
||||||
- Team Created, Updated, Deleted
|
- Team Created, Updated, Deleted
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
|
||||||
message = f"`{event_name}`\n"
|
message = f"`{event_name}`\n"
|
||||||
|
|
||||||
key_event_dict = key_event.model_dump()
|
key_event_dict = key_event.model_dump()
|
||||||
|
|
|
@ -98,7 +98,6 @@ class ArgillaLogger(CustomBatchLogger):
|
||||||
argilla_dataset_name: Optional[str],
|
argilla_dataset_name: Optional[str],
|
||||||
argilla_base_url: Optional[str],
|
argilla_base_url: Optional[str],
|
||||||
) -> ArgillaCredentialsObject:
|
) -> ArgillaCredentialsObject:
|
||||||
|
|
||||||
_credentials_api_key = argilla_api_key or os.getenv("ARGILLA_API_KEY")
|
_credentials_api_key = argilla_api_key or os.getenv("ARGILLA_API_KEY")
|
||||||
if _credentials_api_key is None:
|
if _credentials_api_key is None:
|
||||||
raise Exception("Invalid Argilla API Key given. _credentials_api_key=None.")
|
raise Exception("Invalid Argilla API Key given. _credentials_api_key=None.")
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||||
|
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||||
|
@ -7,7 +7,7 @@ from litellm.types.utils import StandardLoggingPayload
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from opentelemetry.trace import Span as _Span
|
from opentelemetry.trace import Span as _Span
|
||||||
|
|
||||||
Span = _Span
|
Span = Union[_Span, Any]
|
||||||
else:
|
else:
|
||||||
Span = Any
|
Span = Any
|
||||||
|
|
||||||
|
|
|
@ -19,14 +19,13 @@ if TYPE_CHECKING:
|
||||||
from litellm.types.integrations.arize import Protocol as _Protocol
|
from litellm.types.integrations.arize import Protocol as _Protocol
|
||||||
|
|
||||||
Protocol = _Protocol
|
Protocol = _Protocol
|
||||||
Span = _Span
|
Span = Union[_Span, Any]
|
||||||
else:
|
else:
|
||||||
Protocol = Any
|
Protocol = Any
|
||||||
Span = Any
|
Span = Any
|
||||||
|
|
||||||
|
|
||||||
class ArizeLogger(OpenTelemetry):
|
class ArizeLogger(OpenTelemetry):
|
||||||
|
|
||||||
def set_attributes(self, span: Span, kwargs, response_obj: Optional[Any]):
|
def set_attributes(self, span: Span, kwargs, response_obj: Optional[Any]):
|
||||||
ArizeLogger.set_arize_attributes(span, kwargs, response_obj)
|
ArizeLogger.set_arize_attributes(span, kwargs, response_obj)
|
||||||
return
|
return
|
||||||
|
|
|
@ -1,17 +1,20 @@
|
||||||
import os
|
import os
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any, Union
|
||||||
from litellm.integrations.arize import _utils
|
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
|
from litellm.integrations.arize import _utils
|
||||||
from litellm.types.integrations.arize_phoenix import ArizePhoenixConfig
|
from litellm.types.integrations.arize_phoenix import ArizePhoenixConfig
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .opentelemetry import OpenTelemetryConfig as _OpenTelemetryConfig
|
|
||||||
from litellm.types.integrations.arize import Protocol as _Protocol
|
|
||||||
from opentelemetry.trace import Span as _Span
|
from opentelemetry.trace import Span as _Span
|
||||||
|
|
||||||
|
from litellm.types.integrations.arize import Protocol as _Protocol
|
||||||
|
|
||||||
|
from .opentelemetry import OpenTelemetryConfig as _OpenTelemetryConfig
|
||||||
|
|
||||||
Protocol = _Protocol
|
Protocol = _Protocol
|
||||||
OpenTelemetryConfig = _OpenTelemetryConfig
|
OpenTelemetryConfig = _OpenTelemetryConfig
|
||||||
Span = _Span
|
Span = Union[_Span, Any]
|
||||||
else:
|
else:
|
||||||
Protocol = Any
|
Protocol = Any
|
||||||
OpenTelemetryConfig = Any
|
OpenTelemetryConfig = Any
|
||||||
|
@ -20,6 +23,7 @@ else:
|
||||||
|
|
||||||
ARIZE_HOSTED_PHOENIX_ENDPOINT = "https://app.phoenix.arize.com/v1/traces"
|
ARIZE_HOSTED_PHOENIX_ENDPOINT = "https://app.phoenix.arize.com/v1/traces"
|
||||||
|
|
||||||
|
|
||||||
class ArizePhoenixLogger:
|
class ArizePhoenixLogger:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def set_arize_phoenix_attributes(span: Span, kwargs, response_obj):
|
def set_arize_phoenix_attributes(span: Span, kwargs, response_obj):
|
||||||
|
@ -59,15 +63,14 @@ class ArizePhoenixLogger:
|
||||||
# a slightly different auth header format than self hosted phoenix
|
# a slightly different auth header format than self hosted phoenix
|
||||||
if endpoint == ARIZE_HOSTED_PHOENIX_ENDPOINT:
|
if endpoint == ARIZE_HOSTED_PHOENIX_ENDPOINT:
|
||||||
if api_key is None:
|
if api_key is None:
|
||||||
raise ValueError("PHOENIX_API_KEY must be set when the Arize hosted Phoenix endpoint is used.")
|
raise ValueError(
|
||||||
|
"PHOENIX_API_KEY must be set when the Arize hosted Phoenix endpoint is used."
|
||||||
|
)
|
||||||
otlp_auth_headers = f"api_key={api_key}"
|
otlp_auth_headers = f"api_key={api_key}"
|
||||||
elif api_key is not None:
|
elif api_key is not None:
|
||||||
# api_key/auth is optional for self hosted phoenix
|
# api_key/auth is optional for self hosted phoenix
|
||||||
otlp_auth_headers = f"Authorization=Bearer {api_key}"
|
otlp_auth_headers = f"Authorization=Bearer {api_key}"
|
||||||
|
|
||||||
return ArizePhoenixConfig(
|
return ArizePhoenixConfig(
|
||||||
otlp_auth_headers=otlp_auth_headers,
|
otlp_auth_headers=otlp_auth_headers, protocol=protocol, endpoint=endpoint
|
||||||
protocol=protocol,
|
|
||||||
endpoint=endpoint
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,10 @@ class AthinaLogger:
|
||||||
"athina-api-key": self.athina_api_key,
|
"athina-api-key": self.athina_api_key,
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
}
|
}
|
||||||
self.athina_logging_url = os.getenv("ATHINA_BASE_URL", "https://log.athina.ai") + "/api/v1/log/inference"
|
self.athina_logging_url = (
|
||||||
|
os.getenv("ATHINA_BASE_URL", "https://log.athina.ai")
|
||||||
|
+ "/api/v1/log/inference"
|
||||||
|
)
|
||||||
self.additional_keys = [
|
self.additional_keys = [
|
||||||
"environment",
|
"environment",
|
||||||
"prompt_slug",
|
"prompt_slug",
|
||||||
|
|
|
@ -50,12 +50,12 @@ class AzureBlobStorageLogger(CustomBatchLogger):
|
||||||
self.azure_storage_file_system: str = _azure_storage_file_system
|
self.azure_storage_file_system: str = _azure_storage_file_system
|
||||||
|
|
||||||
# Internal variables used for Token based authentication
|
# Internal variables used for Token based authentication
|
||||||
self.azure_auth_token: Optional[str] = (
|
self.azure_auth_token: Optional[
|
||||||
None # the Azure AD token to use for Azure Storage API requests
|
str
|
||||||
)
|
] = None # the Azure AD token to use for Azure Storage API requests
|
||||||
self.token_expiry: Optional[datetime] = (
|
self.token_expiry: Optional[
|
||||||
None # the expiry time of the currentAzure AD token
|
datetime
|
||||||
)
|
] = None # the expiry time of the currentAzure AD token
|
||||||
|
|
||||||
asyncio.create_task(self.periodic_flush())
|
asyncio.create_task(self.periodic_flush())
|
||||||
self.flush_lock = asyncio.Lock()
|
self.flush_lock = asyncio.Lock()
|
||||||
|
@ -153,7 +153,6 @@ class AzureBlobStorageLogger(CustomBatchLogger):
|
||||||
3. Flush the data
|
3. Flush the data
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
|
||||||
if self.azure_storage_account_key:
|
if self.azure_storage_account_key:
|
||||||
await self.upload_to_azure_data_lake_with_azure_account_key(
|
await self.upload_to_azure_data_lake_with_azure_account_key(
|
||||||
payload=payload
|
payload=payload
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
import copy
|
import copy
|
||||||
import os
|
import os
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional, Dict
|
from typing import Dict, Optional
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
@ -19,7 +19,9 @@ from litellm.llms.custom_httpx.http_handler import (
|
||||||
)
|
)
|
||||||
from litellm.utils import print_verbose
|
from litellm.utils import print_verbose
|
||||||
|
|
||||||
global_braintrust_http_handler = get_async_httpx_client(llm_provider=httpxSpecialProvider.LoggingCallback)
|
global_braintrust_http_handler = get_async_httpx_client(
|
||||||
|
llm_provider=httpxSpecialProvider.LoggingCallback
|
||||||
|
)
|
||||||
global_braintrust_sync_http_handler = HTTPHandler()
|
global_braintrust_sync_http_handler = HTTPHandler()
|
||||||
API_BASE = "https://api.braintrustdata.com/v1"
|
API_BASE = "https://api.braintrustdata.com/v1"
|
||||||
|
|
||||||
|
@ -35,7 +37,9 @@ def get_utc_datetime():
|
||||||
|
|
||||||
|
|
||||||
class BraintrustLogger(CustomLogger):
|
class BraintrustLogger(CustomLogger):
|
||||||
def __init__(self, api_key: Optional[str] = None, api_base: Optional[str] = None) -> None:
|
def __init__(
|
||||||
|
self, api_key: Optional[str] = None, api_base: Optional[str] = None
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.validate_environment(api_key=api_key)
|
self.validate_environment(api_key=api_key)
|
||||||
self.api_base = api_base or API_BASE
|
self.api_base = api_base or API_BASE
|
||||||
|
@ -45,7 +49,9 @@ class BraintrustLogger(CustomLogger):
|
||||||
"Authorization": "Bearer " + self.api_key,
|
"Authorization": "Bearer " + self.api_key,
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
}
|
}
|
||||||
self._project_id_cache: Dict[str, str] = {} # Cache mapping project names to IDs
|
self._project_id_cache: Dict[
|
||||||
|
str, str
|
||||||
|
] = {} # Cache mapping project names to IDs
|
||||||
|
|
||||||
def validate_environment(self, api_key: Optional[str]):
|
def validate_environment(self, api_key: Optional[str]):
|
||||||
"""
|
"""
|
||||||
|
@ -71,7 +77,9 @@ class BraintrustLogger(CustomLogger):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = global_braintrust_sync_http_handler.post(
|
response = global_braintrust_sync_http_handler.post(
|
||||||
f"{self.api_base}/project", headers=self.headers, json={"name": project_name}
|
f"{self.api_base}/project",
|
||||||
|
headers=self.headers,
|
||||||
|
json={"name": project_name},
|
||||||
)
|
)
|
||||||
project_dict = response.json()
|
project_dict = response.json()
|
||||||
project_id = project_dict["id"]
|
project_id = project_dict["id"]
|
||||||
|
@ -89,7 +97,9 @@ class BraintrustLogger(CustomLogger):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await global_braintrust_http_handler.post(
|
response = await global_braintrust_http_handler.post(
|
||||||
f"{self.api_base}/project/register", headers=self.headers, json={"name": project_name}
|
f"{self.api_base}/project/register",
|
||||||
|
headers=self.headers,
|
||||||
|
json={"name": project_name},
|
||||||
)
|
)
|
||||||
project_dict = response.json()
|
project_dict = response.json()
|
||||||
project_id = project_dict["id"]
|
project_id = project_dict["id"]
|
||||||
|
@ -116,15 +126,21 @@ class BraintrustLogger(CustomLogger):
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
metadata = {}
|
metadata = {}
|
||||||
|
|
||||||
proxy_headers = litellm_params.get("proxy_server_request", {}).get("headers", {}) or {}
|
proxy_headers = (
|
||||||
|
litellm_params.get("proxy_server_request", {}).get("headers", {}) or {}
|
||||||
|
)
|
||||||
|
|
||||||
for metadata_param_key in proxy_headers:
|
for metadata_param_key in proxy_headers:
|
||||||
if metadata_param_key.startswith("braintrust"):
|
if metadata_param_key.startswith("braintrust"):
|
||||||
trace_param_key = metadata_param_key.replace("braintrust", "", 1)
|
trace_param_key = metadata_param_key.replace("braintrust", "", 1)
|
||||||
if trace_param_key in metadata:
|
if trace_param_key in metadata:
|
||||||
verbose_logger.warning(f"Overwriting Braintrust `{trace_param_key}` from request header")
|
verbose_logger.warning(
|
||||||
|
f"Overwriting Braintrust `{trace_param_key}` from request header"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
verbose_logger.debug(f"Found Braintrust `{trace_param_key}` in request header")
|
verbose_logger.debug(
|
||||||
|
f"Found Braintrust `{trace_param_key}` in request header"
|
||||||
|
)
|
||||||
metadata[trace_param_key] = proxy_headers.get(metadata_param_key)
|
metadata[trace_param_key] = proxy_headers.get(metadata_param_key)
|
||||||
|
|
||||||
return metadata
|
return metadata
|
||||||
|
@ -157,24 +173,35 @@ class BraintrustLogger(CustomLogger):
|
||||||
output = None
|
output = None
|
||||||
choices = []
|
choices = []
|
||||||
if response_obj is not None and (
|
if response_obj is not None and (
|
||||||
kwargs.get("call_type", None) == "embedding" or isinstance(response_obj, litellm.EmbeddingResponse)
|
kwargs.get("call_type", None) == "embedding"
|
||||||
|
or isinstance(response_obj, litellm.EmbeddingResponse)
|
||||||
):
|
):
|
||||||
output = None
|
output = None
|
||||||
elif response_obj is not None and isinstance(response_obj, litellm.ModelResponse):
|
elif response_obj is not None and isinstance(
|
||||||
|
response_obj, litellm.ModelResponse
|
||||||
|
):
|
||||||
output = response_obj["choices"][0]["message"].json()
|
output = response_obj["choices"][0]["message"].json()
|
||||||
choices = response_obj["choices"]
|
choices = response_obj["choices"]
|
||||||
elif response_obj is not None and isinstance(response_obj, litellm.TextCompletionResponse):
|
elif response_obj is not None and isinstance(
|
||||||
|
response_obj, litellm.TextCompletionResponse
|
||||||
|
):
|
||||||
output = response_obj.choices[0].text
|
output = response_obj.choices[0].text
|
||||||
choices = response_obj.choices
|
choices = response_obj.choices
|
||||||
elif response_obj is not None and isinstance(response_obj, litellm.ImageResponse):
|
elif response_obj is not None and isinstance(
|
||||||
|
response_obj, litellm.ImageResponse
|
||||||
|
):
|
||||||
output = response_obj["data"]
|
output = response_obj["data"]
|
||||||
|
|
||||||
litellm_params = kwargs.get("litellm_params", {})
|
litellm_params = kwargs.get("litellm_params", {})
|
||||||
metadata = litellm_params.get("metadata", {}) or {} # if litellm_params['metadata'] == None
|
metadata = (
|
||||||
|
litellm_params.get("metadata", {}) or {}
|
||||||
|
) # if litellm_params['metadata'] == None
|
||||||
metadata = self.add_metadata_from_header(litellm_params, metadata)
|
metadata = self.add_metadata_from_header(litellm_params, metadata)
|
||||||
clean_metadata = {}
|
clean_metadata = {}
|
||||||
try:
|
try:
|
||||||
metadata = copy.deepcopy(metadata) # Avoid modifying the original metadata
|
metadata = copy.deepcopy(
|
||||||
|
metadata
|
||||||
|
) # Avoid modifying the original metadata
|
||||||
except Exception:
|
except Exception:
|
||||||
new_metadata = {}
|
new_metadata = {}
|
||||||
for key, value in metadata.items():
|
for key, value in metadata.items():
|
||||||
|
@ -192,7 +219,9 @@ class BraintrustLogger(CustomLogger):
|
||||||
project_id = metadata.get("project_id")
|
project_id = metadata.get("project_id")
|
||||||
if project_id is None:
|
if project_id is None:
|
||||||
project_name = metadata.get("project_name")
|
project_name = metadata.get("project_name")
|
||||||
project_id = self.get_project_id_sync(project_name) if project_name else None
|
project_id = (
|
||||||
|
self.get_project_id_sync(project_name) if project_name else None
|
||||||
|
)
|
||||||
|
|
||||||
if project_id is None:
|
if project_id is None:
|
||||||
if self.default_project_id is None:
|
if self.default_project_id is None:
|
||||||
|
@ -234,7 +263,8 @@ class BraintrustLogger(CustomLogger):
|
||||||
"completion_tokens": usage_obj.completion_tokens,
|
"completion_tokens": usage_obj.completion_tokens,
|
||||||
"total_tokens": usage_obj.total_tokens,
|
"total_tokens": usage_obj.total_tokens,
|
||||||
"total_cost": cost,
|
"total_cost": cost,
|
||||||
"time_to_first_token": end_time.timestamp() - start_time.timestamp(),
|
"time_to_first_token": end_time.timestamp()
|
||||||
|
- start_time.timestamp(),
|
||||||
"start": start_time.timestamp(),
|
"start": start_time.timestamp(),
|
||||||
"end": end_time.timestamp(),
|
"end": end_time.timestamp(),
|
||||||
}
|
}
|
||||||
|
@ -255,7 +285,9 @@ class BraintrustLogger(CustomLogger):
|
||||||
request_data["metrics"] = metrics
|
request_data["metrics"] = metrics
|
||||||
|
|
||||||
try:
|
try:
|
||||||
print_verbose(f"global_braintrust_sync_http_handler.post: {global_braintrust_sync_http_handler.post}")
|
print_verbose(
|
||||||
|
f"global_braintrust_sync_http_handler.post: {global_braintrust_sync_http_handler.post}"
|
||||||
|
)
|
||||||
global_braintrust_sync_http_handler.post(
|
global_braintrust_sync_http_handler.post(
|
||||||
url=f"{self.api_base}/project_logs/{project_id}/insert",
|
url=f"{self.api_base}/project_logs/{project_id}/insert",
|
||||||
json={"events": [request_data]},
|
json={"events": [request_data]},
|
||||||
|
@ -276,20 +308,29 @@ class BraintrustLogger(CustomLogger):
|
||||||
output = None
|
output = None
|
||||||
choices = []
|
choices = []
|
||||||
if response_obj is not None and (
|
if response_obj is not None and (
|
||||||
kwargs.get("call_type", None) == "embedding" or isinstance(response_obj, litellm.EmbeddingResponse)
|
kwargs.get("call_type", None) == "embedding"
|
||||||
|
or isinstance(response_obj, litellm.EmbeddingResponse)
|
||||||
):
|
):
|
||||||
output = None
|
output = None
|
||||||
elif response_obj is not None and isinstance(response_obj, litellm.ModelResponse):
|
elif response_obj is not None and isinstance(
|
||||||
|
response_obj, litellm.ModelResponse
|
||||||
|
):
|
||||||
output = response_obj["choices"][0]["message"].json()
|
output = response_obj["choices"][0]["message"].json()
|
||||||
choices = response_obj["choices"]
|
choices = response_obj["choices"]
|
||||||
elif response_obj is not None and isinstance(response_obj, litellm.TextCompletionResponse):
|
elif response_obj is not None and isinstance(
|
||||||
|
response_obj, litellm.TextCompletionResponse
|
||||||
|
):
|
||||||
output = response_obj.choices[0].text
|
output = response_obj.choices[0].text
|
||||||
choices = response_obj.choices
|
choices = response_obj.choices
|
||||||
elif response_obj is not None and isinstance(response_obj, litellm.ImageResponse):
|
elif response_obj is not None and isinstance(
|
||||||
|
response_obj, litellm.ImageResponse
|
||||||
|
):
|
||||||
output = response_obj["data"]
|
output = response_obj["data"]
|
||||||
|
|
||||||
litellm_params = kwargs.get("litellm_params", {})
|
litellm_params = kwargs.get("litellm_params", {})
|
||||||
metadata = litellm_params.get("metadata", {}) or {} # if litellm_params['metadata'] == None
|
metadata = (
|
||||||
|
litellm_params.get("metadata", {}) or {}
|
||||||
|
) # if litellm_params['metadata'] == None
|
||||||
metadata = self.add_metadata_from_header(litellm_params, metadata)
|
metadata = self.add_metadata_from_header(litellm_params, metadata)
|
||||||
clean_metadata = {}
|
clean_metadata = {}
|
||||||
new_metadata = {}
|
new_metadata = {}
|
||||||
|
@ -313,7 +354,11 @@ class BraintrustLogger(CustomLogger):
|
||||||
project_id = metadata.get("project_id")
|
project_id = metadata.get("project_id")
|
||||||
if project_id is None:
|
if project_id is None:
|
||||||
project_name = metadata.get("project_name")
|
project_name = metadata.get("project_name")
|
||||||
project_id = await self.get_project_id_async(project_name) if project_name else None
|
project_id = (
|
||||||
|
await self.get_project_id_async(project_name)
|
||||||
|
if project_name
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
if project_id is None:
|
if project_id is None:
|
||||||
if self.default_project_id is None:
|
if self.default_project_id is None:
|
||||||
|
@ -362,8 +407,14 @@ class BraintrustLogger(CustomLogger):
|
||||||
api_call_start_time = kwargs.get("api_call_start_time")
|
api_call_start_time = kwargs.get("api_call_start_time")
|
||||||
completion_start_time = kwargs.get("completion_start_time")
|
completion_start_time = kwargs.get("completion_start_time")
|
||||||
|
|
||||||
if api_call_start_time is not None and completion_start_time is not None:
|
if (
|
||||||
metrics["time_to_first_token"] = completion_start_time.timestamp() - api_call_start_time.timestamp()
|
api_call_start_time is not None
|
||||||
|
and completion_start_time is not None
|
||||||
|
):
|
||||||
|
metrics["time_to_first_token"] = (
|
||||||
|
completion_start_time.timestamp()
|
||||||
|
- api_call_start_time.timestamp()
|
||||||
|
)
|
||||||
|
|
||||||
request_data = {
|
request_data = {
|
||||||
"id": litellm_call_id,
|
"id": litellm_call_id,
|
||||||
|
|
|
@ -14,7 +14,6 @@ from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
|
||||||
|
|
||||||
class CustomBatchLogger(CustomLogger):
|
class CustomBatchLogger(CustomLogger):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
flush_lock: Optional[asyncio.Lock] = None,
|
flush_lock: Optional[asyncio.Lock] = None,
|
||||||
|
|
|
@ -7,7 +7,6 @@ from litellm.types.utils import StandardLoggingGuardrailInformation
|
||||||
|
|
||||||
|
|
||||||
class CustomGuardrail(CustomLogger):
|
class CustomGuardrail(CustomLogger):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
guardrail_name: Optional[str] = None,
|
guardrail_name: Optional[str] = None,
|
||||||
|
|
|
@ -31,7 +31,7 @@ from litellm.types.utils import (
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from opentelemetry.trace import Span as _Span
|
from opentelemetry.trace import Span as _Span
|
||||||
|
|
||||||
Span = _Span
|
Span = Union[_Span, Any]
|
||||||
else:
|
else:
|
||||||
Span = Any
|
Span = Any
|
||||||
|
|
||||||
|
|
|
@ -233,7 +233,6 @@ class DataDogLogger(
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def _log_async_event(self, kwargs, response_obj, start_time, end_time):
|
async def _log_async_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
|
||||||
dd_payload = self.create_datadog_logging_payload(
|
dd_payload = self.create_datadog_logging_payload(
|
||||||
kwargs=kwargs,
|
kwargs=kwargs,
|
||||||
response_obj=response_obj,
|
response_obj=response_obj,
|
||||||
|
|
|
@ -125,9 +125,9 @@ class GCSBucketBase(CustomBatchLogger):
|
||||||
if kwargs is None:
|
if kwargs is None:
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
|
|
||||||
standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = (
|
standard_callback_dynamic_params: Optional[
|
||||||
kwargs.get("standard_callback_dynamic_params", None)
|
StandardCallbackDynamicParams
|
||||||
)
|
] = kwargs.get("standard_callback_dynamic_params", None)
|
||||||
|
|
||||||
bucket_name: str
|
bucket_name: str
|
||||||
path_service_account: Optional[str]
|
path_service_account: Optional[str]
|
||||||
|
|
|
@ -70,13 +70,14 @@ class GcsPubSubLogger(CustomBatchLogger):
|
||||||
"""Construct authorization headers using Vertex AI auth"""
|
"""Construct authorization headers using Vertex AI auth"""
|
||||||
from litellm import vertex_chat_completion
|
from litellm import vertex_chat_completion
|
||||||
|
|
||||||
_auth_header, vertex_project = (
|
(
|
||||||
await vertex_chat_completion._ensure_access_token_async(
|
_auth_header,
|
||||||
|
vertex_project,
|
||||||
|
) = await vertex_chat_completion._ensure_access_token_async(
|
||||||
credentials=self.path_service_account_json,
|
credentials=self.path_service_account_json,
|
||||||
project_id=None,
|
project_id=None,
|
||||||
custom_llm_provider="vertex_ai",
|
custom_llm_provider="vertex_ai",
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
auth_header, _ = vertex_chat_completion._get_token_and_url(
|
auth_header, _ = vertex_chat_completion._get_token_and_url(
|
||||||
model="pub-sub",
|
model="pub-sub",
|
||||||
|
|
|
@ -155,11 +155,7 @@ class HumanloopLogger(CustomLogger):
|
||||||
prompt_id: str,
|
prompt_id: str,
|
||||||
prompt_variables: Optional[dict],
|
prompt_variables: Optional[dict],
|
||||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||||
) -> Tuple[
|
) -> Tuple[str, List[AllMessageValues], dict,]:
|
||||||
str,
|
|
||||||
List[AllMessageValues],
|
|
||||||
dict,
|
|
||||||
]:
|
|
||||||
humanloop_api_key = dynamic_callback_params.get(
|
humanloop_api_key = dynamic_callback_params.get(
|
||||||
"humanloop_api_key"
|
"humanloop_api_key"
|
||||||
) or get_secret_str("HUMANLOOP_API_KEY")
|
) or get_secret_str("HUMANLOOP_API_KEY")
|
||||||
|
|
|
@ -471,9 +471,9 @@ class LangFuseLogger:
|
||||||
# we clean out all extra litellm metadata params before logging
|
# we clean out all extra litellm metadata params before logging
|
||||||
clean_metadata: Dict[str, Any] = {}
|
clean_metadata: Dict[str, Any] = {}
|
||||||
if prompt_management_metadata is not None:
|
if prompt_management_metadata is not None:
|
||||||
clean_metadata["prompt_management_metadata"] = (
|
clean_metadata[
|
||||||
prompt_management_metadata
|
"prompt_management_metadata"
|
||||||
)
|
] = prompt_management_metadata
|
||||||
if isinstance(metadata, dict):
|
if isinstance(metadata, dict):
|
||||||
for key, value in metadata.items():
|
for key, value in metadata.items():
|
||||||
# generate langfuse tags - Default Tags sent to Langfuse from LiteLLM Proxy
|
# generate langfuse tags - Default Tags sent to Langfuse from LiteLLM Proxy
|
||||||
|
|
|
@ -19,7 +19,6 @@ else:
|
||||||
|
|
||||||
|
|
||||||
class LangFuseHandler:
|
class LangFuseHandler:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_langfuse_logger_for_request(
|
def get_langfuse_logger_for_request(
|
||||||
standard_callback_dynamic_params: StandardCallbackDynamicParams,
|
standard_callback_dynamic_params: StandardCallbackDynamicParams,
|
||||||
|
@ -87,7 +86,9 @@ class LangFuseHandler:
|
||||||
if globalLangfuseLogger is not None:
|
if globalLangfuseLogger is not None:
|
||||||
return globalLangfuseLogger
|
return globalLangfuseLogger
|
||||||
|
|
||||||
credentials_dict: Dict[str, Any] = (
|
credentials_dict: Dict[
|
||||||
|
str, Any
|
||||||
|
] = (
|
||||||
{}
|
{}
|
||||||
) # the global langfuse logger uses Environment Variables, there are no dynamic credentials
|
) # the global langfuse logger uses Environment Variables, there are no dynamic credentials
|
||||||
globalLangfuseLogger = in_memory_dynamic_logger_cache.get_cache(
|
globalLangfuseLogger = in_memory_dynamic_logger_cache.get_cache(
|
||||||
|
|
|
@ -172,11 +172,7 @@ class LangfusePromptManagement(LangFuseLogger, PromptManagementBase, CustomLogge
|
||||||
prompt_id: str,
|
prompt_id: str,
|
||||||
prompt_variables: Optional[dict],
|
prompt_variables: Optional[dict],
|
||||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||||
) -> Tuple[
|
) -> Tuple[str, List[AllMessageValues], dict,]:
|
||||||
str,
|
|
||||||
List[AllMessageValues],
|
|
||||||
dict,
|
|
||||||
]:
|
|
||||||
return self.get_chat_completion_prompt(
|
return self.get_chat_completion_prompt(
|
||||||
model,
|
model,
|
||||||
messages,
|
messages,
|
||||||
|
|
|
@ -75,7 +75,6 @@ class LangsmithLogger(CustomBatchLogger):
|
||||||
langsmith_project: Optional[str] = None,
|
langsmith_project: Optional[str] = None,
|
||||||
langsmith_base_url: Optional[str] = None,
|
langsmith_base_url: Optional[str] = None,
|
||||||
) -> LangsmithCredentialsObject:
|
) -> LangsmithCredentialsObject:
|
||||||
|
|
||||||
_credentials_api_key = langsmith_api_key or os.getenv("LANGSMITH_API_KEY")
|
_credentials_api_key = langsmith_api_key or os.getenv("LANGSMITH_API_KEY")
|
||||||
if _credentials_api_key is None:
|
if _credentials_api_key is None:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
|
@ -443,9 +442,9 @@ class LangsmithLogger(CustomBatchLogger):
|
||||||
|
|
||||||
Otherwise, use the default credentials.
|
Otherwise, use the default credentials.
|
||||||
"""
|
"""
|
||||||
standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = (
|
standard_callback_dynamic_params: Optional[
|
||||||
kwargs.get("standard_callback_dynamic_params", None)
|
StandardCallbackDynamicParams
|
||||||
)
|
] = kwargs.get("standard_callback_dynamic_params", None)
|
||||||
if standard_callback_dynamic_params is not None:
|
if standard_callback_dynamic_params is not None:
|
||||||
credentials = self.get_credentials_from_env(
|
credentials = self.get_credentials_from_env(
|
||||||
langsmith_api_key=standard_callback_dynamic_params.get(
|
langsmith_api_key=standard_callback_dynamic_params.get(
|
||||||
|
@ -481,7 +480,6 @@ class LangsmithLogger(CustomBatchLogger):
|
||||||
asyncio.run(self.async_send_batch())
|
asyncio.run(self.async_send_batch())
|
||||||
|
|
||||||
def get_run_by_id(self, run_id):
|
def get_run_by_id(self, run_id):
|
||||||
|
|
||||||
langsmith_api_key = self.default_credentials["LANGSMITH_API_KEY"]
|
langsmith_api_key = self.default_credentials["LANGSMITH_API_KEY"]
|
||||||
|
|
||||||
langsmith_api_base = self.default_credentials["LANGSMITH_BASE_URL"]
|
langsmith_api_base = self.default_credentials["LANGSMITH_BASE_URL"]
|
||||||
|
|
|
@ -1,12 +1,12 @@
|
||||||
import json
|
import json
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any, Union
|
||||||
|
|
||||||
from litellm.proxy._types import SpanAttributes
|
from litellm.proxy._types import SpanAttributes
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from opentelemetry.trace import Span as _Span
|
from opentelemetry.trace import Span as _Span
|
||||||
|
|
||||||
Span = _Span
|
Span = Union[_Span, Any]
|
||||||
else:
|
else:
|
||||||
Span = Any
|
Span = Any
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,6 @@ def parse_tool_calls(tool_calls):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def clean_tool_call(tool_call):
|
def clean_tool_call(tool_call):
|
||||||
|
|
||||||
serialized = {
|
serialized = {
|
||||||
"type": tool_call.type,
|
"type": tool_call.type,
|
||||||
"id": tool_call.id,
|
"id": tool_call.id,
|
||||||
|
@ -36,7 +35,6 @@ def parse_tool_calls(tool_calls):
|
||||||
|
|
||||||
|
|
||||||
def parse_messages(input):
|
def parse_messages(input):
|
||||||
|
|
||||||
if input is None:
|
if input is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
|
@ -48,14 +48,17 @@ class MlflowLogger(CustomLogger):
|
||||||
|
|
||||||
def _extract_and_set_chat_attributes(self, span, kwargs, response_obj):
|
def _extract_and_set_chat_attributes(self, span, kwargs, response_obj):
|
||||||
try:
|
try:
|
||||||
from mlflow.tracing.utils import set_span_chat_messages, set_span_chat_tools
|
from mlflow.tracing.utils import set_span_chat_messages # type: ignore
|
||||||
|
from mlflow.tracing.utils import set_span_chat_tools # type: ignore
|
||||||
except ImportError:
|
except ImportError:
|
||||||
return
|
return
|
||||||
|
|
||||||
inputs = self._construct_input(kwargs)
|
inputs = self._construct_input(kwargs)
|
||||||
input_messages = inputs.get("messages", [])
|
input_messages = inputs.get("messages", [])
|
||||||
output_messages = [c.message.model_dump(exclude_none=True)
|
output_messages = [
|
||||||
for c in getattr(response_obj, "choices", [])]
|
c.message.model_dump(exclude_none=True)
|
||||||
|
for c in getattr(response_obj, "choices", [])
|
||||||
|
]
|
||||||
if messages := [*input_messages, *output_messages]:
|
if messages := [*input_messages, *output_messages]:
|
||||||
set_span_chat_messages(span, messages)
|
set_span_chat_messages(span, messages)
|
||||||
if tools := inputs.get("tools"):
|
if tools := inputs.get("tools"):
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
|
@ -23,10 +23,10 @@ if TYPE_CHECKING:
|
||||||
)
|
)
|
||||||
from litellm.proxy.proxy_server import UserAPIKeyAuth as _UserAPIKeyAuth
|
from litellm.proxy.proxy_server import UserAPIKeyAuth as _UserAPIKeyAuth
|
||||||
|
|
||||||
Span = _Span
|
Span = Union[_Span, Any]
|
||||||
SpanExporter = _SpanExporter
|
SpanExporter = Union[_SpanExporter, Any]
|
||||||
UserAPIKeyAuth = _UserAPIKeyAuth
|
UserAPIKeyAuth = Union[_UserAPIKeyAuth, Any]
|
||||||
ManagementEndpointLoggingPayload = _ManagementEndpointLoggingPayload
|
ManagementEndpointLoggingPayload = Union[_ManagementEndpointLoggingPayload, Any]
|
||||||
else:
|
else:
|
||||||
Span = Any
|
Span = Any
|
||||||
SpanExporter = Any
|
SpanExporter = Any
|
||||||
|
@ -46,7 +46,6 @@ LITELLM_REQUEST_SPAN_NAME = "litellm_request"
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class OpenTelemetryConfig:
|
class OpenTelemetryConfig:
|
||||||
|
|
||||||
exporter: Union[str, SpanExporter] = "console"
|
exporter: Union[str, SpanExporter] = "console"
|
||||||
endpoint: Optional[str] = None
|
endpoint: Optional[str] = None
|
||||||
headers: Optional[str] = None
|
headers: Optional[str] = None
|
||||||
|
@ -154,7 +153,6 @@ class OpenTelemetry(CustomLogger):
|
||||||
end_time: Optional[Union[datetime, float]] = None,
|
end_time: Optional[Union[datetime, float]] = None,
|
||||||
event_metadata: Optional[dict] = None,
|
event_metadata: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from opentelemetry.trace import Status, StatusCode
|
from opentelemetry.trace import Status, StatusCode
|
||||||
|
|
||||||
|
@ -215,7 +213,6 @@ class OpenTelemetry(CustomLogger):
|
||||||
end_time: Optional[Union[float, datetime]] = None,
|
end_time: Optional[Union[float, datetime]] = None,
|
||||||
event_metadata: Optional[dict] = None,
|
event_metadata: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from opentelemetry.trace import Status, StatusCode
|
from opentelemetry.trace import Status, StatusCode
|
||||||
|
|
||||||
|
@ -353,9 +350,9 @@ class OpenTelemetry(CustomLogger):
|
||||||
"""
|
"""
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
|
|
||||||
standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = (
|
standard_callback_dynamic_params: Optional[
|
||||||
kwargs.get("standard_callback_dynamic_params")
|
StandardCallbackDynamicParams
|
||||||
)
|
] = kwargs.get("standard_callback_dynamic_params")
|
||||||
if not standard_callback_dynamic_params:
|
if not standard_callback_dynamic_params:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -722,7 +719,6 @@ class OpenTelemetry(CustomLogger):
|
||||||
span.set_attribute(key, primitive_value)
|
span.set_attribute(key, primitive_value)
|
||||||
|
|
||||||
def set_raw_request_attributes(self, span: Span, kwargs, response_obj):
|
def set_raw_request_attributes(self, span: Span, kwargs, response_obj):
|
||||||
|
|
||||||
kwargs.get("optional_params", {})
|
kwargs.get("optional_params", {})
|
||||||
litellm_params = kwargs.get("litellm_params", {}) or {}
|
litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||||
custom_llm_provider = litellm_params.get("custom_llm_provider", "Unknown")
|
custom_llm_provider = litellm_params.get("custom_llm_provider", "Unknown")
|
||||||
|
@ -843,12 +839,14 @@ class OpenTelemetry(CustomLogger):
|
||||||
headers=dynamic_headers or self.OTEL_HEADERS
|
headers=dynamic_headers or self.OTEL_HEADERS
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(self.OTEL_EXPORTER, SpanExporter):
|
if hasattr(
|
||||||
|
self.OTEL_EXPORTER, "export"
|
||||||
|
): # Check if it has the export method that SpanExporter requires
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
"OpenTelemetry: intiializing SpanExporter. Value of OTEL_EXPORTER: %s",
|
"OpenTelemetry: intiializing SpanExporter. Value of OTEL_EXPORTER: %s",
|
||||||
self.OTEL_EXPORTER,
|
self.OTEL_EXPORTER,
|
||||||
)
|
)
|
||||||
return SimpleSpanProcessor(self.OTEL_EXPORTER)
|
return SimpleSpanProcessor(cast(SpanExporter, self.OTEL_EXPORTER))
|
||||||
|
|
||||||
if self.OTEL_EXPORTER == "console":
|
if self.OTEL_EXPORTER == "console":
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
|
@ -907,7 +905,6 @@ class OpenTelemetry(CustomLogger):
|
||||||
logging_payload: ManagementEndpointLoggingPayload,
|
logging_payload: ManagementEndpointLoggingPayload,
|
||||||
parent_otel_span: Optional[Span] = None,
|
parent_otel_span: Optional[Span] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from opentelemetry.trace import Status, StatusCode
|
from opentelemetry.trace import Status, StatusCode
|
||||||
|
|
||||||
|
@ -961,7 +958,6 @@ class OpenTelemetry(CustomLogger):
|
||||||
logging_payload: ManagementEndpointLoggingPayload,
|
logging_payload: ManagementEndpointLoggingPayload,
|
||||||
parent_otel_span: Optional[Span] = None,
|
parent_otel_span: Optional[Span] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from opentelemetry.trace import Status, StatusCode
|
from opentelemetry.trace import Status, StatusCode
|
||||||
|
|
||||||
|
|
|
@ -185,7 +185,6 @@ class OpikLogger(CustomBatchLogger):
|
||||||
def _create_opik_payload( # noqa: PLR0915
|
def _create_opik_payload( # noqa: PLR0915
|
||||||
self, kwargs, response_obj, start_time, end_time
|
self, kwargs, response_obj, start_time, end_time
|
||||||
) -> List[Dict]:
|
) -> List[Dict]:
|
||||||
|
|
||||||
# Get metadata
|
# Get metadata
|
||||||
_litellm_params = kwargs.get("litellm_params", {}) or {}
|
_litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||||
litellm_params_metadata = _litellm_params.get("metadata", {}) or {}
|
litellm_params_metadata = _litellm_params.get("metadata", {}) or {}
|
||||||
|
|
|
@ -988,9 +988,9 @@ class PrometheusLogger(CustomLogger):
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
verbose_logger.debug("setting remaining tokens requests metric")
|
verbose_logger.debug("setting remaining tokens requests metric")
|
||||||
standard_logging_payload: Optional[StandardLoggingPayload] = (
|
standard_logging_payload: Optional[
|
||||||
request_kwargs.get("standard_logging_object")
|
StandardLoggingPayload
|
||||||
)
|
] = request_kwargs.get("standard_logging_object")
|
||||||
|
|
||||||
if standard_logging_payload is None:
|
if standard_logging_payload is None:
|
||||||
return
|
return
|
||||||
|
|
|
@ -14,7 +14,6 @@ class PromptManagementClient(TypedDict):
|
||||||
|
|
||||||
|
|
||||||
class PromptManagementBase(ABC):
|
class PromptManagementBase(ABC):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def integration_name(self) -> str:
|
def integration_name(self) -> str:
|
||||||
|
@ -83,11 +82,7 @@ class PromptManagementBase(ABC):
|
||||||
prompt_id: str,
|
prompt_id: str,
|
||||||
prompt_variables: Optional[dict],
|
prompt_variables: Optional[dict],
|
||||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||||
) -> Tuple[
|
) -> Tuple[str, List[AllMessageValues], dict,]:
|
||||||
str,
|
|
||||||
List[AllMessageValues],
|
|
||||||
dict,
|
|
||||||
]:
|
|
||||||
if not self.should_run_prompt_management(
|
if not self.should_run_prompt_management(
|
||||||
prompt_id=prompt_id, dynamic_callback_params=dynamic_callback_params
|
prompt_id=prompt_id, dynamic_callback_params=dynamic_callback_params
|
||||||
):
|
):
|
||||||
|
|
|
@ -38,7 +38,7 @@ class S3Logger:
|
||||||
if litellm.s3_callback_params is not None:
|
if litellm.s3_callback_params is not None:
|
||||||
# read in .env variables - example os.environ/AWS_BUCKET_NAME
|
# read in .env variables - example os.environ/AWS_BUCKET_NAME
|
||||||
for key, value in litellm.s3_callback_params.items():
|
for key, value in litellm.s3_callback_params.items():
|
||||||
if type(value) is str and value.startswith("os.environ/"):
|
if isinstance(value, str) and value.startswith("os.environ/"):
|
||||||
litellm.s3_callback_params[key] = litellm.get_secret(value)
|
litellm.s3_callback_params[key] = litellm.get_secret(value)
|
||||||
# now set s3 params from litellm.s3_logger_params
|
# now set s3 params from litellm.s3_logger_params
|
||||||
s3_bucket_name = litellm.s3_callback_params.get("s3_bucket_name")
|
s3_bucket_name = litellm.s3_callback_params.get("s3_bucket_name")
|
||||||
|
|
|
@ -21,11 +21,11 @@ try:
|
||||||
# contains a (known) object attribute
|
# contains a (known) object attribute
|
||||||
object: Literal["chat.completion", "edit", "text_completion"]
|
object: Literal["chat.completion", "edit", "text_completion"]
|
||||||
|
|
||||||
def __getitem__(self, key: K) -> V: ... # noqa
|
def __getitem__(self, key: K) -> V:
|
||||||
|
... # noqa
|
||||||
|
|
||||||
def get( # noqa
|
def get(self, key: K, default: Optional[V] = None) -> Optional[V]: # noqa
|
||||||
self, key: K, default: Optional[V] = None
|
... # pragma: no cover
|
||||||
) -> Optional[V]: ... # pragma: no cover
|
|
||||||
|
|
||||||
class OpenAIRequestResponseResolver:
|
class OpenAIRequestResponseResolver:
|
||||||
def __call__(
|
def __call__(
|
||||||
|
|
|
@ -10,7 +10,7 @@ from litellm.types.llms.openai import AllMessageValues
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from opentelemetry.trace import Span as _Span
|
from opentelemetry.trace import Span as _Span
|
||||||
|
|
||||||
Span = _Span
|
Span = Union[_Span, Any]
|
||||||
else:
|
else:
|
||||||
Span = Any
|
Span = Any
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,9 @@ except (ImportError, AttributeError):
|
||||||
# Old way to access resources, which setuptools deprecated some time ago
|
# Old way to access resources, which setuptools deprecated some time ago
|
||||||
import pkg_resources # type: ignore
|
import pkg_resources # type: ignore
|
||||||
|
|
||||||
filename = pkg_resources.resource_filename(__name__, "litellm_core_utils/tokenizers")
|
filename = pkg_resources.resource_filename(
|
||||||
|
__name__, "litellm_core_utils/tokenizers"
|
||||||
|
)
|
||||||
|
|
||||||
os.environ["TIKTOKEN_CACHE_DIR"] = os.getenv(
|
os.environ["TIKTOKEN_CACHE_DIR"] = os.getenv(
|
||||||
"CUSTOM_TIKTOKEN_CACHE_DIR", filename
|
"CUSTOM_TIKTOKEN_CACHE_DIR", filename
|
||||||
|
|
|
@ -239,9 +239,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
self.litellm_trace_id = litellm_trace_id
|
self.litellm_trace_id = litellm_trace_id
|
||||||
self.function_id = function_id
|
self.function_id = function_id
|
||||||
self.streaming_chunks: List[Any] = [] # for generating complete stream response
|
self.streaming_chunks: List[Any] = [] # for generating complete stream response
|
||||||
self.sync_streaming_chunks: List[Any] = (
|
self.sync_streaming_chunks: List[
|
||||||
[]
|
Any
|
||||||
) # for generating complete stream response
|
] = [] # for generating complete stream response
|
||||||
self.log_raw_request_response = log_raw_request_response
|
self.log_raw_request_response = log_raw_request_response
|
||||||
|
|
||||||
# Initialize dynamic callbacks
|
# Initialize dynamic callbacks
|
||||||
|
@ -452,11 +452,13 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
prompt_id: str,
|
prompt_id: str,
|
||||||
prompt_variables: Optional[dict],
|
prompt_variables: Optional[dict],
|
||||||
) -> Tuple[str, List[AllMessageValues], dict]:
|
) -> Tuple[str, List[AllMessageValues], dict]:
|
||||||
|
|
||||||
custom_logger = self.get_custom_logger_for_prompt_management(model)
|
custom_logger = self.get_custom_logger_for_prompt_management(model)
|
||||||
if custom_logger:
|
if custom_logger:
|
||||||
model, messages, non_default_params = (
|
(
|
||||||
custom_logger.get_chat_completion_prompt(
|
model,
|
||||||
|
messages,
|
||||||
|
non_default_params,
|
||||||
|
) = custom_logger.get_chat_completion_prompt(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
non_default_params=non_default_params,
|
non_default_params=non_default_params,
|
||||||
|
@ -464,7 +466,6 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
prompt_variables=prompt_variables,
|
prompt_variables=prompt_variables,
|
||||||
dynamic_callback_params=self.standard_callback_dynamic_params,
|
dynamic_callback_params=self.standard_callback_dynamic_params,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
self.messages = messages
|
self.messages = messages
|
||||||
return model, messages, non_default_params
|
return model, messages, non_default_params
|
||||||
|
|
||||||
|
@ -541,12 +542,11 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
model
|
model
|
||||||
): # if model name was changes pre-call, overwrite the initial model call name with the new one
|
): # if model name was changes pre-call, overwrite the initial model call name with the new one
|
||||||
self.model_call_details["model"] = model
|
self.model_call_details["model"] = model
|
||||||
self.model_call_details["litellm_params"]["api_base"] = (
|
self.model_call_details["litellm_params"][
|
||||||
self._get_masked_api_base(additional_args.get("api_base", ""))
|
"api_base"
|
||||||
)
|
] = self._get_masked_api_base(additional_args.get("api_base", ""))
|
||||||
|
|
||||||
def pre_call(self, input, api_key, model=None, additional_args={}): # noqa: PLR0915
|
def pre_call(self, input, api_key, model=None, additional_args={}): # noqa: PLR0915
|
||||||
|
|
||||||
# Log the exact input to the LLM API
|
# Log the exact input to the LLM API
|
||||||
litellm.error_logs["PRE_CALL"] = locals()
|
litellm.error_logs["PRE_CALL"] = locals()
|
||||||
try:
|
try:
|
||||||
|
@ -568,19 +568,16 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
self.log_raw_request_response is True
|
self.log_raw_request_response is True
|
||||||
or log_raw_request_response is True
|
or log_raw_request_response is True
|
||||||
):
|
):
|
||||||
|
|
||||||
_litellm_params = self.model_call_details.get("litellm_params", {})
|
_litellm_params = self.model_call_details.get("litellm_params", {})
|
||||||
_metadata = _litellm_params.get("metadata", {}) or {}
|
_metadata = _litellm_params.get("metadata", {}) or {}
|
||||||
try:
|
try:
|
||||||
# [Non-blocking Extra Debug Information in metadata]
|
# [Non-blocking Extra Debug Information in metadata]
|
||||||
if turn_off_message_logging is True:
|
if turn_off_message_logging is True:
|
||||||
|
_metadata[
|
||||||
_metadata["raw_request"] = (
|
"raw_request"
|
||||||
"redacted by litellm. \
|
] = "redacted by litellm. \
|
||||||
'litellm.turn_off_message_logging=True'"
|
'litellm.turn_off_message_logging=True'"
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
curl_command = self._get_request_curl_command(
|
curl_command = self._get_request_curl_command(
|
||||||
api_base=additional_args.get("api_base", ""),
|
api_base=additional_args.get("api_base", ""),
|
||||||
headers=additional_args.get("headers", {}),
|
headers=additional_args.get("headers", {}),
|
||||||
|
@ -590,8 +587,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
|
|
||||||
_metadata["raw_request"] = str(curl_command)
|
_metadata["raw_request"] = str(curl_command)
|
||||||
# split up, so it's easier to parse in the UI
|
# split up, so it's easier to parse in the UI
|
||||||
self.model_call_details["raw_request_typed_dict"] = (
|
self.model_call_details[
|
||||||
RawRequestTypedDict(
|
"raw_request_typed_dict"
|
||||||
|
] = RawRequestTypedDict(
|
||||||
raw_request_api_base=str(
|
raw_request_api_base=str(
|
||||||
additional_args.get("api_base") or ""
|
additional_args.get("api_base") or ""
|
||||||
),
|
),
|
||||||
|
@ -604,20 +602,19 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
),
|
),
|
||||||
error=None,
|
error=None,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.model_call_details["raw_request_typed_dict"] = (
|
self.model_call_details[
|
||||||
RawRequestTypedDict(
|
"raw_request_typed_dict"
|
||||||
|
] = RawRequestTypedDict(
|
||||||
error=str(e),
|
error=str(e),
|
||||||
)
|
)
|
||||||
)
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
_metadata["raw_request"] = (
|
_metadata[
|
||||||
"Unable to Log \
|
"raw_request"
|
||||||
|
] = "Unable to Log \
|
||||||
raw request: {}".format(
|
raw request: {}".format(
|
||||||
str(e)
|
str(e)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
if self.logger_fn and callable(self.logger_fn):
|
if self.logger_fn and callable(self.logger_fn):
|
||||||
try:
|
try:
|
||||||
self.logger_fn(
|
self.logger_fn(
|
||||||
|
@ -941,9 +938,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
f"response_cost_failure_debug_information: {debug_info}"
|
f"response_cost_failure_debug_information: {debug_info}"
|
||||||
)
|
)
|
||||||
self.model_call_details["response_cost_failure_debug_information"] = (
|
self.model_call_details[
|
||||||
debug_info
|
"response_cost_failure_debug_information"
|
||||||
)
|
] = debug_info
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -968,9 +965,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
f"response_cost_failure_debug_information: {debug_info}"
|
f"response_cost_failure_debug_information: {debug_info}"
|
||||||
)
|
)
|
||||||
self.model_call_details["response_cost_failure_debug_information"] = (
|
self.model_call_details[
|
||||||
debug_info
|
"response_cost_failure_debug_information"
|
||||||
)
|
] = debug_info
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -995,7 +992,6 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
def should_run_callback(
|
def should_run_callback(
|
||||||
self, callback: litellm.CALLBACK_TYPES, litellm_params: dict, event_hook: str
|
self, callback: litellm.CALLBACK_TYPES, litellm_params: dict, event_hook: str
|
||||||
) -> bool:
|
) -> bool:
|
||||||
|
|
||||||
if litellm.global_disable_no_log_param:
|
if litellm.global_disable_no_log_param:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -1027,9 +1023,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
end_time = datetime.datetime.now()
|
end_time = datetime.datetime.now()
|
||||||
if self.completion_start_time is None:
|
if self.completion_start_time is None:
|
||||||
self.completion_start_time = end_time
|
self.completion_start_time = end_time
|
||||||
self.model_call_details["completion_start_time"] = (
|
self.model_call_details[
|
||||||
self.completion_start_time
|
"completion_start_time"
|
||||||
)
|
] = self.completion_start_time
|
||||||
self.model_call_details["log_event_type"] = "successful_api_call"
|
self.model_call_details["log_event_type"] = "successful_api_call"
|
||||||
self.model_call_details["end_time"] = end_time
|
self.model_call_details["end_time"] = end_time
|
||||||
self.model_call_details["cache_hit"] = cache_hit
|
self.model_call_details["cache_hit"] = cache_hit
|
||||||
|
@ -1083,13 +1079,14 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
"response_cost"
|
"response_cost"
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
self.model_call_details["response_cost"] = (
|
self.model_call_details[
|
||||||
self._response_cost_calculator(result=result)
|
"response_cost"
|
||||||
)
|
] = self._response_cost_calculator(result=result)
|
||||||
## STANDARDIZED LOGGING PAYLOAD
|
## STANDARDIZED LOGGING PAYLOAD
|
||||||
|
|
||||||
self.model_call_details["standard_logging_object"] = (
|
self.model_call_details[
|
||||||
get_standard_logging_object_payload(
|
"standard_logging_object"
|
||||||
|
] = get_standard_logging_object_payload(
|
||||||
kwargs=self.model_call_details,
|
kwargs=self.model_call_details,
|
||||||
init_response_obj=result,
|
init_response_obj=result,
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
|
@ -1098,11 +1095,11 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
status="success",
|
status="success",
|
||||||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
elif isinstance(result, dict): # pass-through endpoints
|
elif isinstance(result, dict): # pass-through endpoints
|
||||||
## STANDARDIZED LOGGING PAYLOAD
|
## STANDARDIZED LOGGING PAYLOAD
|
||||||
self.model_call_details["standard_logging_object"] = (
|
self.model_call_details[
|
||||||
get_standard_logging_object_payload(
|
"standard_logging_object"
|
||||||
|
] = get_standard_logging_object_payload(
|
||||||
kwargs=self.model_call_details,
|
kwargs=self.model_call_details,
|
||||||
init_response_obj=result,
|
init_response_obj=result,
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
|
@ -1111,11 +1108,10 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
status="success",
|
status="success",
|
||||||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
elif standard_logging_object is not None:
|
elif standard_logging_object is not None:
|
||||||
self.model_call_details["standard_logging_object"] = (
|
self.model_call_details[
|
||||||
standard_logging_object
|
"standard_logging_object"
|
||||||
)
|
] = standard_logging_object
|
||||||
else: # streaming chunks + image gen.
|
else: # streaming chunks + image gen.
|
||||||
self.model_call_details["response_cost"] = None
|
self.model_call_details["response_cost"] = None
|
||||||
|
|
||||||
|
@ -1154,7 +1150,6 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
standard_logging_object=kwargs.get("standard_logging_object", None),
|
standard_logging_object=kwargs.get("standard_logging_object", None),
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
|
|
||||||
## BUILD COMPLETE STREAMED RESPONSE
|
## BUILD COMPLETE STREAMED RESPONSE
|
||||||
complete_streaming_response: Optional[
|
complete_streaming_response: Optional[
|
||||||
Union[ModelResponse, TextCompletionResponse, ResponsesAPIResponse]
|
Union[ModelResponse, TextCompletionResponse, ResponsesAPIResponse]
|
||||||
|
@ -1172,15 +1167,16 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
"Logging Details LiteLLM-Success Call streaming complete"
|
"Logging Details LiteLLM-Success Call streaming complete"
|
||||||
)
|
)
|
||||||
self.model_call_details["complete_streaming_response"] = (
|
self.model_call_details[
|
||||||
complete_streaming_response
|
"complete_streaming_response"
|
||||||
)
|
] = complete_streaming_response
|
||||||
self.model_call_details["response_cost"] = (
|
self.model_call_details[
|
||||||
self._response_cost_calculator(result=complete_streaming_response)
|
"response_cost"
|
||||||
)
|
] = self._response_cost_calculator(result=complete_streaming_response)
|
||||||
## STANDARDIZED LOGGING PAYLOAD
|
## STANDARDIZED LOGGING PAYLOAD
|
||||||
self.model_call_details["standard_logging_object"] = (
|
self.model_call_details[
|
||||||
get_standard_logging_object_payload(
|
"standard_logging_object"
|
||||||
|
] = get_standard_logging_object_payload(
|
||||||
kwargs=self.model_call_details,
|
kwargs=self.model_call_details,
|
||||||
init_response_obj=complete_streaming_response,
|
init_response_obj=complete_streaming_response,
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
|
@ -1189,7 +1185,6 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
status="success",
|
status="success",
|
||||||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
callbacks = self.get_combined_callback_list(
|
callbacks = self.get_combined_callback_list(
|
||||||
dynamic_success_callbacks=self.dynamic_success_callbacks,
|
dynamic_success_callbacks=self.dynamic_success_callbacks,
|
||||||
global_callbacks=litellm.success_callback,
|
global_callbacks=litellm.success_callback,
|
||||||
|
@ -1207,7 +1202,6 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
## LOGGING HOOK ##
|
## LOGGING HOOK ##
|
||||||
for callback in callbacks:
|
for callback in callbacks:
|
||||||
if isinstance(callback, CustomLogger):
|
if isinstance(callback, CustomLogger):
|
||||||
|
|
||||||
self.model_call_details, result = callback.logging_hook(
|
self.model_call_details, result = callback.logging_hook(
|
||||||
kwargs=self.model_call_details,
|
kwargs=self.model_call_details,
|
||||||
result=result,
|
result=result,
|
||||||
|
@ -1538,11 +1532,11 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if self.stream and complete_streaming_response:
|
if self.stream and complete_streaming_response:
|
||||||
self.model_call_details["complete_response"] = (
|
self.model_call_details[
|
||||||
self.model_call_details.get(
|
"complete_response"
|
||||||
|
] = self.model_call_details.get(
|
||||||
"complete_streaming_response", {}
|
"complete_streaming_response", {}
|
||||||
)
|
)
|
||||||
)
|
|
||||||
result = self.model_call_details["complete_response"]
|
result = self.model_call_details["complete_response"]
|
||||||
openMeterLogger.log_success_event(
|
openMeterLogger.log_success_event(
|
||||||
kwargs=self.model_call_details,
|
kwargs=self.model_call_details,
|
||||||
|
@ -1581,11 +1575,11 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if self.stream and complete_streaming_response:
|
if self.stream and complete_streaming_response:
|
||||||
self.model_call_details["complete_response"] = (
|
self.model_call_details[
|
||||||
self.model_call_details.get(
|
"complete_response"
|
||||||
|
] = self.model_call_details.get(
|
||||||
"complete_streaming_response", {}
|
"complete_streaming_response", {}
|
||||||
)
|
)
|
||||||
)
|
|
||||||
result = self.model_call_details["complete_response"]
|
result = self.model_call_details["complete_response"]
|
||||||
|
|
||||||
callback.log_success_event(
|
callback.log_success_event(
|
||||||
|
@ -1659,7 +1653,6 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
if self.call_type == CallTypes.aretrieve_batch.value and isinstance(
|
if self.call_type == CallTypes.aretrieve_batch.value and isinstance(
|
||||||
result, LiteLLMBatch
|
result, LiteLLMBatch
|
||||||
):
|
):
|
||||||
|
|
||||||
response_cost, batch_usage, batch_models = await _handle_completed_batch(
|
response_cost, batch_usage, batch_models = await _handle_completed_batch(
|
||||||
batch=result, custom_llm_provider=self.custom_llm_provider
|
batch=result, custom_llm_provider=self.custom_llm_provider
|
||||||
)
|
)
|
||||||
|
@ -1692,9 +1685,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
if complete_streaming_response is not None:
|
if complete_streaming_response is not None:
|
||||||
print_verbose("Async success callbacks: Got a complete streaming response")
|
print_verbose("Async success callbacks: Got a complete streaming response")
|
||||||
|
|
||||||
self.model_call_details["async_complete_streaming_response"] = (
|
self.model_call_details[
|
||||||
complete_streaming_response
|
"async_complete_streaming_response"
|
||||||
)
|
] = complete_streaming_response
|
||||||
try:
|
try:
|
||||||
if self.model_call_details.get("cache_hit", False) is True:
|
if self.model_call_details.get("cache_hit", False) is True:
|
||||||
self.model_call_details["response_cost"] = 0.0
|
self.model_call_details["response_cost"] = 0.0
|
||||||
|
@ -1704,11 +1697,11 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
model_call_details=self.model_call_details
|
model_call_details=self.model_call_details
|
||||||
)
|
)
|
||||||
# base_model defaults to None if not set on model_info
|
# base_model defaults to None if not set on model_info
|
||||||
self.model_call_details["response_cost"] = (
|
self.model_call_details[
|
||||||
self._response_cost_calculator(
|
"response_cost"
|
||||||
|
] = self._response_cost_calculator(
|
||||||
result=complete_streaming_response
|
result=complete_streaming_response
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
f"Model={self.model}; cost={self.model_call_details['response_cost']}"
|
f"Model={self.model}; cost={self.model_call_details['response_cost']}"
|
||||||
|
@ -1720,8 +1713,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
self.model_call_details["response_cost"] = None
|
self.model_call_details["response_cost"] = None
|
||||||
|
|
||||||
## STANDARDIZED LOGGING PAYLOAD
|
## STANDARDIZED LOGGING PAYLOAD
|
||||||
self.model_call_details["standard_logging_object"] = (
|
self.model_call_details[
|
||||||
get_standard_logging_object_payload(
|
"standard_logging_object"
|
||||||
|
] = get_standard_logging_object_payload(
|
||||||
kwargs=self.model_call_details,
|
kwargs=self.model_call_details,
|
||||||
init_response_obj=complete_streaming_response,
|
init_response_obj=complete_streaming_response,
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
|
@ -1730,7 +1724,6 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
status="success",
|
status="success",
|
||||||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
callbacks = self.get_combined_callback_list(
|
callbacks = self.get_combined_callback_list(
|
||||||
dynamic_success_callbacks=self.dynamic_async_success_callbacks,
|
dynamic_success_callbacks=self.dynamic_async_success_callbacks,
|
||||||
global_callbacks=litellm._async_success_callback,
|
global_callbacks=litellm._async_success_callback,
|
||||||
|
@ -1935,8 +1928,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
|
|
||||||
## STANDARDIZED LOGGING PAYLOAD
|
## STANDARDIZED LOGGING PAYLOAD
|
||||||
|
|
||||||
self.model_call_details["standard_logging_object"] = (
|
self.model_call_details[
|
||||||
get_standard_logging_object_payload(
|
"standard_logging_object"
|
||||||
|
] = get_standard_logging_object_payload(
|
||||||
kwargs=self.model_call_details,
|
kwargs=self.model_call_details,
|
||||||
init_response_obj={},
|
init_response_obj={},
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
|
@ -1947,7 +1941,6 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
original_exception=exception,
|
original_exception=exception,
|
||||||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
return start_time, end_time
|
return start_time, end_time
|
||||||
|
|
||||||
async def special_failure_handlers(self, exception: Exception):
|
async def special_failure_handlers(self, exception: Exception):
|
||||||
|
@ -2084,7 +2077,6 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
)
|
)
|
||||||
is not True
|
is not True
|
||||||
): # custom logger class
|
): # custom logger class
|
||||||
|
|
||||||
callback.log_failure_event(
|
callback.log_failure_event(
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
|
@ -2713,9 +2705,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
|
||||||
endpoint=arize_config.endpoint,
|
endpoint=arize_config.endpoint,
|
||||||
)
|
)
|
||||||
|
|
||||||
os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = (
|
os.environ[
|
||||||
f"space_key={arize_config.space_key},api_key={arize_config.api_key}"
|
"OTEL_EXPORTER_OTLP_TRACES_HEADERS"
|
||||||
)
|
] = f"space_key={arize_config.space_key},api_key={arize_config.api_key}"
|
||||||
for callback in _in_memory_loggers:
|
for callback in _in_memory_loggers:
|
||||||
if (
|
if (
|
||||||
isinstance(callback, ArizeLogger)
|
isinstance(callback, ArizeLogger)
|
||||||
|
@ -2739,9 +2731,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
|
||||||
|
|
||||||
# auth can be disabled on local deployments of arize phoenix
|
# auth can be disabled on local deployments of arize phoenix
|
||||||
if arize_phoenix_config.otlp_auth_headers is not None:
|
if arize_phoenix_config.otlp_auth_headers is not None:
|
||||||
os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = (
|
os.environ[
|
||||||
arize_phoenix_config.otlp_auth_headers
|
"OTEL_EXPORTER_OTLP_TRACES_HEADERS"
|
||||||
)
|
] = arize_phoenix_config.otlp_auth_headers
|
||||||
|
|
||||||
for callback in _in_memory_loggers:
|
for callback in _in_memory_loggers:
|
||||||
if (
|
if (
|
||||||
|
@ -2832,9 +2824,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
|
||||||
exporter="otlp_http",
|
exporter="otlp_http",
|
||||||
endpoint="https://langtrace.ai/api/trace",
|
endpoint="https://langtrace.ai/api/trace",
|
||||||
)
|
)
|
||||||
os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = (
|
os.environ[
|
||||||
f"api_key={os.getenv('LANGTRACE_API_KEY')}"
|
"OTEL_EXPORTER_OTLP_TRACES_HEADERS"
|
||||||
)
|
] = f"api_key={os.getenv('LANGTRACE_API_KEY')}"
|
||||||
for callback in _in_memory_loggers:
|
for callback in _in_memory_loggers:
|
||||||
if (
|
if (
|
||||||
isinstance(callback, OpenTelemetry)
|
isinstance(callback, OpenTelemetry)
|
||||||
|
@ -3223,7 +3215,6 @@ class StandardLoggingPayloadSetup:
|
||||||
custom_llm_provider: Optional[str],
|
custom_llm_provider: Optional[str],
|
||||||
init_response_obj: Union[Any, BaseModel, dict],
|
init_response_obj: Union[Any, BaseModel, dict],
|
||||||
) -> StandardLoggingModelInformation:
|
) -> StandardLoggingModelInformation:
|
||||||
|
|
||||||
model_cost_name = _select_model_name_for_cost_calc(
|
model_cost_name = _select_model_name_for_cost_calc(
|
||||||
model=None,
|
model=None,
|
||||||
completion_response=init_response_obj, # type: ignore
|
completion_response=init_response_obj, # type: ignore
|
||||||
|
@ -3286,7 +3277,6 @@ class StandardLoggingPayloadSetup:
|
||||||
def get_additional_headers(
|
def get_additional_headers(
|
||||||
additiona_headers: Optional[dict],
|
additiona_headers: Optional[dict],
|
||||||
) -> Optional[StandardLoggingAdditionalHeaders]:
|
) -> Optional[StandardLoggingAdditionalHeaders]:
|
||||||
|
|
||||||
if additiona_headers is None:
|
if additiona_headers is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -3322,11 +3312,11 @@ class StandardLoggingPayloadSetup:
|
||||||
for key in StandardLoggingHiddenParams.__annotations__.keys():
|
for key in StandardLoggingHiddenParams.__annotations__.keys():
|
||||||
if key in hidden_params:
|
if key in hidden_params:
|
||||||
if key == "additional_headers":
|
if key == "additional_headers":
|
||||||
clean_hidden_params["additional_headers"] = (
|
clean_hidden_params[
|
||||||
StandardLoggingPayloadSetup.get_additional_headers(
|
"additional_headers"
|
||||||
|
] = StandardLoggingPayloadSetup.get_additional_headers(
|
||||||
hidden_params[key]
|
hidden_params[key]
|
||||||
)
|
)
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
clean_hidden_params[key] = hidden_params[key] # type: ignore
|
clean_hidden_params[key] = hidden_params[key] # type: ignore
|
||||||
return clean_hidden_params
|
return clean_hidden_params
|
||||||
|
@ -3463,13 +3453,15 @@ def get_standard_logging_object_payload(
|
||||||
)
|
)
|
||||||
|
|
||||||
# cleanup timestamps
|
# cleanup timestamps
|
||||||
start_time_float, end_time_float, completion_start_time_float = (
|
(
|
||||||
StandardLoggingPayloadSetup.cleanup_timestamps(
|
start_time_float,
|
||||||
|
end_time_float,
|
||||||
|
completion_start_time_float,
|
||||||
|
) = StandardLoggingPayloadSetup.cleanup_timestamps(
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
completion_start_time=completion_start_time,
|
completion_start_time=completion_start_time,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
response_time = StandardLoggingPayloadSetup.get_response_time(
|
response_time = StandardLoggingPayloadSetup.get_response_time(
|
||||||
start_time_float=start_time_float,
|
start_time_float=start_time_float,
|
||||||
end_time_float=end_time_float,
|
end_time_float=end_time_float,
|
||||||
|
@ -3495,7 +3487,6 @@ def get_standard_logging_object_payload(
|
||||||
|
|
||||||
saved_cache_cost: float = 0.0
|
saved_cache_cost: float = 0.0
|
||||||
if cache_hit is True:
|
if cache_hit is True:
|
||||||
|
|
||||||
id = f"{id}_cache_hit{time.time()}" # do not duplicate the request id
|
id = f"{id}_cache_hit{time.time()}" # do not duplicate the request id
|
||||||
saved_cache_cost = (
|
saved_cache_cost = (
|
||||||
logging_obj._response_cost_calculator(
|
logging_obj._response_cost_calculator(
|
||||||
|
@ -3658,9 +3649,9 @@ def scrub_sensitive_keys_in_metadata(litellm_params: Optional[dict]):
|
||||||
):
|
):
|
||||||
for k, v in metadata["user_api_key_metadata"].items():
|
for k, v in metadata["user_api_key_metadata"].items():
|
||||||
if k == "logging": # prevent logging user logging keys
|
if k == "logging": # prevent logging user logging keys
|
||||||
cleaned_user_api_key_metadata[k] = (
|
cleaned_user_api_key_metadata[
|
||||||
"scrubbed_by_litellm_for_sensitive_keys"
|
k
|
||||||
)
|
] = "scrubbed_by_litellm_for_sensitive_keys"
|
||||||
else:
|
else:
|
||||||
cleaned_user_api_key_metadata[k] = v
|
cleaned_user_api_key_metadata[k] = v
|
||||||
|
|
||||||
|
|
|
@ -258,14 +258,12 @@ def _extract_reasoning_content(message: dict) -> Tuple[Optional[str], Optional[s
|
||||||
|
|
||||||
|
|
||||||
class LiteLLMResponseObjectHandler:
|
class LiteLLMResponseObjectHandler:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def convert_to_image_response(
|
def convert_to_image_response(
|
||||||
response_object: dict,
|
response_object: dict,
|
||||||
model_response_object: Optional[ImageResponse] = None,
|
model_response_object: Optional[ImageResponse] = None,
|
||||||
hidden_params: Optional[dict] = None,
|
hidden_params: Optional[dict] = None,
|
||||||
) -> ImageResponse:
|
) -> ImageResponse:
|
||||||
|
|
||||||
response_object.update({"hidden_params": hidden_params})
|
response_object.update({"hidden_params": hidden_params})
|
||||||
|
|
||||||
if model_response_object is None:
|
if model_response_object is None:
|
||||||
|
@ -481,9 +479,9 @@ def convert_to_model_response_object( # noqa: PLR0915
|
||||||
provider_specific_fields["thinking_blocks"] = thinking_blocks
|
provider_specific_fields["thinking_blocks"] = thinking_blocks
|
||||||
|
|
||||||
if reasoning_content:
|
if reasoning_content:
|
||||||
provider_specific_fields["reasoning_content"] = (
|
provider_specific_fields[
|
||||||
reasoning_content
|
"reasoning_content"
|
||||||
)
|
] = reasoning_content
|
||||||
|
|
||||||
message = Message(
|
message = Message(
|
||||||
content=content,
|
content=content,
|
||||||
|
|
|
@ -17,7 +17,6 @@ from litellm.types.rerank import RerankRequest
|
||||||
|
|
||||||
|
|
||||||
class ModelParamHelper:
|
class ModelParamHelper:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_standard_logging_model_parameters(
|
def get_standard_logging_model_parameters(
|
||||||
model_parameters: dict,
|
model_parameters: dict,
|
||||||
|
|
|
@ -257,7 +257,6 @@ def _insert_assistant_continue_message(
|
||||||
and message.get("role") == "user" # Current is user
|
and message.get("role") == "user" # Current is user
|
||||||
and messages[i + 1].get("role") == "user"
|
and messages[i + 1].get("role") == "user"
|
||||||
): # Next is user
|
): # Next is user
|
||||||
|
|
||||||
# Insert assistant message
|
# Insert assistant message
|
||||||
continue_message = (
|
continue_message = (
|
||||||
assistant_continue_message or DEFAULT_ASSISTANT_CONTINUE_MESSAGE
|
assistant_continue_message or DEFAULT_ASSISTANT_CONTINUE_MESSAGE
|
||||||
|
|
|
@ -1042,11 +1042,11 @@ def convert_to_gemini_tool_call_invoke(
|
||||||
if tool_calls is not None:
|
if tool_calls is not None:
|
||||||
for tool in tool_calls:
|
for tool in tool_calls:
|
||||||
if "function" in tool:
|
if "function" in tool:
|
||||||
gemini_function_call: Optional[VertexFunctionCall] = (
|
gemini_function_call: Optional[
|
||||||
_gemini_tool_call_invoke_helper(
|
VertexFunctionCall
|
||||||
|
] = _gemini_tool_call_invoke_helper(
|
||||||
function_call_params=tool["function"]
|
function_call_params=tool["function"]
|
||||||
)
|
)
|
||||||
)
|
|
||||||
if gemini_function_call is not None:
|
if gemini_function_call is not None:
|
||||||
_parts_list.append(
|
_parts_list.append(
|
||||||
VertexPartType(function_call=gemini_function_call)
|
VertexPartType(function_call=gemini_function_call)
|
||||||
|
@ -1432,9 +1432,9 @@ def anthropic_messages_pt( # noqa: PLR0915
|
||||||
)
|
)
|
||||||
|
|
||||||
if "cache_control" in _content_element:
|
if "cache_control" in _content_element:
|
||||||
_anthropic_content_element["cache_control"] = (
|
_anthropic_content_element[
|
||||||
_content_element["cache_control"]
|
"cache_control"
|
||||||
)
|
] = _content_element["cache_control"]
|
||||||
user_content.append(_anthropic_content_element)
|
user_content.append(_anthropic_content_element)
|
||||||
elif m.get("type", "") == "text":
|
elif m.get("type", "") == "text":
|
||||||
m = cast(ChatCompletionTextObject, m)
|
m = cast(ChatCompletionTextObject, m)
|
||||||
|
@ -1466,9 +1466,9 @@ def anthropic_messages_pt( # noqa: PLR0915
|
||||||
)
|
)
|
||||||
|
|
||||||
if "cache_control" in _content_element:
|
if "cache_control" in _content_element:
|
||||||
_anthropic_content_text_element["cache_control"] = (
|
_anthropic_content_text_element[
|
||||||
_content_element["cache_control"]
|
"cache_control"
|
||||||
)
|
] = _content_element["cache_control"]
|
||||||
|
|
||||||
user_content.append(_anthropic_content_text_element)
|
user_content.append(_anthropic_content_text_element)
|
||||||
|
|
||||||
|
@ -1533,7 +1533,6 @@ def anthropic_messages_pt( # noqa: PLR0915
|
||||||
"content"
|
"content"
|
||||||
] # don't pass empty text blocks. anthropic api raises errors.
|
] # don't pass empty text blocks. anthropic api raises errors.
|
||||||
):
|
):
|
||||||
|
|
||||||
_anthropic_text_content_element = AnthropicMessagesTextParam(
|
_anthropic_text_content_element = AnthropicMessagesTextParam(
|
||||||
type="text",
|
type="text",
|
||||||
text=assistant_content_block["content"],
|
text=assistant_content_block["content"],
|
||||||
|
@ -1569,7 +1568,6 @@ def anthropic_messages_pt( # noqa: PLR0915
|
||||||
msg_i += 1
|
msg_i += 1
|
||||||
|
|
||||||
if assistant_content:
|
if assistant_content:
|
||||||
|
|
||||||
new_messages.append({"role": "assistant", "content": assistant_content})
|
new_messages.append({"role": "assistant", "content": assistant_content})
|
||||||
|
|
||||||
if msg_i == init_msg_i: # prevent infinite loops
|
if msg_i == init_msg_i: # prevent infinite loops
|
||||||
|
@ -2245,7 +2243,6 @@ class BedrockImageProcessor:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def get_image_details_async(image_url) -> Tuple[str, str]:
|
async def get_image_details_async(image_url) -> Tuple[str, str]:
|
||||||
try:
|
try:
|
||||||
|
|
||||||
client = get_async_httpx_client(
|
client = get_async_httpx_client(
|
||||||
llm_provider=httpxSpecialProvider.PromptFactory,
|
llm_provider=httpxSpecialProvider.PromptFactory,
|
||||||
params={"concurrent_limit": 1},
|
params={"concurrent_limit": 1},
|
||||||
|
@ -2612,7 +2609,6 @@ def get_user_message_block_or_continue_message(
|
||||||
for item in modified_content_block:
|
for item in modified_content_block:
|
||||||
# Check if the list is empty
|
# Check if the list is empty
|
||||||
if item["type"] == "text":
|
if item["type"] == "text":
|
||||||
|
|
||||||
if not item["text"].strip():
|
if not item["text"].strip():
|
||||||
# Replace empty text with continue message
|
# Replace empty text with continue message
|
||||||
_user_continue_message = ChatCompletionUserMessage(
|
_user_continue_message = ChatCompletionUserMessage(
|
||||||
|
@ -3207,7 +3203,6 @@ def _bedrock_converse_messages_pt( # noqa: PLR0915
|
||||||
assistant_content: List[BedrockContentBlock] = []
|
assistant_content: List[BedrockContentBlock] = []
|
||||||
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
|
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
|
||||||
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
|
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
|
||||||
|
|
||||||
assistant_message_block = get_assistant_message_block_or_continue_message(
|
assistant_message_block = get_assistant_message_block_or_continue_message(
|
||||||
message=messages[msg_i],
|
message=messages[msg_i],
|
||||||
assistant_continue_message=assistant_continue_message,
|
assistant_continue_message=assistant_continue_message,
|
||||||
|
@ -3410,7 +3405,6 @@ def response_schema_prompt(model: str, response_schema: dict) -> str:
|
||||||
{"role": "user", "content": "{}".format(response_schema)}
|
{"role": "user", "content": "{}".format(response_schema)}
|
||||||
]
|
]
|
||||||
if f"{model}/response_schema_prompt" in litellm.custom_prompt_dict:
|
if f"{model}/response_schema_prompt" in litellm.custom_prompt_dict:
|
||||||
|
|
||||||
custom_prompt_details = litellm.custom_prompt_dict[
|
custom_prompt_details = litellm.custom_prompt_dict[
|
||||||
f"{model}/response_schema_prompt"
|
f"{model}/response_schema_prompt"
|
||||||
] # allow user to define custom response schema prompt by model
|
] # allow user to define custom response schema prompt by model
|
||||||
|
|
|
@ -122,7 +122,6 @@ class RealTimeStreaming:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def bidirectional_forward(self):
|
async def bidirectional_forward(self):
|
||||||
|
|
||||||
forward_task = asyncio.create_task(self.backend_to_client_send_messages())
|
forward_task = asyncio.create_task(self.backend_to_client_send_messages())
|
||||||
try:
|
try:
|
||||||
await self.client_ack_messages()
|
await self.client_ack_messages()
|
||||||
|
|
|
@ -135,9 +135,9 @@ def _get_turn_off_message_logging_from_dynamic_params(
|
||||||
|
|
||||||
handles boolean and string values of `turn_off_message_logging`
|
handles boolean and string values of `turn_off_message_logging`
|
||||||
"""
|
"""
|
||||||
standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = (
|
standard_callback_dynamic_params: Optional[
|
||||||
model_call_details.get("standard_callback_dynamic_params", None)
|
StandardCallbackDynamicParams
|
||||||
)
|
] = model_call_details.get("standard_callback_dynamic_params", None)
|
||||||
if standard_callback_dynamic_params:
|
if standard_callback_dynamic_params:
|
||||||
_turn_off_message_logging = standard_callback_dynamic_params.get(
|
_turn_off_message_logging = standard_callback_dynamic_params.get(
|
||||||
"turn_off_message_logging"
|
"turn_off_message_logging"
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from typing import Any, Dict, Optional, Set
|
from typing import Any, Dict, Optional, Set
|
||||||
|
|
||||||
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
|
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
|
||||||
|
|
||||||
|
|
||||||
|
@ -40,7 +41,10 @@ class SensitiveDataMasker:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def mask_dict(
|
def mask_dict(
|
||||||
self, data: Dict[str, Any], depth: int = 0, max_depth: int = DEFAULT_MAX_RECURSE_DEPTH
|
self,
|
||||||
|
data: Dict[str, Any],
|
||||||
|
depth: int = 0,
|
||||||
|
max_depth: int = DEFAULT_MAX_RECURSE_DEPTH,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
if depth >= max_depth:
|
if depth >= max_depth:
|
||||||
return data
|
return data
|
||||||
|
|
|
@ -104,7 +104,6 @@ class ChunkProcessor:
|
||||||
def get_combined_tool_content(
|
def get_combined_tool_content(
|
||||||
self, tool_call_chunks: List[Dict[str, Any]]
|
self, tool_call_chunks: List[Dict[str, Any]]
|
||||||
) -> List[ChatCompletionMessageToolCall]:
|
) -> List[ChatCompletionMessageToolCall]:
|
||||||
|
|
||||||
argument_list: List[str] = []
|
argument_list: List[str] = []
|
||||||
delta = tool_call_chunks[0]["choices"][0]["delta"]
|
delta = tool_call_chunks[0]["choices"][0]["delta"]
|
||||||
id = None
|
id = None
|
||||||
|
|
|
@ -84,9 +84,9 @@ class CustomStreamWrapper:
|
||||||
|
|
||||||
self.system_fingerprint: Optional[str] = None
|
self.system_fingerprint: Optional[str] = None
|
||||||
self.received_finish_reason: Optional[str] = None
|
self.received_finish_reason: Optional[str] = None
|
||||||
self.intermittent_finish_reason: Optional[str] = (
|
self.intermittent_finish_reason: Optional[
|
||||||
None # finish reasons that show up mid-stream
|
str
|
||||||
)
|
] = None # finish reasons that show up mid-stream
|
||||||
self.special_tokens = [
|
self.special_tokens = [
|
||||||
"<|assistant|>",
|
"<|assistant|>",
|
||||||
"<|system|>",
|
"<|system|>",
|
||||||
|
@ -814,7 +814,6 @@ class CustomStreamWrapper:
|
||||||
model_response: ModelResponseStream,
|
model_response: ModelResponseStream,
|
||||||
response_obj: Dict[str, Any],
|
response_obj: Dict[str, Any],
|
||||||
):
|
):
|
||||||
|
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"completion_obj: {completion_obj}, model_response.choices[0]: {model_response.choices[0]}, response_obj: {response_obj}"
|
f"completion_obj: {completion_obj}, model_response.choices[0]: {model_response.choices[0]}, response_obj: {response_obj}"
|
||||||
)
|
)
|
||||||
|
@ -1008,7 +1007,6 @@ class CustomStreamWrapper:
|
||||||
self.custom_llm_provider
|
self.custom_llm_provider
|
||||||
and self.custom_llm_provider in litellm._custom_providers
|
and self.custom_llm_provider in litellm._custom_providers
|
||||||
):
|
):
|
||||||
|
|
||||||
if self.received_finish_reason is not None:
|
if self.received_finish_reason is not None:
|
||||||
if "provider_specific_fields" not in chunk:
|
if "provider_specific_fields" not in chunk:
|
||||||
raise StopIteration
|
raise StopIteration
|
||||||
|
@ -1379,9 +1377,9 @@ class CustomStreamWrapper:
|
||||||
_json_delta = delta.model_dump()
|
_json_delta = delta.model_dump()
|
||||||
print_verbose(f"_json_delta: {_json_delta}")
|
print_verbose(f"_json_delta: {_json_delta}")
|
||||||
if "role" not in _json_delta or _json_delta["role"] is None:
|
if "role" not in _json_delta or _json_delta["role"] is None:
|
||||||
_json_delta["role"] = (
|
_json_delta[
|
||||||
"assistant" # mistral's api returns role as None
|
"role"
|
||||||
)
|
] = "assistant" # mistral's api returns role as None
|
||||||
if "tool_calls" in _json_delta and isinstance(
|
if "tool_calls" in _json_delta and isinstance(
|
||||||
_json_delta["tool_calls"], list
|
_json_delta["tool_calls"], list
|
||||||
):
|
):
|
||||||
|
@ -1758,9 +1756,9 @@ class CustomStreamWrapper:
|
||||||
chunk = next(self.completion_stream)
|
chunk = next(self.completion_stream)
|
||||||
if chunk is not None and chunk != b"":
|
if chunk is not None and chunk != b"":
|
||||||
print_verbose(f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}")
|
print_verbose(f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}")
|
||||||
processed_chunk: Optional[ModelResponseStream] = (
|
processed_chunk: Optional[
|
||||||
self.chunk_creator(chunk=chunk)
|
ModelResponseStream
|
||||||
)
|
] = self.chunk_creator(chunk=chunk)
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"PROCESSED CHUNK POST CHUNK CREATOR: {processed_chunk}"
|
f"PROCESSED CHUNK POST CHUNK CREATOR: {processed_chunk}"
|
||||||
)
|
)
|
||||||
|
|
|
@ -290,7 +290,6 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
headers={},
|
headers={},
|
||||||
client=None,
|
client=None,
|
||||||
):
|
):
|
||||||
|
|
||||||
optional_params = copy.deepcopy(optional_params)
|
optional_params = copy.deepcopy(optional_params)
|
||||||
stream = optional_params.pop("stream", None)
|
stream = optional_params.pop("stream", None)
|
||||||
json_mode: bool = optional_params.pop("json_mode", False)
|
json_mode: bool = optional_params.pop("json_mode", False)
|
||||||
|
@ -491,7 +490,6 @@ class ModelResponseIterator:
|
||||||
def _handle_usage(
|
def _handle_usage(
|
||||||
self, anthropic_usage_chunk: Union[dict, UsageDelta]
|
self, anthropic_usage_chunk: Union[dict, UsageDelta]
|
||||||
) -> AnthropicChatCompletionUsageBlock:
|
) -> AnthropicChatCompletionUsageBlock:
|
||||||
|
|
||||||
usage_block = AnthropicChatCompletionUsageBlock(
|
usage_block = AnthropicChatCompletionUsageBlock(
|
||||||
prompt_tokens=anthropic_usage_chunk.get("input_tokens", 0),
|
prompt_tokens=anthropic_usage_chunk.get("input_tokens", 0),
|
||||||
completion_tokens=anthropic_usage_chunk.get("output_tokens", 0),
|
completion_tokens=anthropic_usage_chunk.get("output_tokens", 0),
|
||||||
|
@ -515,7 +513,9 @@ class ModelResponseIterator:
|
||||||
|
|
||||||
return usage_block
|
return usage_block
|
||||||
|
|
||||||
def _content_block_delta_helper(self, chunk: dict) -> Tuple[
|
def _content_block_delta_helper(
|
||||||
|
self, chunk: dict
|
||||||
|
) -> Tuple[
|
||||||
str,
|
str,
|
||||||
Optional[ChatCompletionToolCallChunk],
|
Optional[ChatCompletionToolCallChunk],
|
||||||
List[ChatCompletionThinkingBlock],
|
List[ChatCompletionThinkingBlock],
|
||||||
|
@ -592,9 +592,12 @@ class ModelResponseIterator:
|
||||||
Anthropic content chunk
|
Anthropic content chunk
|
||||||
chunk = {'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': 'Hello'}}
|
chunk = {'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': 'Hello'}}
|
||||||
"""
|
"""
|
||||||
text, tool_use, thinking_blocks, provider_specific_fields = (
|
(
|
||||||
self._content_block_delta_helper(chunk=chunk)
|
text,
|
||||||
)
|
tool_use,
|
||||||
|
thinking_blocks,
|
||||||
|
provider_specific_fields,
|
||||||
|
) = self._content_block_delta_helper(chunk=chunk)
|
||||||
if thinking_blocks:
|
if thinking_blocks:
|
||||||
reasoning_content = self._handle_reasoning_content(
|
reasoning_content = self._handle_reasoning_content(
|
||||||
thinking_blocks=thinking_blocks
|
thinking_blocks=thinking_blocks
|
||||||
|
@ -620,7 +623,6 @@ class ModelResponseIterator:
|
||||||
"index": self.tool_index,
|
"index": self.tool_index,
|
||||||
}
|
}
|
||||||
elif type_chunk == "content_block_stop":
|
elif type_chunk == "content_block_stop":
|
||||||
|
|
||||||
ContentBlockStop(**chunk) # type: ignore
|
ContentBlockStop(**chunk) # type: ignore
|
||||||
# check if tool call content block
|
# check if tool call content block
|
||||||
is_empty = self.check_empty_tool_call_args()
|
is_empty = self.check_empty_tool_call_args()
|
||||||
|
|
|
@ -49,9 +49,9 @@ class AnthropicConfig(BaseConfig):
|
||||||
to pass metadata to anthropic, it's {"user_id": "any-relevant-information"}
|
to pass metadata to anthropic, it's {"user_id": "any-relevant-information"}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
max_tokens: Optional[int] = (
|
max_tokens: Optional[
|
||||||
4096 # anthropic requires a default value (Opus, Sonnet, and Haiku have the same default)
|
int
|
||||||
)
|
] = 4096 # anthropic requires a default value (Opus, Sonnet, and Haiku have the same default)
|
||||||
stop_sequences: Optional[list] = None
|
stop_sequences: Optional[list] = None
|
||||||
temperature: Optional[int] = None
|
temperature: Optional[int] = None
|
||||||
top_p: Optional[int] = None
|
top_p: Optional[int] = None
|
||||||
|
@ -104,7 +104,6 @@ class AnthropicConfig(BaseConfig):
|
||||||
def get_json_schema_from_pydantic_object(
|
def get_json_schema_from_pydantic_object(
|
||||||
self, response_format: Union[Any, Dict, None]
|
self, response_format: Union[Any, Dict, None]
|
||||||
) -> Optional[dict]:
|
) -> Optional[dict]:
|
||||||
|
|
||||||
return type_to_response_format_param(
|
return type_to_response_format_param(
|
||||||
response_format, ref_template="/$defs/{model}"
|
response_format, ref_template="/$defs/{model}"
|
||||||
) # Relevant issue: https://github.com/BerriAI/litellm/issues/7755
|
) # Relevant issue: https://github.com/BerriAI/litellm/issues/7755
|
||||||
|
@ -125,7 +124,6 @@ class AnthropicConfig(BaseConfig):
|
||||||
is_vertex_request: bool = False,
|
is_vertex_request: bool = False,
|
||||||
user_anthropic_beta_headers: Optional[List[str]] = None,
|
user_anthropic_beta_headers: Optional[List[str]] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
||||||
betas = set()
|
betas = set()
|
||||||
if prompt_caching_set:
|
if prompt_caching_set:
|
||||||
betas.add("prompt-caching-2024-07-31")
|
betas.add("prompt-caching-2024-07-31")
|
||||||
|
@ -300,7 +298,6 @@ class AnthropicConfig(BaseConfig):
|
||||||
model: str,
|
model: str,
|
||||||
drop_params: bool,
|
drop_params: bool,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
||||||
is_thinking_enabled = self.is_thinking_enabled(
|
is_thinking_enabled = self.is_thinking_enabled(
|
||||||
non_default_params=non_default_params
|
non_default_params=non_default_params
|
||||||
)
|
)
|
||||||
|
@ -321,12 +318,12 @@ class AnthropicConfig(BaseConfig):
|
||||||
optional_params=optional_params, tools=tool_value
|
optional_params=optional_params, tools=tool_value
|
||||||
)
|
)
|
||||||
if param == "tool_choice" or param == "parallel_tool_calls":
|
if param == "tool_choice" or param == "parallel_tool_calls":
|
||||||
_tool_choice: Optional[AnthropicMessagesToolChoice] = (
|
_tool_choice: Optional[
|
||||||
self._map_tool_choice(
|
AnthropicMessagesToolChoice
|
||||||
|
] = self._map_tool_choice(
|
||||||
tool_choice=non_default_params.get("tool_choice"),
|
tool_choice=non_default_params.get("tool_choice"),
|
||||||
parallel_tool_use=non_default_params.get("parallel_tool_calls"),
|
parallel_tool_use=non_default_params.get("parallel_tool_calls"),
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
if _tool_choice is not None:
|
if _tool_choice is not None:
|
||||||
optional_params["tool_choice"] = _tool_choice
|
optional_params["tool_choice"] = _tool_choice
|
||||||
|
@ -341,7 +338,6 @@ class AnthropicConfig(BaseConfig):
|
||||||
if param == "top_p":
|
if param == "top_p":
|
||||||
optional_params["top_p"] = value
|
optional_params["top_p"] = value
|
||||||
if param == "response_format" and isinstance(value, dict):
|
if param == "response_format" and isinstance(value, dict):
|
||||||
|
|
||||||
ignore_response_format_types = ["text"]
|
ignore_response_format_types = ["text"]
|
||||||
if value["type"] in ignore_response_format_types: # value is a no-op
|
if value["type"] in ignore_response_format_types: # value is a no-op
|
||||||
continue
|
continue
|
||||||
|
@ -470,9 +466,9 @@ class AnthropicConfig(BaseConfig):
|
||||||
text=system_message_block["content"],
|
text=system_message_block["content"],
|
||||||
)
|
)
|
||||||
if "cache_control" in system_message_block:
|
if "cache_control" in system_message_block:
|
||||||
anthropic_system_message_content["cache_control"] = (
|
anthropic_system_message_content[
|
||||||
system_message_block["cache_control"]
|
"cache_control"
|
||||||
)
|
] = system_message_block["cache_control"]
|
||||||
anthropic_system_message_list.append(
|
anthropic_system_message_list.append(
|
||||||
anthropic_system_message_content
|
anthropic_system_message_content
|
||||||
)
|
)
|
||||||
|
@ -486,9 +482,9 @@ class AnthropicConfig(BaseConfig):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if "cache_control" in _content:
|
if "cache_control" in _content:
|
||||||
anthropic_system_message_content["cache_control"] = (
|
anthropic_system_message_content[
|
||||||
_content["cache_control"]
|
"cache_control"
|
||||||
)
|
] = _content["cache_control"]
|
||||||
|
|
||||||
anthropic_system_message_list.append(
|
anthropic_system_message_list.append(
|
||||||
anthropic_system_message_content
|
anthropic_system_message_content
|
||||||
|
@ -597,7 +593,9 @@ class AnthropicConfig(BaseConfig):
|
||||||
)
|
)
|
||||||
return _message
|
return _message
|
||||||
|
|
||||||
def extract_response_content(self, completion_response: dict) -> Tuple[
|
def extract_response_content(
|
||||||
|
self, completion_response: dict
|
||||||
|
) -> Tuple[
|
||||||
str,
|
str,
|
||||||
Optional[List[Any]],
|
Optional[List[Any]],
|
||||||
Optional[List[ChatCompletionThinkingBlock]],
|
Optional[List[ChatCompletionThinkingBlock]],
|
||||||
|
@ -693,9 +691,13 @@ class AnthropicConfig(BaseConfig):
|
||||||
reasoning_content: Optional[str] = None
|
reasoning_content: Optional[str] = None
|
||||||
tool_calls: List[ChatCompletionToolCallChunk] = []
|
tool_calls: List[ChatCompletionToolCallChunk] = []
|
||||||
|
|
||||||
text_content, citations, thinking_blocks, reasoning_content, tool_calls = (
|
(
|
||||||
self.extract_response_content(completion_response=completion_response)
|
text_content,
|
||||||
)
|
citations,
|
||||||
|
thinking_blocks,
|
||||||
|
reasoning_content,
|
||||||
|
tool_calls,
|
||||||
|
) = self.extract_response_content(completion_response=completion_response)
|
||||||
|
|
||||||
_message = litellm.Message(
|
_message = litellm.Message(
|
||||||
tool_calls=tool_calls,
|
tool_calls=tool_calls,
|
||||||
|
|
|
@ -54,9 +54,9 @@ class AnthropicTextConfig(BaseConfig):
|
||||||
to pass metadata to anthropic, it's {"user_id": "any-relevant-information"}
|
to pass metadata to anthropic, it's {"user_id": "any-relevant-information"}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
max_tokens_to_sample: Optional[int] = (
|
max_tokens_to_sample: Optional[
|
||||||
litellm.max_tokens
|
int
|
||||||
) # anthropic requires a default
|
] = litellm.max_tokens # anthropic requires a default
|
||||||
stop_sequences: Optional[list] = None
|
stop_sequences: Optional[list] = None
|
||||||
temperature: Optional[int] = None
|
temperature: Optional[int] = None
|
||||||
top_p: Optional[int] = None
|
top_p: Optional[int] = None
|
||||||
|
|
|
@ -25,7 +25,6 @@ from litellm.utils import ProviderConfigManager, client
|
||||||
|
|
||||||
|
|
||||||
class AnthropicMessagesHandler:
|
class AnthropicMessagesHandler:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _handle_anthropic_streaming(
|
async def _handle_anthropic_streaming(
|
||||||
response: httpx.Response,
|
response: httpx.Response,
|
||||||
|
@ -74,20 +73,23 @@ async def anthropic_messages(
|
||||||
"""
|
"""
|
||||||
# Use provided client or create a new one
|
# Use provided client or create a new one
|
||||||
optional_params = GenericLiteLLMParams(**kwargs)
|
optional_params = GenericLiteLLMParams(**kwargs)
|
||||||
model, _custom_llm_provider, dynamic_api_key, dynamic_api_base = (
|
(
|
||||||
litellm.get_llm_provider(
|
model,
|
||||||
|
_custom_llm_provider,
|
||||||
|
dynamic_api_key,
|
||||||
|
dynamic_api_base,
|
||||||
|
) = litellm.get_llm_provider(
|
||||||
model=model,
|
model=model,
|
||||||
custom_llm_provider=custom_llm_provider,
|
custom_llm_provider=custom_llm_provider,
|
||||||
api_base=optional_params.api_base,
|
api_base=optional_params.api_base,
|
||||||
api_key=optional_params.api_key,
|
api_key=optional_params.api_key,
|
||||||
)
|
)
|
||||||
)
|
anthropic_messages_provider_config: Optional[
|
||||||
anthropic_messages_provider_config: Optional[BaseAnthropicMessagesConfig] = (
|
BaseAnthropicMessagesConfig
|
||||||
ProviderConfigManager.get_provider_anthropic_messages_config(
|
] = ProviderConfigManager.get_provider_anthropic_messages_config(
|
||||||
model=model,
|
model=model,
|
||||||
provider=litellm.LlmProviders(_custom_llm_provider),
|
provider=litellm.LlmProviders(_custom_llm_provider),
|
||||||
)
|
)
|
||||||
)
|
|
||||||
if anthropic_messages_provider_config is None:
|
if anthropic_messages_provider_config is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Anthropic messages provider config not found for model: {model}"
|
f"Anthropic messages provider config not found for model: {model}"
|
||||||
|
|
|
@ -654,7 +654,6 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
|
||||||
) -> EmbeddingResponse:
|
) -> EmbeddingResponse:
|
||||||
response = None
|
response = None
|
||||||
try:
|
try:
|
||||||
|
|
||||||
openai_aclient = self.get_azure_openai_client(
|
openai_aclient = self.get_azure_openai_client(
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
@ -835,7 +834,6 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
|
||||||
"2023-10-01-preview",
|
"2023-10-01-preview",
|
||||||
]
|
]
|
||||||
): # CREATE + POLL for azure dall-e-2 calls
|
): # CREATE + POLL for azure dall-e-2 calls
|
||||||
|
|
||||||
api_base = modify_url(
|
api_base = modify_url(
|
||||||
original_url=api_base, new_path="/openai/images/generations:submit"
|
original_url=api_base, new_path="/openai/images/generations:submit"
|
||||||
)
|
)
|
||||||
|
@ -867,7 +865,6 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
|
||||||
)
|
)
|
||||||
while response.json()["status"] not in ["succeeded", "failed"]:
|
while response.json()["status"] not in ["succeeded", "failed"]:
|
||||||
if time.time() - start_time > timeout_secs:
|
if time.time() - start_time > timeout_secs:
|
||||||
|
|
||||||
raise AzureOpenAIError(
|
raise AzureOpenAIError(
|
||||||
status_code=408, message="Operation polling timed out."
|
status_code=408, message="Operation polling timed out."
|
||||||
)
|
)
|
||||||
|
@ -935,7 +932,6 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
|
||||||
"2023-10-01-preview",
|
"2023-10-01-preview",
|
||||||
]
|
]
|
||||||
): # CREATE + POLL for azure dall-e-2 calls
|
): # CREATE + POLL for azure dall-e-2 calls
|
||||||
|
|
||||||
api_base = modify_url(
|
api_base = modify_url(
|
||||||
original_url=api_base, new_path="/openai/images/generations:submit"
|
original_url=api_base, new_path="/openai/images/generations:submit"
|
||||||
)
|
)
|
||||||
|
@ -1199,7 +1195,6 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
|
||||||
client=None,
|
client=None,
|
||||||
litellm_params: Optional[dict] = None,
|
litellm_params: Optional[dict] = None,
|
||||||
) -> HttpxBinaryResponseContent:
|
) -> HttpxBinaryResponseContent:
|
||||||
|
|
||||||
max_retries = optional_params.pop("max_retries", 2)
|
max_retries = optional_params.pop("max_retries", 2)
|
||||||
|
|
||||||
if aspeech is not None and aspeech is True:
|
if aspeech is not None and aspeech is True:
|
||||||
|
@ -1253,7 +1248,6 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
|
||||||
client=None,
|
client=None,
|
||||||
litellm_params: Optional[dict] = None,
|
litellm_params: Optional[dict] = None,
|
||||||
) -> HttpxBinaryResponseContent:
|
) -> HttpxBinaryResponseContent:
|
||||||
|
|
||||||
azure_client: AsyncAzureOpenAI = self.get_azure_openai_client(
|
azure_client: AsyncAzureOpenAI = self.get_azure_openai_client(
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
|
|
|
@ -50,8 +50,9 @@ class AzureBatchesAPI(BaseAzureLLM):
|
||||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||||
litellm_params: Optional[dict] = None,
|
litellm_params: Optional[dict] = None,
|
||||||
) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
|
) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
|
||||||
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
azure_client: Optional[
|
||||||
self.get_azure_openai_client(
|
Union[AzureOpenAI, AsyncAzureOpenAI]
|
||||||
|
] = self.get_azure_openai_client(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
|
@ -59,7 +60,6 @@ class AzureBatchesAPI(BaseAzureLLM):
|
||||||
_is_async=_is_async,
|
_is_async=_is_async,
|
||||||
litellm_params=litellm_params or {},
|
litellm_params=litellm_params or {},
|
||||||
)
|
)
|
||||||
)
|
|
||||||
if azure_client is None:
|
if azure_client is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
||||||
|
@ -96,8 +96,9 @@ class AzureBatchesAPI(BaseAzureLLM):
|
||||||
client: Optional[AzureOpenAI] = None,
|
client: Optional[AzureOpenAI] = None,
|
||||||
litellm_params: Optional[dict] = None,
|
litellm_params: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
azure_client: Optional[
|
||||||
self.get_azure_openai_client(
|
Union[AzureOpenAI, AsyncAzureOpenAI]
|
||||||
|
] = self.get_azure_openai_client(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
|
@ -105,7 +106,6 @@ class AzureBatchesAPI(BaseAzureLLM):
|
||||||
_is_async=_is_async,
|
_is_async=_is_async,
|
||||||
litellm_params=litellm_params or {},
|
litellm_params=litellm_params or {},
|
||||||
)
|
)
|
||||||
)
|
|
||||||
if azure_client is None:
|
if azure_client is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
||||||
|
@ -144,8 +144,9 @@ class AzureBatchesAPI(BaseAzureLLM):
|
||||||
client: Optional[AzureOpenAI] = None,
|
client: Optional[AzureOpenAI] = None,
|
||||||
litellm_params: Optional[dict] = None,
|
litellm_params: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
azure_client: Optional[
|
||||||
self.get_azure_openai_client(
|
Union[AzureOpenAI, AsyncAzureOpenAI]
|
||||||
|
] = self.get_azure_openai_client(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
|
@ -153,7 +154,6 @@ class AzureBatchesAPI(BaseAzureLLM):
|
||||||
_is_async=_is_async,
|
_is_async=_is_async,
|
||||||
litellm_params=litellm_params or {},
|
litellm_params=litellm_params or {},
|
||||||
)
|
)
|
||||||
)
|
|
||||||
if azure_client is None:
|
if azure_client is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
||||||
|
@ -183,8 +183,9 @@ class AzureBatchesAPI(BaseAzureLLM):
|
||||||
client: Optional[AzureOpenAI] = None,
|
client: Optional[AzureOpenAI] = None,
|
||||||
litellm_params: Optional[dict] = None,
|
litellm_params: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
azure_client: Optional[
|
||||||
self.get_azure_openai_client(
|
Union[AzureOpenAI, AsyncAzureOpenAI]
|
||||||
|
] = self.get_azure_openai_client(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
|
@ -192,7 +193,6 @@ class AzureBatchesAPI(BaseAzureLLM):
|
||||||
_is_async=_is_async,
|
_is_async=_is_async,
|
||||||
litellm_params=litellm_params or {},
|
litellm_params=litellm_params or {},
|
||||||
)
|
)
|
||||||
)
|
|
||||||
if azure_client is None:
|
if azure_client is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
||||||
|
|
|
@ -306,7 +306,6 @@ class BaseAzureLLM(BaseOpenAILLM):
|
||||||
api_version: Optional[str],
|
api_version: Optional[str],
|
||||||
is_async: bool,
|
is_async: bool,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
||||||
azure_ad_token_provider: Optional[Callable[[], str]] = None
|
azure_ad_token_provider: Optional[Callable[[], str]] = None
|
||||||
# If we have api_key, then we have higher priority
|
# If we have api_key, then we have higher priority
|
||||||
azure_ad_token = litellm_params.get("azure_ad_token")
|
azure_ad_token = litellm_params.get("azure_ad_token")
|
||||||
|
|
|
@ -46,9 +46,9 @@ class AzureOpenAIFilesAPI(BaseAzureLLM):
|
||||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||||
litellm_params: Optional[dict] = None,
|
litellm_params: Optional[dict] = None,
|
||||||
) -> Union[FileObject, Coroutine[Any, Any, FileObject]]:
|
) -> Union[FileObject, Coroutine[Any, Any, FileObject]]:
|
||||||
|
openai_client: Optional[
|
||||||
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
Union[AzureOpenAI, AsyncAzureOpenAI]
|
||||||
self.get_azure_openai_client(
|
] = self.get_azure_openai_client(
|
||||||
litellm_params=litellm_params or {},
|
litellm_params=litellm_params or {},
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
@ -56,7 +56,6 @@ class AzureOpenAIFilesAPI(BaseAzureLLM):
|
||||||
client=client,
|
client=client,
|
||||||
_is_async=_is_async,
|
_is_async=_is_async,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
if openai_client is None:
|
if openai_client is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
"AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
||||||
|
@ -95,8 +94,9 @@ class AzureOpenAIFilesAPI(BaseAzureLLM):
|
||||||
) -> Union[
|
) -> Union[
|
||||||
HttpxBinaryResponseContent, Coroutine[Any, Any, HttpxBinaryResponseContent]
|
HttpxBinaryResponseContent, Coroutine[Any, Any, HttpxBinaryResponseContent]
|
||||||
]:
|
]:
|
||||||
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
openai_client: Optional[
|
||||||
self.get_azure_openai_client(
|
Union[AzureOpenAI, AsyncAzureOpenAI]
|
||||||
|
] = self.get_azure_openai_client(
|
||||||
litellm_params=litellm_params or {},
|
litellm_params=litellm_params or {},
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
@ -104,7 +104,6 @@ class AzureOpenAIFilesAPI(BaseAzureLLM):
|
||||||
client=client,
|
client=client,
|
||||||
_is_async=_is_async,
|
_is_async=_is_async,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
if openai_client is None:
|
if openai_client is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
"AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
||||||
|
@ -145,8 +144,9 @@ class AzureOpenAIFilesAPI(BaseAzureLLM):
|
||||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||||
litellm_params: Optional[dict] = None,
|
litellm_params: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
openai_client: Optional[
|
||||||
self.get_azure_openai_client(
|
Union[AzureOpenAI, AsyncAzureOpenAI]
|
||||||
|
] = self.get_azure_openai_client(
|
||||||
litellm_params=litellm_params or {},
|
litellm_params=litellm_params or {},
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
@ -154,7 +154,6 @@ class AzureOpenAIFilesAPI(BaseAzureLLM):
|
||||||
client=client,
|
client=client,
|
||||||
_is_async=_is_async,
|
_is_async=_is_async,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
if openai_client is None:
|
if openai_client is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
"AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
||||||
|
@ -197,8 +196,9 @@ class AzureOpenAIFilesAPI(BaseAzureLLM):
|
||||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||||
litellm_params: Optional[dict] = None,
|
litellm_params: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
openai_client: Optional[
|
||||||
self.get_azure_openai_client(
|
Union[AzureOpenAI, AsyncAzureOpenAI]
|
||||||
|
] = self.get_azure_openai_client(
|
||||||
litellm_params=litellm_params or {},
|
litellm_params=litellm_params or {},
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
@ -206,7 +206,6 @@ class AzureOpenAIFilesAPI(BaseAzureLLM):
|
||||||
client=client,
|
client=client,
|
||||||
_is_async=_is_async,
|
_is_async=_is_async,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
if openai_client is None:
|
if openai_client is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
"AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
||||||
|
@ -251,8 +250,9 @@ class AzureOpenAIFilesAPI(BaseAzureLLM):
|
||||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||||
litellm_params: Optional[dict] = None,
|
litellm_params: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
openai_client: Optional[
|
||||||
self.get_azure_openai_client(
|
Union[AzureOpenAI, AsyncAzureOpenAI]
|
||||||
|
] = self.get_azure_openai_client(
|
||||||
litellm_params=litellm_params or {},
|
litellm_params=litellm_params or {},
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
@ -260,7 +260,6 @@ class AzureOpenAIFilesAPI(BaseAzureLLM):
|
||||||
client=client,
|
client=client,
|
||||||
_is_async=_is_async,
|
_is_async=_is_async,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
if openai_client is None:
|
if openai_client is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
"AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
||||||
|
|
|
@ -25,14 +25,7 @@ class AzureOpenAIFineTuningAPI(OpenAIFineTuningAPI, BaseAzureLLM):
|
||||||
_is_async: bool = False,
|
_is_async: bool = False,
|
||||||
api_version: Optional[str] = None,
|
api_version: Optional[str] = None,
|
||||||
litellm_params: Optional[dict] = None,
|
litellm_params: Optional[dict] = None,
|
||||||
) -> Optional[
|
) -> Optional[Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI,]]:
|
||||||
Union[
|
|
||||||
OpenAI,
|
|
||||||
AsyncOpenAI,
|
|
||||||
AzureOpenAI,
|
|
||||||
AsyncAzureOpenAI,
|
|
||||||
]
|
|
||||||
]:
|
|
||||||
# Override to use Azure-specific client initialization
|
# Override to use Azure-specific client initialization
|
||||||
if isinstance(client, OpenAI) or isinstance(client, AsyncOpenAI):
|
if isinstance(client, OpenAI) or isinstance(client, AsyncOpenAI):
|
||||||
client = None
|
client = None
|
||||||
|
|
|
@ -145,7 +145,6 @@ class AzureAIStudioConfig(OpenAIConfig):
|
||||||
2. If message contains an image or audio, send as is (user-intended)
|
2. If message contains an image or audio, send as is (user-intended)
|
||||||
"""
|
"""
|
||||||
for message in messages:
|
for message in messages:
|
||||||
|
|
||||||
# Do nothing if the message contains an image or audio
|
# Do nothing if the message contains an image or audio
|
||||||
if _audio_or_image_in_message_content(message):
|
if _audio_or_image_in_message_content(message):
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -22,7 +22,6 @@ class AzureAICohereConfig:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _map_azure_model_group(self, model: str) -> str:
|
def _map_azure_model_group(self, model: str) -> str:
|
||||||
|
|
||||||
if model == "offer-cohere-embed-multili-paygo":
|
if model == "offer-cohere-embed-multili-paygo":
|
||||||
return "Cohere-embed-v3-multilingual"
|
return "Cohere-embed-v3-multilingual"
|
||||||
elif model == "offer-cohere-embed-english-paygo":
|
elif model == "offer-cohere-embed-english-paygo":
|
||||||
|
|
|
@ -17,7 +17,6 @@ from .cohere_transformation import AzureAICohereConfig
|
||||||
|
|
||||||
|
|
||||||
class AzureAIEmbedding(OpenAIChatCompletion):
|
class AzureAIEmbedding(OpenAIChatCompletion):
|
||||||
|
|
||||||
def _process_response(
|
def _process_response(
|
||||||
self,
|
self,
|
||||||
image_embedding_responses: Optional[List],
|
image_embedding_responses: Optional[List],
|
||||||
|
@ -145,7 +144,6 @@ class AzureAIEmbedding(OpenAIChatCompletion):
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
client=None,
|
client=None,
|
||||||
) -> EmbeddingResponse:
|
) -> EmbeddingResponse:
|
||||||
|
|
||||||
(
|
(
|
||||||
image_embeddings_request,
|
image_embeddings_request,
|
||||||
v1_embeddings_request,
|
v1_embeddings_request,
|
||||||
|
|
|
@ -17,6 +17,7 @@ class AzureAIRerankConfig(CohereRerankConfig):
|
||||||
"""
|
"""
|
||||||
Azure AI Rerank - Follows the same Spec as Cohere Rerank
|
Azure AI Rerank - Follows the same Spec as Cohere Rerank
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
|
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
|
||||||
if api_base is None:
|
if api_base is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
|
@ -9,7 +9,6 @@ from litellm.types.utils import ModelResponse, TextCompletionResponse
|
||||||
|
|
||||||
|
|
||||||
class BaseLLM:
|
class BaseLLM:
|
||||||
|
|
||||||
_client_session: Optional[httpx.Client] = None
|
_client_session: Optional[httpx.Client] = None
|
||||||
|
|
||||||
def process_response(
|
def process_response(
|
||||||
|
|
|
@ -218,7 +218,6 @@ class BaseConfig(ABC):
|
||||||
json_schema = value["json_schema"]["schema"]
|
json_schema = value["json_schema"]["schema"]
|
||||||
|
|
||||||
if json_schema and not is_response_format_supported:
|
if json_schema and not is_response_format_supported:
|
||||||
|
|
||||||
_tool_choice = ChatCompletionToolChoiceObjectParam(
|
_tool_choice = ChatCompletionToolChoiceObjectParam(
|
||||||
type="function",
|
type="function",
|
||||||
function=ChatCompletionToolChoiceFunctionParam(
|
function=ChatCompletionToolChoiceFunctionParam(
|
||||||
|
|
|
@ -58,7 +58,6 @@ class BaseResponsesAPIConfig(ABC):
|
||||||
model: str,
|
model: str,
|
||||||
drop_params: bool,
|
drop_params: bool,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|
|
@ -81,7 +81,6 @@ def make_sync_call(
|
||||||
|
|
||||||
|
|
||||||
class BedrockConverseLLM(BaseAWSLLM):
|
class BedrockConverseLLM(BaseAWSLLM):
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -114,7 +113,6 @@ class BedrockConverseLLM(BaseAWSLLM):
|
||||||
fake_stream: bool = False,
|
fake_stream: bool = False,
|
||||||
json_mode: Optional[bool] = False,
|
json_mode: Optional[bool] = False,
|
||||||
) -> CustomStreamWrapper:
|
) -> CustomStreamWrapper:
|
||||||
|
|
||||||
request_data = await litellm.AmazonConverseConfig()._async_transform_request(
|
request_data = await litellm.AmazonConverseConfig()._async_transform_request(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -179,7 +177,6 @@ class BedrockConverseLLM(BaseAWSLLM):
|
||||||
headers: dict = {},
|
headers: dict = {},
|
||||||
client: Optional[AsyncHTTPHandler] = None,
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||||
|
|
||||||
request_data = await litellm.AmazonConverseConfig()._async_transform_request(
|
request_data = await litellm.AmazonConverseConfig()._async_transform_request(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -265,7 +262,6 @@ class BedrockConverseLLM(BaseAWSLLM):
|
||||||
extra_headers: Optional[dict] = None,
|
extra_headers: Optional[dict] = None,
|
||||||
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
## SETUP ##
|
## SETUP ##
|
||||||
stream = optional_params.pop("stream", None)
|
stream = optional_params.pop("stream", None)
|
||||||
unencoded_model_id = optional_params.pop("model_id", None)
|
unencoded_model_id = optional_params.pop("model_id", None)
|
||||||
|
@ -301,9 +297,9 @@ class BedrockConverseLLM(BaseAWSLLM):
|
||||||
aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None)
|
aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None)
|
||||||
optional_params.pop("aws_region_name", None)
|
optional_params.pop("aws_region_name", None)
|
||||||
|
|
||||||
litellm_params["aws_region_name"] = (
|
litellm_params[
|
||||||
aws_region_name # [DO NOT DELETE] important for async calls
|
"aws_region_name"
|
||||||
)
|
] = aws_region_name # [DO NOT DELETE] important for async calls
|
||||||
|
|
||||||
credentials: Credentials = self.get_credentials(
|
credentials: Credentials = self.get_credentials(
|
||||||
aws_access_key_id=aws_access_key_id,
|
aws_access_key_id=aws_access_key_id,
|
||||||
|
|
|
@ -223,7 +223,6 @@ class AmazonConverseConfig(BaseConfig):
|
||||||
)
|
)
|
||||||
for param, value in non_default_params.items():
|
for param, value in non_default_params.items():
|
||||||
if param == "response_format" and isinstance(value, dict):
|
if param == "response_format" and isinstance(value, dict):
|
||||||
|
|
||||||
ignore_response_format_types = ["text"]
|
ignore_response_format_types = ["text"]
|
||||||
if value["type"] in ignore_response_format_types: # value is a no-op
|
if value["type"] in ignore_response_format_types: # value is a no-op
|
||||||
continue
|
continue
|
||||||
|
@ -715,9 +714,9 @@ class AmazonConverseConfig(BaseConfig):
|
||||||
chat_completion_message: ChatCompletionResponseMessage = {"role": "assistant"}
|
chat_completion_message: ChatCompletionResponseMessage = {"role": "assistant"}
|
||||||
content_str = ""
|
content_str = ""
|
||||||
tools: List[ChatCompletionToolCallChunk] = []
|
tools: List[ChatCompletionToolCallChunk] = []
|
||||||
reasoningContentBlocks: Optional[List[BedrockConverseReasoningContentBlock]] = (
|
reasoningContentBlocks: Optional[
|
||||||
None
|
List[BedrockConverseReasoningContentBlock]
|
||||||
)
|
] = None
|
||||||
|
|
||||||
if message is not None:
|
if message is not None:
|
||||||
for idx, content in enumerate(message["content"]):
|
for idx, content in enumerate(message["content"]):
|
||||||
|
@ -727,7 +726,6 @@ class AmazonConverseConfig(BaseConfig):
|
||||||
if "text" in content:
|
if "text" in content:
|
||||||
content_str += content["text"]
|
content_str += content["text"]
|
||||||
if "toolUse" in content:
|
if "toolUse" in content:
|
||||||
|
|
||||||
## check tool name was formatted by litellm
|
## check tool name was formatted by litellm
|
||||||
_response_tool_name = content["toolUse"]["name"]
|
_response_tool_name = content["toolUse"]["name"]
|
||||||
response_tool_name = get_bedrock_tool_name(
|
response_tool_name = get_bedrock_tool_name(
|
||||||
|
@ -754,12 +752,12 @@ class AmazonConverseConfig(BaseConfig):
|
||||||
chat_completion_message["provider_specific_fields"] = {
|
chat_completion_message["provider_specific_fields"] = {
|
||||||
"reasoningContentBlocks": reasoningContentBlocks,
|
"reasoningContentBlocks": reasoningContentBlocks,
|
||||||
}
|
}
|
||||||
chat_completion_message["reasoning_content"] = (
|
chat_completion_message[
|
||||||
self._transform_reasoning_content(reasoningContentBlocks)
|
"reasoning_content"
|
||||||
)
|
] = self._transform_reasoning_content(reasoningContentBlocks)
|
||||||
chat_completion_message["thinking_blocks"] = (
|
chat_completion_message[
|
||||||
self._transform_thinking_blocks(reasoningContentBlocks)
|
"thinking_blocks"
|
||||||
)
|
] = self._transform_thinking_blocks(reasoningContentBlocks)
|
||||||
chat_completion_message["content"] = content_str
|
chat_completion_message["content"] = content_str
|
||||||
if json_mode is True and tools is not None and len(tools) == 1:
|
if json_mode is True and tools is not None and len(tools) == 1:
|
||||||
# to support 'json_schema' logic on bedrock models
|
# to support 'json_schema' logic on bedrock models
|
||||||
|
|
|
@ -496,9 +496,9 @@ class BedrockLLM(BaseAWSLLM):
|
||||||
content=None,
|
content=None,
|
||||||
)
|
)
|
||||||
model_response.choices[0].message = _message # type: ignore
|
model_response.choices[0].message = _message # type: ignore
|
||||||
model_response._hidden_params["original_response"] = (
|
model_response._hidden_params[
|
||||||
outputText # allow user to access raw anthropic tool calling response
|
"original_response"
|
||||||
)
|
] = outputText # allow user to access raw anthropic tool calling response
|
||||||
if (
|
if (
|
||||||
_is_function_call is True
|
_is_function_call is True
|
||||||
and stream is not None
|
and stream is not None
|
||||||
|
@ -806,9 +806,9 @@ class BedrockLLM(BaseAWSLLM):
|
||||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||||
inference_params[k] = v
|
inference_params[k] = v
|
||||||
if stream is True:
|
if stream is True:
|
||||||
inference_params["stream"] = (
|
inference_params[
|
||||||
True # cohere requires stream = True in inference params
|
"stream"
|
||||||
)
|
] = True # cohere requires stream = True in inference params
|
||||||
data = json.dumps({"prompt": prompt, **inference_params})
|
data = json.dumps({"prompt": prompt, **inference_params})
|
||||||
elif provider == "anthropic":
|
elif provider == "anthropic":
|
||||||
if model.startswith("anthropic.claude-3"):
|
if model.startswith("anthropic.claude-3"):
|
||||||
|
@ -1205,7 +1205,6 @@ class BedrockLLM(BaseAWSLLM):
|
||||||
def get_response_stream_shape():
|
def get_response_stream_shape():
|
||||||
global _response_stream_shape_cache
|
global _response_stream_shape_cache
|
||||||
if _response_stream_shape_cache is None:
|
if _response_stream_shape_cache is None:
|
||||||
|
|
||||||
from botocore.loaders import Loader
|
from botocore.loaders import Loader
|
||||||
from botocore.model import ServiceModel
|
from botocore.model import ServiceModel
|
||||||
|
|
||||||
|
@ -1539,7 +1538,6 @@ class AmazonDeepSeekR1StreamDecoder(AWSEventStreamDecoder):
|
||||||
model: str,
|
model: str,
|
||||||
sync_stream: bool,
|
sync_stream: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
super().__init__(model=model)
|
super().__init__(model=model)
|
||||||
from litellm.llms.bedrock.chat.invoke_transformations.amazon_deepseek_transformation import (
|
from litellm.llms.bedrock.chat.invoke_transformations.amazon_deepseek_transformation import (
|
||||||
AmazonDeepseekR1ResponseIterator,
|
AmazonDeepseekR1ResponseIterator,
|
||||||
|
|
|
@ -225,9 +225,9 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
||||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||||
inference_params[k] = v
|
inference_params[k] = v
|
||||||
if stream is True:
|
if stream is True:
|
||||||
inference_params["stream"] = (
|
inference_params[
|
||||||
True # cohere requires stream = True in inference params
|
"stream"
|
||||||
)
|
] = True # cohere requires stream = True in inference params
|
||||||
request_data = {"prompt": prompt, **inference_params}
|
request_data = {"prompt": prompt, **inference_params}
|
||||||
elif provider == "anthropic":
|
elif provider == "anthropic":
|
||||||
return litellm.AmazonAnthropicClaude3Config().transform_request(
|
return litellm.AmazonAnthropicClaude3Config().transform_request(
|
||||||
|
@ -311,7 +311,6 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
json_mode: Optional[bool] = None,
|
json_mode: Optional[bool] = None,
|
||||||
) -> ModelResponse:
|
) -> ModelResponse:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
completion_response = raw_response.json()
|
completion_response = raw_response.json()
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
|
@ -314,7 +314,6 @@ def get_bedrock_tool_name(response_tool_name: str) -> str:
|
||||||
|
|
||||||
|
|
||||||
class BedrockModelInfo(BaseLLMModelInfo):
|
class BedrockModelInfo(BaseLLMModelInfo):
|
||||||
|
|
||||||
global_config = AmazonBedrockGlobalConfig()
|
global_config = AmazonBedrockGlobalConfig()
|
||||||
all_global_regions = global_config.get_all_regions()
|
all_global_regions = global_config.get_all_regions()
|
||||||
|
|
||||||
|
|
|
@ -33,9 +33,9 @@ class AmazonTitanMultimodalEmbeddingG1Config:
|
||||||
) -> dict:
|
) -> dict:
|
||||||
for k, v in non_default_params.items():
|
for k, v in non_default_params.items():
|
||||||
if k == "dimensions":
|
if k == "dimensions":
|
||||||
optional_params["embeddingConfig"] = (
|
optional_params[
|
||||||
AmazonTitanMultimodalEmbeddingConfig(outputEmbeddingLength=v)
|
"embeddingConfig"
|
||||||
)
|
] = AmazonTitanMultimodalEmbeddingConfig(outputEmbeddingLength=v)
|
||||||
return optional_params
|
return optional_params
|
||||||
|
|
||||||
def _transform_request(
|
def _transform_request(
|
||||||
|
@ -58,7 +58,6 @@ class AmazonTitanMultimodalEmbeddingG1Config:
|
||||||
def _transform_response(
|
def _transform_response(
|
||||||
self, response_list: List[dict], model: str
|
self, response_list: List[dict], model: str
|
||||||
) -> EmbeddingResponse:
|
) -> EmbeddingResponse:
|
||||||
|
|
||||||
total_prompt_tokens = 0
|
total_prompt_tokens = 0
|
||||||
transformed_responses: List[Embedding] = []
|
transformed_responses: List[Embedding] = []
|
||||||
for index, response in enumerate(response_list):
|
for index, response in enumerate(response_list):
|
||||||
|
|
|
@ -1,12 +1,16 @@
|
||||||
import types
|
import types
|
||||||
from typing import List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from openai.types.image import Image
|
from openai.types.image import Image
|
||||||
|
|
||||||
from litellm.types.llms.bedrock import (
|
from litellm.types.llms.bedrock import (
|
||||||
AmazonNovaCanvasTextToImageRequest, AmazonNovaCanvasTextToImageResponse,
|
AmazonNovaCanvasColorGuidedGenerationParams,
|
||||||
AmazonNovaCanvasTextToImageParams, AmazonNovaCanvasRequestBase, AmazonNovaCanvasColorGuidedGenerationParams,
|
|
||||||
AmazonNovaCanvasColorGuidedRequest,
|
AmazonNovaCanvasColorGuidedRequest,
|
||||||
|
AmazonNovaCanvasImageGenerationConfig,
|
||||||
|
AmazonNovaCanvasRequestBase,
|
||||||
|
AmazonNovaCanvasTextToImageParams,
|
||||||
|
AmazonNovaCanvasTextToImageRequest,
|
||||||
|
AmazonNovaCanvasTextToImageResponse,
|
||||||
)
|
)
|
||||||
from litellm.types.utils import ImageResponse
|
from litellm.types.utils import ImageResponse
|
||||||
|
|
||||||
|
@ -37,8 +41,7 @@ class AmazonNovaCanvasConfig:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_supported_openai_params(cls, model: Optional[str] = None) -> List:
|
def get_supported_openai_params(cls, model: Optional[str] = None) -> List:
|
||||||
"""
|
""" """
|
||||||
"""
|
|
||||||
return ["n", "size", "quality"]
|
return ["n", "size", "quality"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -65,18 +68,64 @@ class AmazonNovaCanvasConfig:
|
||||||
image_generation_config = optional_params.pop("imageGenerationConfig", {})
|
image_generation_config = optional_params.pop("imageGenerationConfig", {})
|
||||||
image_generation_config = {**image_generation_config, **optional_params}
|
image_generation_config = {**image_generation_config, **optional_params}
|
||||||
if task_type == "TEXT_IMAGE":
|
if task_type == "TEXT_IMAGE":
|
||||||
text_to_image_params = image_generation_config.pop("textToImageParams", {})
|
text_to_image_params: Dict[str, Any] = image_generation_config.pop(
|
||||||
text_to_image_params = {"text" :text, **text_to_image_params}
|
"textToImageParams", {}
|
||||||
text_to_image_params = AmazonNovaCanvasTextToImageParams(**text_to_image_params)
|
)
|
||||||
return AmazonNovaCanvasTextToImageRequest(textToImageParams=text_to_image_params, taskType=task_type,
|
text_to_image_params = {"text": text, **text_to_image_params}
|
||||||
imageGenerationConfig=image_generation_config)
|
try:
|
||||||
|
text_to_image_params_typed = AmazonNovaCanvasTextToImageParams(
|
||||||
|
**text_to_image_params # type: ignore
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"Error transforming text to image params: {e}. Got params: {text_to_image_params}, Expected params: {AmazonNovaCanvasTextToImageParams.__annotations__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
image_generation_config_typed = AmazonNovaCanvasImageGenerationConfig(
|
||||||
|
**image_generation_config
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"Error transforming image generation config: {e}. Got params: {image_generation_config}, Expected params: {AmazonNovaCanvasImageGenerationConfig.__annotations__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return AmazonNovaCanvasTextToImageRequest(
|
||||||
|
textToImageParams=text_to_image_params_typed,
|
||||||
|
taskType=task_type,
|
||||||
|
imageGenerationConfig=image_generation_config_typed,
|
||||||
|
)
|
||||||
if task_type == "COLOR_GUIDED_GENERATION":
|
if task_type == "COLOR_GUIDED_GENERATION":
|
||||||
color_guided_generation_params = image_generation_config.pop("colorGuidedGenerationParams", {})
|
color_guided_generation_params: Dict[
|
||||||
color_guided_generation_params = {"text": text, **color_guided_generation_params}
|
str, Any
|
||||||
color_guided_generation_params = AmazonNovaCanvasColorGuidedGenerationParams(**color_guided_generation_params)
|
] = image_generation_config.pop("colorGuidedGenerationParams", {})
|
||||||
return AmazonNovaCanvasColorGuidedRequest(taskType=task_type,
|
color_guided_generation_params = {
|
||||||
colorGuidedGenerationParams=color_guided_generation_params,
|
"text": text,
|
||||||
imageGenerationConfig=image_generation_config)
|
**color_guided_generation_params,
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
color_guided_generation_params_typed = AmazonNovaCanvasColorGuidedGenerationParams(
|
||||||
|
**color_guided_generation_params # type: ignore
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"Error transforming color guided generation params: {e}. Got params: {color_guided_generation_params}, Expected params: {AmazonNovaCanvasColorGuidedGenerationParams.__annotations__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
image_generation_config_typed = AmazonNovaCanvasImageGenerationConfig(
|
||||||
|
**image_generation_config
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"Error transforming image generation config: {e}. Got params: {image_generation_config}, Expected params: {AmazonNovaCanvasImageGenerationConfig.__annotations__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return AmazonNovaCanvasColorGuidedRequest(
|
||||||
|
taskType=task_type,
|
||||||
|
colorGuidedGenerationParams=color_guided_generation_params_typed,
|
||||||
|
imageGenerationConfig=image_generation_config_typed,
|
||||||
|
)
|
||||||
raise NotImplementedError(f"Task type {task_type} is not supported")
|
raise NotImplementedError(f"Task type {task_type} is not supported")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -87,7 +136,9 @@ class AmazonNovaCanvasConfig:
|
||||||
_size = non_default_params.get("size")
|
_size = non_default_params.get("size")
|
||||||
if _size is not None:
|
if _size is not None:
|
||||||
width, height = _size.split("x")
|
width, height = _size.split("x")
|
||||||
optional_params["width"], optional_params["height"] = int(width), int(height)
|
optional_params["width"], optional_params["height"] = int(width), int(
|
||||||
|
height
|
||||||
|
)
|
||||||
if non_default_params.get("n") is not None:
|
if non_default_params.get("n") is not None:
|
||||||
optional_params["numberOfImages"] = non_default_params.get("n")
|
optional_params["numberOfImages"] = non_default_params.get("n")
|
||||||
if non_default_params.get("quality") is not None:
|
if non_default_params.get("quality") is not None:
|
||||||
|
|
|
@ -267,7 +267,11 @@ class BedrockImageGeneration(BaseAWSLLM):
|
||||||
**inference_params,
|
**inference_params,
|
||||||
}
|
}
|
||||||
elif provider == "amazon":
|
elif provider == "amazon":
|
||||||
return dict(litellm.AmazonNovaCanvasConfig.transform_request_body(text=prompt, optional_params=optional_params))
|
return dict(
|
||||||
|
litellm.AmazonNovaCanvasConfig.transform_request_body(
|
||||||
|
text=prompt, optional_params=optional_params
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise BedrockError(
|
raise BedrockError(
|
||||||
status_code=422, message=f"Unsupported model={model}, passed in"
|
status_code=422, message=f"Unsupported model={model}, passed in"
|
||||||
|
@ -303,9 +307,12 @@ class BedrockImageGeneration(BaseAWSLLM):
|
||||||
config_class = (
|
config_class = (
|
||||||
litellm.AmazonStability3Config
|
litellm.AmazonStability3Config
|
||||||
if litellm.AmazonStability3Config._is_stability_3_model(model=model)
|
if litellm.AmazonStability3Config._is_stability_3_model(model=model)
|
||||||
else litellm.AmazonNovaCanvasConfig if litellm.AmazonNovaCanvasConfig._is_nova_model(model=model)
|
else (
|
||||||
|
litellm.AmazonNovaCanvasConfig
|
||||||
|
if litellm.AmazonNovaCanvasConfig._is_nova_model(model=model)
|
||||||
else litellm.AmazonStabilityConfig
|
else litellm.AmazonStabilityConfig
|
||||||
)
|
)
|
||||||
|
)
|
||||||
config_class.transform_response_dict_to_openai_response(
|
config_class.transform_response_dict_to_openai_response(
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
response_dict=response_dict,
|
response_dict=response_dict,
|
||||||
|
|
|
@ -60,7 +60,6 @@ class BedrockRerankHandler(BaseAWSLLM):
|
||||||
extra_headers: Optional[dict] = None,
|
extra_headers: Optional[dict] = None,
|
||||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||||
) -> RerankResponse:
|
) -> RerankResponse:
|
||||||
|
|
||||||
request_data = RerankRequest(
|
request_data = RerankRequest(
|
||||||
model=model,
|
model=model,
|
||||||
query=query,
|
query=query,
|
||||||
|
|
|
@ -29,7 +29,6 @@ from litellm.types.rerank import (
|
||||||
|
|
||||||
|
|
||||||
class BedrockRerankConfig:
|
class BedrockRerankConfig:
|
||||||
|
|
||||||
def _transform_sources(
|
def _transform_sources(
|
||||||
self, documents: List[Union[str, dict]]
|
self, documents: List[Union[str, dict]]
|
||||||
) -> List[BedrockRerankSource]:
|
) -> List[BedrockRerankSource]:
|
||||||
|
|
|
@ -314,7 +314,6 @@ class CodestralTextCompletion:
|
||||||
return _response
|
return _response
|
||||||
### SYNC COMPLETION
|
### SYNC COMPLETION
|
||||||
else:
|
else:
|
||||||
|
|
||||||
response = litellm.module_level_client.post(
|
response = litellm.module_level_client.post(
|
||||||
url=completion_url,
|
url=completion_url,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
|
@ -352,13 +351,11 @@ class CodestralTextCompletion:
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
headers={},
|
headers={},
|
||||||
) -> TextCompletionResponse:
|
) -> TextCompletionResponse:
|
||||||
|
|
||||||
async_handler = get_async_httpx_client(
|
async_handler = get_async_httpx_client(
|
||||||
llm_provider=litellm.LlmProviders.TEXT_COMPLETION_CODESTRAL,
|
llm_provider=litellm.LlmProviders.TEXT_COMPLETION_CODESTRAL,
|
||||||
params={"timeout": timeout},
|
params={"timeout": timeout},
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
|
|
||||||
response = await async_handler.post(
|
response = await async_handler.post(
|
||||||
api_base, headers=headers, data=json.dumps(data)
|
api_base, headers=headers, data=json.dumps(data)
|
||||||
)
|
)
|
||||||
|
|
|
@ -78,7 +78,6 @@ class CodestralTextCompletionConfig(OpenAITextCompletionConfig):
|
||||||
return optional_params
|
return optional_params
|
||||||
|
|
||||||
def _chunk_parser(self, chunk_data: str) -> GenericStreamingChunk:
|
def _chunk_parser(self, chunk_data: str) -> GenericStreamingChunk:
|
||||||
|
|
||||||
text = ""
|
text = ""
|
||||||
is_finished = False
|
is_finished = False
|
||||||
finish_reason = None
|
finish_reason = None
|
||||||
|
|
|
@ -180,7 +180,6 @@ class CohereChatConfig(BaseConfig):
|
||||||
litellm_params: dict,
|
litellm_params: dict,
|
||||||
headers: dict,
|
headers: dict,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
||||||
## Load Config
|
## Load Config
|
||||||
for k, v in litellm.CohereChatConfig.get_config().items():
|
for k, v in litellm.CohereChatConfig.get_config().items():
|
||||||
if (
|
if (
|
||||||
|
@ -222,7 +221,6 @@ class CohereChatConfig(BaseConfig):
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
json_mode: Optional[bool] = None,
|
json_mode: Optional[bool] = None,
|
||||||
) -> ModelResponse:
|
) -> ModelResponse:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
raw_response_json = raw_response.json()
|
raw_response_json = raw_response.json()
|
||||||
model_response.choices[0].message.content = raw_response_json["text"] # type: ignore
|
model_response.choices[0].message.content = raw_response_json["text"] # type: ignore
|
||||||
|
|
|
@ -56,7 +56,6 @@ async def async_embedding(
|
||||||
encoding: Callable,
|
encoding: Callable,
|
||||||
client: Optional[AsyncHTTPHandler] = None,
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=input,
|
input=input,
|
||||||
|
|
|
@ -72,7 +72,6 @@ class CohereEmbeddingConfig:
|
||||||
return transformed_request
|
return transformed_request
|
||||||
|
|
||||||
def _calculate_usage(self, input: List[str], encoding: Any, meta: dict) -> Usage:
|
def _calculate_usage(self, input: List[str], encoding: Any, meta: dict) -> Usage:
|
||||||
|
|
||||||
input_tokens = 0
|
input_tokens = 0
|
||||||
|
|
||||||
text_tokens: Optional[int] = meta.get("billed_units", {}).get("input_tokens")
|
text_tokens: Optional[int] = meta.get("billed_units", {}).get("input_tokens")
|
||||||
|
@ -111,7 +110,6 @@ class CohereEmbeddingConfig:
|
||||||
encoding: Any,
|
encoding: Any,
|
||||||
input: list,
|
input: list,
|
||||||
) -> EmbeddingResponse:
|
) -> EmbeddingResponse:
|
||||||
|
|
||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
|
|
|
@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Union
|
||||||
from litellm.llms.cohere.rerank.transformation import CohereRerankConfig
|
from litellm.llms.cohere.rerank.transformation import CohereRerankConfig
|
||||||
from litellm.types.rerank import OptionalRerankParams, RerankRequest
|
from litellm.types.rerank import OptionalRerankParams, RerankRequest
|
||||||
|
|
||||||
|
|
||||||
class CohereRerankV2Config(CohereRerankConfig):
|
class CohereRerankV2Config(CohereRerankConfig):
|
||||||
"""
|
"""
|
||||||
Reference: https://docs.cohere.com/v2/reference/rerank
|
Reference: https://docs.cohere.com/v2/reference/rerank
|
||||||
|
|
|
@ -32,7 +32,6 @@ DEFAULT_TIMEOUT = 600
|
||||||
|
|
||||||
|
|
||||||
class BaseLLMAIOHTTPHandler:
|
class BaseLLMAIOHTTPHandler:
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.client_session: Optional[aiohttp.ClientSession] = None
|
self.client_session: Optional[aiohttp.ClientSession] = None
|
||||||
|
|
||||||
|
@ -110,7 +109,6 @@ class BaseLLMAIOHTTPHandler:
|
||||||
content: Any = None,
|
content: Any = None,
|
||||||
params: Optional[dict] = None,
|
params: Optional[dict] = None,
|
||||||
) -> httpx.Response:
|
) -> httpx.Response:
|
||||||
|
|
||||||
max_retry_on_unprocessable_entity_error = (
|
max_retry_on_unprocessable_entity_error = (
|
||||||
provider_config.max_retry_on_unprocessable_entity_error
|
provider_config.max_retry_on_unprocessable_entity_error
|
||||||
)
|
)
|
||||||
|
|
|
@ -114,7 +114,6 @@ class AsyncHTTPHandler:
|
||||||
event_hooks: Optional[Mapping[str, List[Callable[..., Any]]]],
|
event_hooks: Optional[Mapping[str, List[Callable[..., Any]]]],
|
||||||
ssl_verify: Optional[VerifyTypes] = None,
|
ssl_verify: Optional[VerifyTypes] = None,
|
||||||
) -> httpx.AsyncClient:
|
) -> httpx.AsyncClient:
|
||||||
|
|
||||||
# SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
|
# SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
|
||||||
# /path/to/certificate.pem
|
# /path/to/certificate.pem
|
||||||
if ssl_verify is None:
|
if ssl_verify is None:
|
||||||
|
@ -590,7 +589,6 @@ class HTTPHandler:
|
||||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
|
|
||||||
if timeout is not None:
|
if timeout is not None:
|
||||||
req = self.client.build_request(
|
req = self.client.build_request(
|
||||||
"PATCH", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore
|
"PATCH", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore
|
||||||
|
@ -609,7 +607,6 @@ class HTTPHandler:
|
||||||
llm_provider="litellm-httpx-handler",
|
llm_provider="litellm-httpx-handler",
|
||||||
)
|
)
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
|
|
||||||
if stream is True:
|
if stream is True:
|
||||||
setattr(e, "message", mask_sensitive_info(e.response.read()))
|
setattr(e, "message", mask_sensitive_info(e.response.read()))
|
||||||
setattr(e, "text", mask_sensitive_info(e.response.read()))
|
setattr(e, "text", mask_sensitive_info(e.response.read()))
|
||||||
|
@ -635,7 +632,6 @@ class HTTPHandler:
|
||||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
|
|
||||||
if timeout is not None:
|
if timeout is not None:
|
||||||
req = self.client.build_request(
|
req = self.client.build_request(
|
||||||
"PUT", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore
|
"PUT", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore
|
||||||
|
|
|
@ -41,7 +41,6 @@ else:
|
||||||
|
|
||||||
|
|
||||||
class BaseLLMHTTPHandler:
|
class BaseLLMHTTPHandler:
|
||||||
|
|
||||||
async def _make_common_async_call(
|
async def _make_common_async_call(
|
||||||
self,
|
self,
|
||||||
async_httpx_client: AsyncHTTPHandler,
|
async_httpx_client: AsyncHTTPHandler,
|
||||||
|
@ -109,7 +108,6 @@ class BaseLLMHTTPHandler:
|
||||||
logging_obj: LiteLLMLoggingObj,
|
logging_obj: LiteLLMLoggingObj,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
) -> httpx.Response:
|
) -> httpx.Response:
|
||||||
|
|
||||||
max_retry_on_unprocessable_entity_error = (
|
max_retry_on_unprocessable_entity_error = (
|
||||||
provider_config.max_retry_on_unprocessable_entity_error
|
provider_config.max_retry_on_unprocessable_entity_error
|
||||||
)
|
)
|
||||||
|
@ -599,7 +597,6 @@ class BaseLLMHTTPHandler:
|
||||||
aembedding: bool = False,
|
aembedding: bool = False,
|
||||||
headers={},
|
headers={},
|
||||||
) -> EmbeddingResponse:
|
) -> EmbeddingResponse:
|
||||||
|
|
||||||
provider_config = ProviderConfigManager.get_provider_embedding_config(
|
provider_config = ProviderConfigManager.get_provider_embedding_config(
|
||||||
model=model, provider=litellm.LlmProviders(custom_llm_provider)
|
model=model, provider=litellm.LlmProviders(custom_llm_provider)
|
||||||
)
|
)
|
||||||
|
@ -742,7 +739,6 @@ class BaseLLMHTTPHandler:
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||||
) -> RerankResponse:
|
) -> RerankResponse:
|
||||||
|
|
||||||
# get config from model, custom llm provider
|
# get config from model, custom llm provider
|
||||||
headers = provider_config.validate_environment(
|
headers = provider_config.validate_environment(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
@ -828,7 +824,6 @@ class BaseLLMHTTPHandler:
|
||||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||||
) -> RerankResponse:
|
) -> RerankResponse:
|
||||||
|
|
||||||
if client is None or not isinstance(client, AsyncHTTPHandler):
|
if client is None or not isinstance(client, AsyncHTTPHandler):
|
||||||
async_httpx_client = get_async_httpx_client(
|
async_httpx_client = get_async_httpx_client(
|
||||||
llm_provider=litellm.LlmProviders(custom_llm_provider)
|
llm_provider=litellm.LlmProviders(custom_llm_provider)
|
||||||
|
|
|
@ -16,9 +16,9 @@ class DatabricksBase:
|
||||||
api_base = api_base or f"{databricks_client.config.host}/serving-endpoints"
|
api_base = api_base or f"{databricks_client.config.host}/serving-endpoints"
|
||||||
|
|
||||||
if api_key is None:
|
if api_key is None:
|
||||||
databricks_auth_headers: dict[str, str] = (
|
databricks_auth_headers: dict[
|
||||||
databricks_client.config.authenticate()
|
str, str
|
||||||
)
|
] = databricks_client.config.authenticate()
|
||||||
headers = {**databricks_auth_headers, **headers}
|
headers = {**databricks_auth_headers, **headers}
|
||||||
|
|
||||||
return api_base, headers
|
return api_base, headers
|
||||||
|
|
|
@ -11,9 +11,9 @@ class DatabricksEmbeddingConfig:
|
||||||
Reference: https://learn.microsoft.com/en-us/azure/databricks/machine-learning/foundation-models/api-reference#--embedding-task
|
Reference: https://learn.microsoft.com/en-us/azure/databricks/machine-learning/foundation-models/api-reference#--embedding-task
|
||||||
"""
|
"""
|
||||||
|
|
||||||
instruction: Optional[str] = (
|
instruction: Optional[
|
||||||
None # An optional instruction to pass to the embedding model. BGE Authors recommend 'Represent this sentence for searching relevant passages:' for retrieval queries
|
str
|
||||||
)
|
] = None # An optional instruction to pass to the embedding model. BGE Authors recommend 'Represent this sentence for searching relevant passages:' for retrieval queries
|
||||||
|
|
||||||
def __init__(self, instruction: Optional[str] = None) -> None:
|
def __init__(self, instruction: Optional[str] = None) -> None:
|
||||||
locals_ = locals().copy()
|
locals_ = locals().copy()
|
||||||
|
|
|
@ -55,7 +55,6 @@ class ModelResponseIterator:
|
||||||
|
|
||||||
usage_chunk: Optional[Usage] = getattr(processed_chunk, "usage", None)
|
usage_chunk: Optional[Usage] = getattr(processed_chunk, "usage", None)
|
||||||
if usage_chunk is not None:
|
if usage_chunk is not None:
|
||||||
|
|
||||||
usage = ChatCompletionUsageBlock(
|
usage = ChatCompletionUsageBlock(
|
||||||
prompt_tokens=usage_chunk.prompt_tokens,
|
prompt_tokens=usage_chunk.prompt_tokens,
|
||||||
completion_tokens=usage_chunk.completion_tokens,
|
completion_tokens=usage_chunk.completion_tokens,
|
||||||
|
|
|
@ -126,9 +126,9 @@ class DeepgramAudioTranscriptionConfig(BaseAudioTranscriptionConfig):
|
||||||
|
|
||||||
# Add additional metadata matching OpenAI format
|
# Add additional metadata matching OpenAI format
|
||||||
response["task"] = "transcribe"
|
response["task"] = "transcribe"
|
||||||
response["language"] = (
|
response[
|
||||||
"english" # Deepgram auto-detects but doesn't return language
|
"language"
|
||||||
)
|
] = "english" # Deepgram auto-detects but doesn't return language
|
||||||
response["duration"] = response_json["metadata"]["duration"]
|
response["duration"] = response_json["metadata"]["duration"]
|
||||||
|
|
||||||
# Transform words to match OpenAI format
|
# Transform words to match OpenAI format
|
||||||
|
|
|
@ -14,7 +14,6 @@ from ...openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||||
|
|
||||||
|
|
||||||
class DeepSeekChatConfig(OpenAIGPTConfig):
|
class DeepSeekChatConfig(OpenAIGPTConfig):
|
||||||
|
|
||||||
def _transform_messages(
|
def _transform_messages(
|
||||||
self, messages: List[AllMessageValues], model: str
|
self, messages: List[AllMessageValues], model: str
|
||||||
) -> List[AllMessageValues]:
|
) -> List[AllMessageValues]:
|
||||||
|
|
|
@ -77,9 +77,9 @@ class AlephAlphaConfig:
|
||||||
- `control_log_additive` (boolean; default value: true): Method of applying control to attention scores.
|
- `control_log_additive` (boolean; default value: true): Method of applying control to attention scores.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
maximum_tokens: Optional[int] = (
|
maximum_tokens: Optional[
|
||||||
litellm.max_tokens
|
int
|
||||||
) # aleph alpha requires max tokens
|
] = litellm.max_tokens # aleph alpha requires max tokens
|
||||||
minimum_tokens: Optional[int] = None
|
minimum_tokens: Optional[int] = None
|
||||||
echo: Optional[bool] = None
|
echo: Optional[bool] = None
|
||||||
temperature: Optional[int] = None
|
temperature: Optional[int] = None
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue