Add pyright to ci/cd + Fix remaining type-checking errors (#6082)

* fix: fix type-checking errors

* fix: fix additional type-checking errors

* fix: additional type-checking error fixes

* fix: fix additional type-checking errors

* fix: additional type-check fixes

* fix: fix all type-checking errors + add pyright to ci/cd

* fix: fix incorrect import

* ci(config.yml): use mypy on ci/cd

* fix: fix type-checking errors in utils.py

* fix: fix all type-checking errors on main.py

* fix: fix mypy linting errors

* fix(anthropic/cost_calculator.py): fix linting errors

* fix: fix mypy linting errors

* fix: fix linting errors
This commit is contained in:
Krish Dholakia 2024-10-05 17:04:00 -04:00 committed by GitHub
parent f7ce1173f3
commit fac3b2ee42
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
65 changed files with 619 additions and 522 deletions

View file

@ -315,11 +315,11 @@ jobs:
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install ruff pip install ruff
pip install pylint pip install pylint
pip install pyright
pip install . pip install .
- run: python -c "from litellm import *" || (echo '🚨 import failed, this means you introduced unprotected imports! 🚨'; exit 1) - run: python -c "from litellm import *" || (echo '🚨 import failed, this means you introduced unprotected imports! 🚨'; exit 1)
- run: ruff check ./litellm - run: ruff check ./litellm
build_and_test: build_and_test:
machine: machine:
image: ubuntu-2204:2023.10.1 image: ubuntu-2204:2023.10.1

View file

@ -8,7 +8,7 @@ import requests
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache from litellm.caching import DualCache
from typing import Literal, Union from typing import Literal, Union, Optional
import traceback import traceback
@ -26,9 +26,9 @@ from litellm._logging import print_verbose, verbose_logger
class GenericAPILogger: class GenericAPILogger:
# Class variables or attributes # Class variables or attributes
def __init__(self, endpoint=None, headers=None): def __init__(self, endpoint: Optional[str] = None, headers: Optional[dict] = None):
try: try:
if endpoint == None: if endpoint is None:
# check env for "GENERIC_LOGGER_ENDPOINT" # check env for "GENERIC_LOGGER_ENDPOINT"
if os.getenv("GENERIC_LOGGER_ENDPOINT"): if os.getenv("GENERIC_LOGGER_ENDPOINT"):
# Do something with the endpoint # Do something with the endpoint
@ -36,9 +36,15 @@ class GenericAPILogger:
else: else:
# Handle the case when the endpoint is not found in the environment variables # Handle the case when the endpoint is not found in the environment variables
raise ValueError( raise ValueError(
f"endpoint not set for GenericAPILogger, GENERIC_LOGGER_ENDPOINT not found in environment variables" "endpoint not set for GenericAPILogger, GENERIC_LOGGER_ENDPOINT not found in environment variables"
) )
headers = headers or litellm.generic_logger_headers headers = headers or litellm.generic_logger_headers
if endpoint is None:
raise ValueError("endpoint not set for GenericAPILogger")
if headers is None:
raise ValueError("headers not set for GenericAPILogger")
self.endpoint = endpoint self.endpoint = endpoint
self.headers = headers self.headers = headers

View file

@ -48,8 +48,6 @@ class AporiaGuardrail(CustomGuardrail):
) )
self.aporia_api_key = api_key or os.environ["APORIO_API_KEY"] self.aporia_api_key = api_key or os.environ["APORIO_API_KEY"]
self.aporia_api_base = api_base or os.environ["APORIO_API_BASE"] self.aporia_api_base = api_base or os.environ["APORIO_API_BASE"]
self.event_hook: GuardrailEventHooks
super().__init__(**kwargs) super().__init__(**kwargs)
#### CALL HOOKS - proxy only #### #### CALL HOOKS - proxy only ####

View file

@ -84,7 +84,7 @@ class _ENTERPRISE_BlockedUserList(CustomLogger):
) )
cache_key = f"litellm:end_user_id:{user}" cache_key = f"litellm:end_user_id:{user}"
end_user_cache_obj: LiteLLM_EndUserTable = cache.get_cache( end_user_cache_obj: Optional[LiteLLM_EndUserTable] = cache.get_cache( # type: ignore
key=cache_key key=cache_key
) )
if end_user_cache_obj is None and self.prisma_client is not None: if end_user_cache_obj is None and self.prisma_client is not None:

View file

@ -48,7 +48,7 @@ class _ENTERPRISE_GoogleTextModeration(CustomLogger):
# Class variables or attributes # Class variables or attributes
def __init__(self): def __init__(self):
try: try:
from google.cloud import language_v1 from google.cloud import language_v1 # type: ignore
except Exception: except Exception:
raise Exception( raise Exception(
"Missing google.cloud package. Run `pip install --upgrade google-cloud-language`" "Missing google.cloud package. Run `pip install --upgrade google-cloud-language`"
@ -57,8 +57,8 @@ class _ENTERPRISE_GoogleTextModeration(CustomLogger):
# Instantiates a client # Instantiates a client
self.client = language_v1.LanguageServiceClient() self.client = language_v1.LanguageServiceClient()
self.moderate_text_request = language_v1.ModerateTextRequest self.moderate_text_request = language_v1.ModerateTextRequest
self.language_document = language_v1.types.Document self.language_document = language_v1.types.Document # type: ignore
self.document_type = language_v1.types.Document.Type.PLAIN_TEXT self.document_type = language_v1.types.Document.Type.PLAIN_TEXT # type: ignore
default_confidence_threshold = ( default_confidence_threshold = (
litellm.google_moderation_confidence_threshold or 0.8 litellm.google_moderation_confidence_threshold or 0.8

View file

@ -8,6 +8,7 @@
# Thank you users! We ❤️ you! - Krrish & Ishaan # Thank you users! We ❤️ you! - Krrish & Ishaan
import sys, os import sys, os
from collections.abc import Iterable
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
@ -19,11 +20,12 @@ from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from fastapi import HTTPException from fastapi import HTTPException
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.utils import ( from litellm.types.utils import (
ModelResponse, ModelResponse,
EmbeddingResponse, EmbeddingResponse,
ImageResponse, ImageResponse,
StreamingChoices, StreamingChoices,
Choices,
) )
from datetime import datetime from datetime import datetime
import aiohttp, asyncio import aiohttp, asyncio
@ -34,7 +36,10 @@ litellm.set_verbose = True
class _ENTERPRISE_LlamaGuard(CustomLogger): class _ENTERPRISE_LlamaGuard(CustomLogger):
# Class variables or attributes # Class variables or attributes
def __init__(self, model_name: Optional[str] = None): def __init__(self, model_name: Optional[str] = None):
self.model = model_name or litellm.llamaguard_model_name _model = model_name or litellm.llamaguard_model_name
if _model is None:
raise ValueError("model_name not set for LlamaGuard")
self.model = _model
file_path = litellm.llamaguard_unsafe_content_categories file_path = litellm.llamaguard_unsafe_content_categories
data = None data = None
@ -124,7 +129,13 @@ class _ENTERPRISE_LlamaGuard(CustomLogger):
hf_model_name="meta-llama/LlamaGuard-7b", hf_model_name="meta-llama/LlamaGuard-7b",
) )
if "unsafe" in response.choices[0].message.content: if (
isinstance(response, ModelResponse)
and isinstance(response.choices[0], Choices)
and response.choices[0].message.content is not None
and isinstance(response.choices[0].message.content, Iterable)
and "unsafe" in response.choices[0].message.content
):
raise HTTPException( raise HTTPException(
status_code=400, detail={"error": "Violated content safety policy"} status_code=400, detail={"error": "Violated content safety policy"}
) )

View file

@ -8,7 +8,11 @@
## This provides an LLM Guard Integration for content moderation on the proxy ## This provides an LLM Guard Integration for content moderation on the proxy
from typing import Optional, Literal, Union from typing import Optional, Literal, Union
import litellm, traceback, sys, uuid, os import litellm
import traceback
import sys
import uuid
import os
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
@ -21,8 +25,10 @@ from litellm.utils import (
StreamingChoices, StreamingChoices,
) )
from datetime import datetime from datetime import datetime
import aiohttp, asyncio import aiohttp
import asyncio
from litellm.utils import get_formatted_prompt from litellm.utils import get_formatted_prompt
from litellm.secret_managers.main import get_secret_str
litellm.set_verbose = True litellm.set_verbose = True
@ -38,7 +44,7 @@ class _ENTERPRISE_LLMGuard(CustomLogger):
self.llm_guard_mode = litellm.llm_guard_mode self.llm_guard_mode = litellm.llm_guard_mode
if mock_testing == True: # for testing purposes only if mock_testing == True: # for testing purposes only
return return
self.llm_guard_api_base = litellm.get_secret("LLM_GUARD_API_BASE", None) self.llm_guard_api_base = get_secret_str("LLM_GUARD_API_BASE", None)
if self.llm_guard_api_base is None: if self.llm_guard_api_base is None:
raise Exception("Missing `LLM_GUARD_API_BASE` from environment") raise Exception("Missing `LLM_GUARD_API_BASE` from environment")
elif not self.llm_guard_api_base.endswith("/"): elif not self.llm_guard_api_base.endswith("/"):

View file

@ -51,8 +51,8 @@ class _ENTERPRISE_OpenAI_Moderation(CustomLogger):
"audio_transcription", "audio_transcription",
], ],
): ):
if "messages" in data and isinstance(data["messages"], list):
text = "" text = ""
if "messages" in data and isinstance(data["messages"], list):
for m in data["messages"]: # assume messages is a list for m in data["messages"]: # assume messages is a list
if "content" in m and isinstance(m["content"], str): if "content" in m and isinstance(m["content"], str):
text += m["content"] text += m["content"]
@ -67,7 +67,7 @@ class _ENTERPRISE_OpenAI_Moderation(CustomLogger):
) )
verbose_proxy_logger.debug("Moderation response: %s", moderation_response) verbose_proxy_logger.debug("Moderation response: %s", moderation_response)
if moderation_response.results[0].flagged == True: if moderation_response.results[0].flagged is True:
raise HTTPException( raise HTTPException(
status_code=403, detail={"error": "Violated content safety policy"} status_code=403, detail={"error": "Violated content safety policy"}
) )

View file

@ -6,7 +6,9 @@ import collections
from datetime import datetime from datetime import datetime
async def get_spend_by_tags(start_date=None, end_date=None, prisma_client=None): async def get_spend_by_tags(
prisma_client: PrismaClient, start_date=None, end_date=None
):
response = await prisma_client.db.query_raw( response = await prisma_client.db.query_raw(
""" """
SELECT SELECT

View file

@ -191,7 +191,7 @@ def init_redis_cluster(redis_kwargs) -> redis.RedisCluster:
new_startup_nodes.append(ClusterNode(**item)) new_startup_nodes.append(ClusterNode(**item))
redis_kwargs.pop("startup_nodes") redis_kwargs.pop("startup_nodes")
return redis.RedisCluster(startup_nodes=new_startup_nodes, **cluster_kwargs) return redis.RedisCluster(startup_nodes=new_startup_nodes, **cluster_kwargs) # type: ignore
def _init_redis_sentinel(redis_kwargs) -> redis.Redis: def _init_redis_sentinel(redis_kwargs) -> redis.Redis:

View file

@ -1,5 +1,8 @@
import litellm
from typing import Optional, Union from typing import Optional, Union
import litellm
from ..exceptions import UnsupportedParamsError
from ..types.llms.openai import * from ..types.llms.openai import *

View file

@ -10,6 +10,7 @@
import ast import ast
import asyncio import asyncio
import hashlib import hashlib
import inspect
import io import io
import json import json
import logging import logging
@ -245,7 +246,8 @@ class RedisCache(BaseCache):
self.redis_flush_size = redis_flush_size self.redis_flush_size = redis_flush_size
self.redis_version = "Unknown" self.redis_version = "Unknown"
try: try:
self.redis_version = self.redis_client.info()["redis_version"] if not inspect.iscoroutinefunction(self.redis_client):
self.redis_version = self.redis_client.info()["redis_version"] # type: ignore
except Exception: except Exception:
pass pass
@ -266,7 +268,8 @@ class RedisCache(BaseCache):
### SYNC HEALTH PING ### ### SYNC HEALTH PING ###
try: try:
self.redis_client.ping() if hasattr(self.redis_client, "ping"):
self.redis_client.ping() # type: ignore
except Exception as e: except Exception as e:
verbose_logger.error( verbose_logger.error(
"Error connecting to Sync Redis client", extra={"error": str(e)} "Error connecting to Sync Redis client", extra={"error": str(e)}
@ -308,7 +311,7 @@ class RedisCache(BaseCache):
_redis_client = self.redis_client _redis_client = self.redis_client
start_time = time.time() start_time = time.time()
try: try:
result = _redis_client.incr(name=key, amount=value) result: int = _redis_client.incr(name=key, amount=value) # type: ignore
if ttl is not None: if ttl is not None:
# check if key already has ttl, if not -> set ttl # check if key already has ttl, if not -> set ttl
@ -561,7 +564,7 @@ class RedisCache(BaseCache):
f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}" f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}"
) )
try: try:
await redis_client.sadd(key, *value) await redis_client.sadd(key, *value) # type: ignore
if ttl is not None: if ttl is not None:
_td = timedelta(seconds=ttl) _td = timedelta(seconds=ttl)
await redis_client.expire(key, _td) await redis_client.expire(key, _td)
@ -712,7 +715,7 @@ class RedisCache(BaseCache):
for cache_key in key_list: for cache_key in key_list:
cache_key = self.check_and_fix_namespace(key=cache_key) cache_key = self.check_and_fix_namespace(key=cache_key)
_keys.append(cache_key) _keys.append(cache_key)
results = self.redis_client.mget(keys=_keys) results: List = self.redis_client.mget(keys=_keys) # type: ignore
# Associate the results back with their keys. # Associate the results back with their keys.
# 'results' is a list of values corresponding to the order of keys in 'key_list'. # 'results' is a list of values corresponding to the order of keys in 'key_list'.
@ -842,7 +845,7 @@ class RedisCache(BaseCache):
print_verbose("Pinging Sync Redis Cache") print_verbose("Pinging Sync Redis Cache")
start_time = time.time() start_time = time.time()
try: try:
response = self.redis_client.ping() response: bool = self.redis_client.ping() # type: ignore
print_verbose(f"Redis Cache PING: {response}") print_verbose(f"Redis Cache PING: {response}")
## LOGGING ## ## LOGGING ##
end_time = time.time() end_time = time.time()
@ -911,8 +914,8 @@ class RedisCache(BaseCache):
async with _redis_client as redis_client: async with _redis_client as redis_client:
await redis_client.delete(*keys) await redis_client.delete(*keys)
def client_list(self): def client_list(self) -> List:
client_list = self.redis_client.client_list() client_list: List = self.redis_client.client_list() # type: ignore
return client_list return client_list
def info(self): def info(self):

View file

@ -39,8 +39,8 @@ from litellm.llms.fireworks_ai.cost_calculator import (
) )
from litellm.llms.OpenAI.cost_calculation import cost_per_token as openai_cost_per_token from litellm.llms.OpenAI.cost_calculation import cost_per_token as openai_cost_per_token
from litellm.llms.together_ai.cost_calculator import get_model_params_and_category from litellm.llms.together_ai.cost_calculator import get_model_params_and_category
from litellm.rerank_api.types import RerankResponse
from litellm.types.llms.openai import HttpxBinaryResponseContent from litellm.types.llms.openai import HttpxBinaryResponseContent
from litellm.types.rerank import RerankResponse
from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS
from litellm.types.utils import PassthroughCallTypes, Usage from litellm.types.utils import PassthroughCallTypes, Usage
from litellm.utils import ( from litellm.utils import (

View file

@ -39,11 +39,11 @@ from litellm.proxy._types import (
VirtualKeyEvent, VirtualKeyEvent,
WebhookEvent, WebhookEvent,
) )
from litellm.types.integrations.slack_alerting import *
from litellm.types.router import LiteLLM_Params from litellm.types.router import LiteLLM_Params
from ..email_templates.templates import * from ..email_templates.templates import *
from .batching_handler import send_to_webhook, squash_payloads from .batching_handler import send_to_webhook, squash_payloads
from .types import *
from .utils import _add_langfuse_trace_id_to_alert, process_slack_alerting_variables from .utils import _add_langfuse_trace_id_to_alert, process_slack_alerting_variables

View file

@ -172,14 +172,14 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
"moderation", "moderation",
"audio_transcription", "audio_transcription",
], ],
): ) -> Any:
pass pass
async def async_post_call_streaming_hook( async def async_post_call_streaming_hook(
self, self,
user_api_key_dict: UserAPIKeyAuth, user_api_key_dict: UserAPIKeyAuth,
response: str, response: str,
): ) -> Any:
pass pass
#### SINGLE-USE #### - https://docs.litellm.ai/docs/observability/custom_callback#using-your-custom-callback-function #### SINGLE-USE #### - https://docs.litellm.ai/docs/observability/custom_callback#using-your-custom-callback-function

View file

@ -42,8 +42,10 @@ async def get_all_team_member_emails(team_id: Optional[str] = None) -> list:
) )
_team_member_user_ids: List[str] = [] _team_member_user_ids: List[str] = []
for member in _team_members: for member in _team_members:
if member and isinstance(member, dict) and member.get("user_id") is not None: if member and isinstance(member, dict):
_team_member_user_ids.append(member.get("user_id")) _user_id = member.get("user_id")
if _user_id and isinstance(_user_id, str):
_team_member_user_ids.append(_user_id)
sql_query = """ sql_query = """
SELECT user_email SELECT user_email

View file

@ -149,7 +149,7 @@ class LunaryLogger:
else: else:
error_obj = None error_obj = None
self.lunary_client.track_event( self.lunary_client.track_event( # type: ignore
type, type,
"start", "start",
run_id, run_id,
@ -164,7 +164,7 @@ class LunaryLogger:
params=extra, params=extra,
) )
self.lunary_client.track_event( self.lunary_client.track_event( # type: ignore
type, type,
event, event,
run_id, run_id,

View file

@ -100,16 +100,14 @@ class OpenMeterLogger(CustomLogger):
} }
try: try:
response = self.sync_http_handler.post( self.sync_http_handler.post(
url=_url, url=_url,
data=json.dumps(_data), data=json.dumps(_data),
headers=_headers, headers=_headers,
) )
except httpx.HTTPStatusError as e:
response.raise_for_status() raise Exception(f"OpenMeter logging error: {e.response.text}")
except Exception as e: except Exception as e:
if hasattr(response, "text"):
litellm.print_verbose(f"\nError Message: {response.text}")
raise e raise e
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
@ -128,18 +126,12 @@ class OpenMeterLogger(CustomLogger):
} }
try: try:
response = await self.async_http_handler.post( await self.async_http_handler.post(
url=_url, url=_url,
data=json.dumps(_data), data=json.dumps(_data),
headers=_headers, headers=_headers,
) )
response.raise_for_status()
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
verbose_logger.error( raise Exception(f"OpenMeter logging error: {e.response.text}")
"Failed OpenMeter logging - {}".format(e.response.text)
)
raise e
except Exception as e: except Exception as e:
verbose_logger.error("Failed OpenMeter logging - {}".format(str(e)))
raise e raise e

View file

@ -146,14 +146,14 @@ class S3Logger:
import json import json
payload = json.dumps(payload) payload_str = json.dumps(payload)
print_verbose(f"\ns3 Logger - Logging payload = {payload}") print_verbose(f"\ns3 Logger - Logging payload = {payload_str}")
response = self.s3_client.put_object( response = self.s3_client.put_object(
Bucket=self.bucket_name, Bucket=self.bucket_name,
Key=s3_object_key, Key=s3_object_key,
Body=payload, Body=payload_str,
ContentType="application/json", ContentType="application/json",
ContentLanguage="en", ContentLanguage="en",
ContentDisposition=f'inline; filename="{s3_object_download_filename}"', ContentDisposition=f'inline; filename="{s3_object_download_filename}"',

View file

@ -1,6 +1,7 @@
import traceback import traceback
from litellm._logging import verbose_logger
import litellm import litellm
from litellm._logging import verbose_logger
class TraceloopLogger: class TraceloopLogger:
@ -11,14 +12,15 @@ class TraceloopLogger:
def __init__(self): def __init__(self):
try: try:
from traceloop.sdk.tracing.tracing import TracerWrapper from opentelemetry.sdk.trace.export import ConsoleSpanExporter
from traceloop.sdk import Traceloop from traceloop.sdk import Traceloop
from traceloop.sdk.instruments import Instruments from traceloop.sdk.instruments import Instruments
from opentelemetry.sdk.trace.export import ConsoleSpanExporter from traceloop.sdk.tracing.tracing import TracerWrapper
except ModuleNotFoundError as e: except ModuleNotFoundError as e:
verbose_logger.error( verbose_logger.error(
f"Traceloop not installed, try running 'pip install traceloop-sdk' to fix this error: {e}\n{traceback.format_exc()}" f"Traceloop not installed, try running 'pip install traceloop-sdk' to fix this error: {e}\n{traceback.format_exc()}"
) )
raise e
Traceloop.init( Traceloop.init(
app_name="Litellm-Server", app_name="Litellm-Server",
@ -38,8 +40,8 @@ class TraceloopLogger:
status_message=None, status_message=None,
): ):
from opentelemetry import trace from opentelemetry import trace
from opentelemetry.trace import SpanKind, Status, StatusCode
from opentelemetry.semconv.ai import SpanAttributes from opentelemetry.semconv.ai import SpanAttributes
from opentelemetry.trace import SpanKind, Status, StatusCode
try: try:
print_verbose( print_verbose(
@ -94,7 +96,7 @@ class TraceloopLogger:
) )
if "temperature" in optional_params: if "temperature" in optional_params:
span.set_attribute( span.set_attribute(
SpanAttributes.LLM_REQUEST_TEMPERATURE, SpanAttributes.LLM_REQUEST_TEMPERATURE, # type: ignore
kwargs.get("temperature"), kwargs.get("temperature"),
) )

View file

@ -32,8 +32,8 @@ from litellm.litellm_core_utils.redact_messages import (
redact_message_input_output_from_logging, redact_message_input_output_from_logging,
) )
from litellm.proxy._types import CommonProxyErrors from litellm.proxy._types import CommonProxyErrors
from litellm.rerank_api.types import RerankResponse
from litellm.types.llms.openai import HttpxBinaryResponseContent from litellm.types.llms.openai import HttpxBinaryResponseContent
from litellm.types.rerank import RerankResponse
from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS
from litellm.types.utils import ( from litellm.types.utils import (
CallTypes, CallTypes,

View file

@ -1,5 +1,5 @@
import uuid import uuid
from typing import Optional, Union from typing import Any, Optional, Union
import httpx import httpx
from openai import AsyncAzureOpenAI, AzureOpenAI from openai import AsyncAzureOpenAI, AzureOpenAI
@ -24,6 +24,7 @@ class AzureAudioTranscription(AzureChatCompletion):
model: str, model: str,
audio_file: FileTypes, audio_file: FileTypes,
optional_params: dict, optional_params: dict,
logging_obj: Any,
model_response: TranscriptionResponse, model_response: TranscriptionResponse,
timeout: float, timeout: float,
max_retries: int, max_retries: int,
@ -32,9 +33,8 @@ class AzureAudioTranscription(AzureChatCompletion):
api_version: Optional[str] = None, api_version: Optional[str] = None,
client=None, client=None,
azure_ad_token: Optional[str] = None, azure_ad_token: Optional[str] = None,
logging_obj=None,
atranscription: bool = False, atranscription: bool = False,
): ) -> TranscriptionResponse:
data = {"model": model, "file": audio_file, **optional_params} data = {"model": model, "file": audio_file, **optional_params}
# init AzureOpenAI Client # init AzureOpenAI Client
@ -59,7 +59,7 @@ class AzureAudioTranscription(AzureChatCompletion):
azure_client_params["max_retries"] = max_retries azure_client_params["max_retries"] = max_retries
if atranscription is True: if atranscription is True:
return self.async_audio_transcriptions( return self.async_audio_transcriptions( # type: ignore
audio_file=audio_file, audio_file=audio_file,
data=data, data=data,
model_response=model_response, model_response=model_response,
@ -105,7 +105,7 @@ class AzureAudioTranscription(AzureChatCompletion):
original_response=stringified_response, original_response=stringified_response,
) )
hidden_params = {"model": "whisper-1", "custom_llm_provider": "azure"} hidden_params = {"model": "whisper-1", "custom_llm_provider": "azure"}
final_response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore final_response: TranscriptionResponse = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore
return final_response return final_response
async def async_audio_transcriptions( async def async_audio_transcriptions(
@ -114,12 +114,12 @@ class AzureAudioTranscription(AzureChatCompletion):
data: dict, data: dict,
model_response: TranscriptionResponse, model_response: TranscriptionResponse,
timeout: float, timeout: float,
azure_client_params: dict,
logging_obj: Any,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
client=None, client=None,
azure_client_params=None,
max_retries=None, max_retries=None,
logging_obj=None,
): ):
response = None response = None
try: try:

View file

@ -1083,7 +1083,7 @@ class AzureChatCompletion(BaseLLM):
azure_ad_token: Optional[str] = None, azure_ad_token: Optional[str] = None,
client=None, client=None,
aembedding=None, aembedding=None,
): ) -> litellm.EmbeddingResponse:
super().embedding() super().embedding()
if self._client_session is None: if self._client_session is None:
self._client_session = self.create_client_session() self._client_session = self.create_client_session()
@ -1128,7 +1128,7 @@ class AzureChatCompletion(BaseLLM):
) )
if aembedding is True: if aembedding is True:
response = self.aembedding( return self.aembedding( # type: ignore
data=data, data=data,
input=input, input=input,
logging_obj=logging_obj, logging_obj=logging_obj,
@ -1138,7 +1138,6 @@ class AzureChatCompletion(BaseLLM):
timeout=timeout, timeout=timeout,
client=client, client=client,
) )
return response
if client is None: if client is None:
azure_client = AzureOpenAI(**azure_client_params) # type: ignore azure_client = AzureOpenAI(**azure_client_params) # type: ignore
else: else:
@ -1418,7 +1417,7 @@ class AzureChatCompletion(BaseLLM):
logging_obj: LiteLLMLoggingObj, logging_obj: LiteLLMLoggingObj,
client=None, client=None,
timeout=None, timeout=None,
): ) -> litellm.ImageResponse:
response: Optional[dict] = None response: Optional[dict] = None
try: try:
# response = await azure_client.images.generate(**data, timeout=timeout) # response = await azure_client.images.generate(**data, timeout=timeout)
@ -1460,7 +1459,7 @@ class AzureChatCompletion(BaseLLM):
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
original_response=stringified_response, original_response=stringified_response,
) )
return convert_to_model_response_object( return convert_to_model_response_object( # type: ignore
response_object=stringified_response, response_object=stringified_response,
model_response_object=model_response, model_response_object=model_response,
response_type="image_generation", response_type="image_generation",
@ -1489,7 +1488,7 @@ class AzureChatCompletion(BaseLLM):
azure_ad_token: Optional[str] = None, azure_ad_token: Optional[str] = None,
client=None, client=None,
aimg_generation=None, aimg_generation=None,
): ) -> litellm.ImageResponse:
try: try:
if model and len(model) > 0: if model and len(model) > 0:
model = model model = model
@ -1531,8 +1530,7 @@ class AzureChatCompletion(BaseLLM):
azure_client_params["azure_ad_token"] = azure_ad_token azure_client_params["azure_ad_token"] = azure_ad_token
if aimg_generation is True: if aimg_generation is True:
response = self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params, timeout=timeout) # type: ignore return self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params, timeout=timeout) # type: ignore
return response
img_gen_api_base = self.create_azure_base_url( img_gen_api_base = self.create_azure_base_url(
azure_client_params=azure_client_params, model=data.get("model", "") azure_client_params=azure_client_params, model=data.get("model", "")
@ -1742,9 +1740,9 @@ class AzureChatCompletion(BaseLLM):
async def ahealth_check( async def ahealth_check(
self, self,
model: Optional[str], model: Optional[str],
api_key: str, api_key: Optional[str],
api_base: str, api_base: str,
api_version: str, api_version: Optional[str],
timeout: float, timeout: float,
mode: str, mode: str,
messages: Optional[list] = None, messages: Optional[list] = None,

View file

@ -77,15 +77,15 @@ class OpenAIAudioTranscription(OpenAIChatCompletion):
model_response: TranscriptionResponse, model_response: TranscriptionResponse,
timeout: float, timeout: float,
max_retries: int, max_retries: int,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str], api_key: Optional[str],
api_base: Optional[str], api_base: Optional[str],
client=None, client=None,
logging_obj=None,
atranscription: bool = False, atranscription: bool = False,
): ) -> TranscriptionResponse:
data = {"model": model, "file": audio_file, **optional_params} data = {"model": model, "file": audio_file, **optional_params}
if atranscription is True: if atranscription is True:
return self.async_audio_transcriptions( return self.async_audio_transcriptions( # type: ignore
audio_file=audio_file, audio_file=audio_file,
data=data, data=data,
model_response=model_response, model_response=model_response,
@ -97,7 +97,7 @@ class OpenAIAudioTranscription(OpenAIChatCompletion):
logging_obj=logging_obj, logging_obj=logging_obj,
) )
openai_client = self._get_openai_client( openai_client: OpenAI = self._get_openai_client( # type: ignore
is_async=False, is_async=False,
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
@ -123,7 +123,7 @@ class OpenAIAudioTranscription(OpenAIChatCompletion):
original_response=stringified_response, original_response=stringified_response,
) )
hidden_params = {"model": "whisper-1", "custom_llm_provider": "openai"} hidden_params = {"model": "whisper-1", "custom_llm_provider": "openai"}
final_response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore final_response: TranscriptionResponse = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore
return final_response return final_response
async def async_audio_transcriptions( async def async_audio_transcriptions(
@ -139,7 +139,7 @@ class OpenAIAudioTranscription(OpenAIChatCompletion):
max_retries=None, max_retries=None,
): ):
try: try:
openai_aclient = self._get_openai_client( openai_aclient: AsyncOpenAI = self._get_openai_client( # type: ignore
is_async=True, is_async=True,
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,

View file

@ -1167,7 +1167,7 @@ class OpenAIChatCompletion(BaseLLM):
api_base: Optional[str] = None, api_base: Optional[str] = None,
client=None, client=None,
aembedding=None, aembedding=None,
): ) -> litellm.EmbeddingResponse:
super().embedding() super().embedding()
try: try:
model = model model = model
@ -1183,7 +1183,7 @@ class OpenAIChatCompletion(BaseLLM):
) )
if aembedding is True: if aembedding is True:
async_response = self.aembedding( return self.aembedding( # type: ignore
data=data, data=data,
input=input, input=input,
logging_obj=logging_obj, logging_obj=logging_obj,
@ -1194,7 +1194,6 @@ class OpenAIChatCompletion(BaseLLM):
client=client, client=client,
max_retries=max_retries, max_retries=max_retries,
) )
return async_response
openai_client: OpenAI = self._get_openai_client( # type: ignore openai_client: OpenAI = self._get_openai_client( # type: ignore
is_async=False, is_async=False,
@ -1294,7 +1293,7 @@ class OpenAIChatCompletion(BaseLLM):
model_response: Optional[litellm.utils.ImageResponse] = None, model_response: Optional[litellm.utils.ImageResponse] = None,
client=None, client=None,
aimg_generation=None, aimg_generation=None,
): ) -> litellm.ImageResponse:
data = {} data = {}
try: try:
model = model model = model
@ -1304,8 +1303,7 @@ class OpenAIChatCompletion(BaseLLM):
raise OpenAIError(status_code=422, message="max retries must be an int") raise OpenAIError(status_code=422, message="max retries must be an int")
if aimg_generation is True: if aimg_generation is True:
response = self.aimage_generation(data=data, prompt=prompt, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore return self.aimage_generation(data=data, prompt=prompt, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
return response
openai_client = self._get_openai_client( openai_client = self._get_openai_client(
is_async=False, is_async=False,
@ -1449,7 +1447,7 @@ class OpenAIChatCompletion(BaseLLM):
async def ahealth_check( async def ahealth_check(
self, self,
model: Optional[str], model: Optional[str],
api_key: str, api_key: Optional[str],
timeout: float, timeout: float,
mode: str, mode: str,
messages: Optional[list] = None, messages: Optional[list] = None,

View file

@ -282,7 +282,6 @@ class AnthropicChatCompletion(BaseLLM):
prompt_tokens = completion_response["usage"]["input_tokens"] prompt_tokens = completion_response["usage"]["input_tokens"]
completion_tokens = completion_response["usage"]["output_tokens"] completion_tokens = completion_response["usage"]["output_tokens"]
_usage = completion_response["usage"] _usage = completion_response["usage"]
total_tokens = prompt_tokens + completion_tokens
cache_creation_input_tokens: int = 0 cache_creation_input_tokens: int = 0
cache_read_input_tokens: int = 0 cache_read_input_tokens: int = 0
@ -290,12 +289,15 @@ class AnthropicChatCompletion(BaseLLM):
model_response.model = model model_response.model = model
if "cache_creation_input_tokens" in _usage: if "cache_creation_input_tokens" in _usage:
cache_creation_input_tokens = _usage["cache_creation_input_tokens"] cache_creation_input_tokens = _usage["cache_creation_input_tokens"]
prompt_tokens += cache_creation_input_tokens
if "cache_read_input_tokens" in _usage: if "cache_read_input_tokens" in _usage:
cache_read_input_tokens = _usage["cache_read_input_tokens"] cache_read_input_tokens = _usage["cache_read_input_tokens"]
prompt_tokens += cache_read_input_tokens
prompt_tokens_details = PromptTokensDetails( prompt_tokens_details = PromptTokensDetails(
cached_tokens=cache_read_input_tokens cached_tokens=cache_read_input_tokens
) )
total_tokens = prompt_tokens + completion_tokens
usage = Usage( usage = Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,

View file

@ -24,16 +24,33 @@ def cost_per_token(model: str, usage: Usage) -> Tuple[float, float]:
model_info = get_model_info(model=model, custom_llm_provider="anthropic") model_info = get_model_info(model=model, custom_llm_provider="anthropic")
## CALCULATE INPUT COST ## CALCULATE INPUT COST
### Cost of processing (non-cache hit + cache hit) + Cost of cache-writing (cache writing)
prompt_cost = 0.0
### PROCESSING COST
non_cache_hit_tokens = usage.prompt_tokens
cache_hit_tokens = 0
if usage.prompt_tokens_details and usage.prompt_tokens_details.cached_tokens:
cache_hit_tokens = usage.prompt_tokens_details.cached_tokens
non_cache_hit_tokens = non_cache_hit_tokens - cache_hit_tokens
prompt_cost: float = usage["prompt_tokens"] * model_info["input_cost_per_token"] prompt_cost = float(non_cache_hit_tokens) * model_info["input_cost_per_token"]
if model_info.get("cache_creation_input_token_cost") is not None:
_cache_read_input_token_cost = model_info.get("cache_read_input_token_cost")
if (
_cache_read_input_token_cost is not None
and usage.prompt_tokens_details
and usage.prompt_tokens_details.cached_tokens
):
prompt_cost += ( prompt_cost += (
usage._cache_creation_input_tokens # type: ignore float(usage.prompt_tokens_details.cached_tokens)
* model_info["cache_creation_input_token_cost"] * _cache_read_input_token_cost
) )
if model_info.get("cache_read_input_token_cost") is not None:
### CACHE WRITING COST
_cache_creation_input_token_cost = model_info.get("cache_creation_input_token_cost")
if _cache_creation_input_token_cost is not None:
prompt_cost += ( prompt_cost += (
usage._cache_read_input_tokens * model_info["cache_read_input_token_cost"] # type: ignore float(usage._cache_creation_input_tokens) * _cache_creation_input_token_cost
) )
## CALCULATE OUTPUT COST ## CALCULATE OUTPUT COST

View file

@ -216,7 +216,7 @@ class AzureAIEmbedding(OpenAIChatCompletion):
api_base: Optional[str] = None, api_base: Optional[str] = None,
client=None, client=None,
aembedding=None, aembedding=None,
): ) -> litellm.EmbeddingResponse:
""" """
- Separate image url from text - Separate image url from text
-> route image url call to `/image/embeddings` -> route image url call to `/image/embeddings`
@ -225,7 +225,7 @@ class AzureAIEmbedding(OpenAIChatCompletion):
assemble result in-order, and return assemble result in-order, and return
""" """
if aembedding is True: if aembedding is True:
return self.async_embedding( return self.async_embedding( # type: ignore
model, model,
input, input,
timeout, timeout,

View file

@ -4,7 +4,7 @@ import httpx
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.cohere.rerank import CohereRerank from litellm.llms.cohere.rerank import CohereRerank
from litellm.rerank_api.types import RerankResponse from litellm.types.rerank import RerankResponse
class AzureAIRerank(CohereRerank): class AzureAIRerank(CohereRerank):

View file

@ -5,7 +5,7 @@ Handles image gen calls to Bedrock's `/invoke` endpoint
import copy import copy
import json import json
import os import os
from typing import List from typing import Any, List
from openai.types.image import Image from openai.types.image import Image
@ -20,8 +20,8 @@ def image_generation(
prompt: str, prompt: str,
model_response: ImageResponse, model_response: ImageResponse,
optional_params: dict, optional_params: dict,
logging_obj: Any,
timeout=None, timeout=None,
logging_obj=None,
aimg_generation=False, aimg_generation=False,
): ):
""" """

View file

@ -13,7 +13,7 @@ from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client, _get_httpx_client,
get_async_httpx_client, get_async_httpx_client,
) )
from litellm.rerank_api.types import RerankRequest, RerankResponse from litellm.types.rerank import RerankRequest, RerankResponse
class CohereRerank(BaseLLM): class CohereRerank(BaseLLM):

View file

@ -40,7 +40,7 @@ class AzureOpenAIFineTuningAPI(BaseLLM):
organization: Optional[str] = None, organization: Optional[str] = None,
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
api_version: Optional[str] = None, api_version: Optional[str] = None,
) -> Union[FineTuningJob, Union[Coroutine[Any, Any, FineTuningJob]]]: ) -> Union[FineTuningJob, Coroutine[Any, Any, FineTuningJob]]:
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
get_azure_openai_client( get_azure_openai_client(
api_key=api_key, api_key=api_key,

View file

@ -68,7 +68,7 @@ class OpenAIFineTuningAPI(BaseLLM):
max_retries: Optional[int], max_retries: Optional[int],
organization: Optional[str], organization: Optional[str],
client: Optional[Union[OpenAI, AsyncOpenAI]] = None, client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
) -> Union[FineTuningJob, Union[Coroutine[Any, Any, FineTuningJob]]]: ) -> Union[FineTuningJob, Coroutine[Any, Any, FineTuningJob]]:
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client( openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,

View file

@ -554,7 +554,7 @@ class Huggingface(BaseLLM):
else: else:
prompt = prompt_factory(model=model, messages=messages) prompt = prompt_factory(model=model, messages=messages)
data = { data = {
"inputs": prompt, "inputs": prompt, # type: ignore
"parameters": optional_params, "parameters": optional_params,
"stream": ( # type: ignore "stream": ( # type: ignore
True True
@ -589,7 +589,7 @@ class Huggingface(BaseLLM):
inference_params.pop("details") inference_params.pop("details")
inference_params.pop("return_full_text") inference_params.pop("return_full_text")
data = { data = {
"inputs": prompt, "inputs": prompt, # type: ignore
} }
if task == "text-generation-inference": if task == "text-generation-inference":
data["parameters"] = inference_params data["parameters"] = inference_params

View file

@ -4,7 +4,7 @@ import time
import traceback import traceback
import types import types
from enum import Enum from enum import Enum
from typing import Callable, List, Optional from typing import Any, Callable, List, Optional
import requests # type: ignore import requests # type: ignore
@ -185,8 +185,8 @@ def completion(
def embedding( def embedding(
model: str, model: str,
input: list, input: list,
api_key: Optional[str] = None, api_key: Optional[str],
logging_obj=None, logging_obj: Any,
model_response=None, model_response=None,
encoding=None, encoding=None,
): ):

View file

@ -2,7 +2,7 @@ import json
import os import os
import time import time
from enum import Enum from enum import Enum
from typing import Callable, Optional from typing import Any, Callable, Optional
import requests # type: ignore import requests # type: ignore
@ -124,9 +124,9 @@ def embedding(
model: str, model: str,
input: list, input: list,
model_response: EmbeddingResponse, model_response: EmbeddingResponse,
api_key: Optional[str] = None, api_key: Optional[str],
api_base: Optional[str] = None, api_base: Optional[str],
logging_obj=None, logging_obj: Any,
optional_params=None, optional_params=None,
encoding=None, encoding=None,
): ):

View file

@ -2714,7 +2714,7 @@ def custom_prompt(
final_prompt_value: str = "", final_prompt_value: str = "",
bos_token: str = "", bos_token: str = "",
eos_token: str = "", eos_token: str = "",
): ) -> str:
prompt = bos_token + initial_prompt_value prompt = bos_token + initial_prompt_value
bos_open = True bos_open = True
## a bos token is at the start of a system / human message ## a bos token is at the start of a system / human message

View file

@ -15,7 +15,7 @@ from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client, _get_httpx_client,
get_async_httpx_client, get_async_httpx_client,
) )
from litellm.rerank_api.types import RerankRequest, RerankResponse from litellm.types.rerank import RerankRequest, RerankResponse
class TogetherAIRerank(BaseLLM): class TogetherAIRerank(BaseLLM):

View file

@ -47,7 +47,7 @@ class TritonChatCompletion(BaseLLM):
data: dict, data: dict,
model_response: litellm.utils.EmbeddingResponse, model_response: litellm.utils.EmbeddingResponse,
api_base: str, api_base: str,
logging_obj=None, logging_obj: Any,
api_key: Optional[str] = None, api_key: Optional[str] = None,
) -> EmbeddingResponse: ) -> EmbeddingResponse:
async_handler = AsyncHTTPHandler( async_handler = AsyncHTTPHandler(
@ -93,9 +93,9 @@ class TritonChatCompletion(BaseLLM):
timeout: float, timeout: float,
api_base: str, api_base: str,
model_response: litellm.utils.EmbeddingResponse, model_response: litellm.utils.EmbeddingResponse,
logging_obj: Any,
optional_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
logging_obj=None,
optional_params=None,
client=None, client=None,
aembedding: bool = False, aembedding: bool = False,
) -> EmbeddingResponse: ) -> EmbeddingResponse:
@ -122,7 +122,7 @@ class TritonChatCompletion(BaseLLM):
) )
if aembedding: if aembedding:
response = await self.aembedding( response = await self.aembedding( # type: ignore
data=data_for_triton, data=data_for_triton,
model_response=model_response, model_response=model_response,
logging_obj=logging_obj, logging_obj=logging_obj,
@ -141,10 +141,10 @@ class TritonChatCompletion(BaseLLM):
messages: List[dict], messages: List[dict],
timeout: float, timeout: float,
api_base: str, api_base: str,
logging_obj: Any,
optional_params: dict,
model_response: ModelResponse, model_response: ModelResponse,
api_key: Optional[str] = None, api_key: Optional[str] = None,
logging_obj=None,
optional_params=None,
client=None, client=None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
acompletion: bool = False, acompletion: bool = False,
@ -239,11 +239,13 @@ class TritonChatCompletion(BaseLLM):
else: else:
handler = HTTPHandler() handler = HTTPHandler()
if stream: if stream:
return self._handle_stream( return self._handle_stream( # type: ignore
handler, api_base, json_data_for_triton, model, logging_obj handler, api_base, json_data_for_triton, model, logging_obj
) )
else: else:
response = handler.post(url=api_base, data=json_data_for_triton, headers=headers) response = handler.post(
url=api_base, data=json_data_for_triton, headers=headers
)
return self._handle_response( return self._handle_response(
response, model_response, logging_obj, type_of_model=type_of_model response, model_response, logging_obj, type_of_model=type_of_model
) )
@ -261,7 +263,7 @@ class TritonChatCompletion(BaseLLM):
) -> ModelResponse: ) -> ModelResponse:
handler = AsyncHTTPHandler() handler = AsyncHTTPHandler()
if stream: if stream:
return self._ahandle_stream( return self._ahandle_stream( # type: ignore
handler, api_base, data_for_triton, model, logging_obj handler, api_base, data_for_triton, model, logging_obj
) )
else: else:

View file

@ -2,7 +2,7 @@ from typing import List, Literal, Tuple
import httpx import httpx
from litellm import supports_system_messages, supports_response_schema, verbose_logger from litellm import supports_response_schema, supports_system_messages, verbose_logger
from litellm.types.llms.vertex_ai import PartType from litellm.types.llms.vertex_ai import PartType
@ -67,6 +67,8 @@ def _get_vertex_url(
vertex_location: Optional[str], vertex_location: Optional[str],
vertex_api_version: Literal["v1", "v1beta1"], vertex_api_version: Literal["v1", "v1beta1"],
) -> Tuple[str, str]: ) -> Tuple[str, str]:
url: Optional[str] = None
endpoint: Optional[str] = None
if mode == "chat": if mode == "chat":
### SET RUNTIME ENDPOINT ### ### SET RUNTIME ENDPOINT ###
endpoint = "generateContent" endpoint = "generateContent"
@ -88,6 +90,8 @@ def _get_vertex_url(
endpoint = "predict" endpoint = "predict"
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}" url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
if not url or not endpoint:
raise ValueError(f"Unable to get vertex url/endpoinit for mode: {mode}")
return url, endpoint return url, endpoint

View file

@ -3,7 +3,7 @@ Google AI Studio /batchEmbedContents Embeddings Endpoint
""" """
import json import json
from typing import List, Literal, Optional, Union from typing import Any, List, Literal, Optional, Union
import httpx import httpx
@ -31,9 +31,9 @@ class GoogleBatchEmbeddings(VertexLLM):
model_response: EmbeddingResponse, model_response: EmbeddingResponse,
custom_llm_provider: Literal["gemini", "vertex_ai"], custom_llm_provider: Literal["gemini", "vertex_ai"],
optional_params: dict, optional_params: dict,
logging_obj: Any,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
logging_obj=None,
encoding=None, encoding=None,
vertex_project=None, vertex_project=None,
vertex_location=None, vertex_location=None,

View file

@ -43,17 +43,17 @@ class VertexImageGeneration(VertexLLM):
vertex_location: Optional[str], vertex_location: Optional[str],
vertex_credentials: Optional[str], vertex_credentials: Optional[str],
model_response: litellm.ImageResponse, model_response: litellm.ImageResponse,
logging_obj: Any,
model: Optional[ model: Optional[
str str
] = "imagegeneration", # vertex ai uses imagegeneration as the default model ] = "imagegeneration", # vertex ai uses imagegeneration as the default model
client: Optional[Any] = None, client: Optional[Any] = None,
optional_params: Optional[dict] = None, optional_params: Optional[dict] = None,
timeout: Optional[int] = None, timeout: Optional[int] = None,
logging_obj=None,
aimg_generation=False, aimg_generation=False,
): ) -> litellm.ImageResponse:
if aimg_generation is True: if aimg_generation is True:
return self.aimage_generation( return self.aimage_generation( # type: ignore
prompt=prompt, prompt=prompt,
vertex_project=vertex_project, vertex_project=vertex_project,
vertex_location=vertex_location, vertex_location=vertex_location,
@ -138,13 +138,13 @@ class VertexImageGeneration(VertexLLM):
vertex_location: Optional[str], vertex_location: Optional[str],
vertex_credentials: Optional[str], vertex_credentials: Optional[str],
model_response: litellm.ImageResponse, model_response: litellm.ImageResponse,
logging_obj: Any,
model: Optional[ model: Optional[
str str
] = "imagegeneration", # vertex ai uses imagegeneration as the default model ] = "imagegeneration", # vertex ai uses imagegeneration as the default model
client: Optional[AsyncHTTPHandler] = None, client: Optional[AsyncHTTPHandler] = None,
optional_params: Optional[dict] = None, optional_params: Optional[dict] = None,
timeout: Optional[int] = None, timeout: Optional[int] = None,
logging_obj=None,
): ):
response = None response = None
if client is None: if client is None:

View file

@ -48,7 +48,7 @@ class VertexMultimodalEmbedding(VertexLLM):
aembedding=False, aembedding=False,
timeout=300, timeout=300,
client=None, client=None,
): ) -> litellm.EmbeddingResponse:
_auth_header, vertex_project = self._ensure_access_token( _auth_header, vertex_project = self._ensure_access_token(
credentials=vertex_credentials, credentials=vertex_credentials,
@ -121,7 +121,7 @@ class VertexMultimodalEmbedding(VertexLLM):
) )
if aembedding is True: if aembedding is True:
return self.async_multimodal_embedding( return self.async_multimodal_embedding( # type: ignore
model=model, model=model,
api_base=url, api_base=url,
data=request_data, data=request_data,

View file

@ -62,7 +62,7 @@ class VertexTextToSpeechAPI(VertexLLM):
_is_async: Optional[bool] = False, _is_async: Optional[bool] = False,
optional_params: Optional[dict] = None, optional_params: Optional[dict] = None,
kwargs: Optional[dict] = None, kwargs: Optional[dict] = None,
): ) -> HttpxBinaryResponseContent:
import base64 import base64
####### Authenticate with Vertex AI ######## ####### Authenticate with Vertex AI ########
@ -145,7 +145,7 @@ class VertexTextToSpeechAPI(VertexLLM):
########## End of logging ############ ########## End of logging ############
####### Send the request ################### ####### Send the request ###################
if _is_async is True: if _is_async is True:
return self.async_audio_speech( return self.async_audio_speech( # type:ignore
logging_obj=logging_obj, url=url, headers=headers, request=request logging_obj=logging_obj, url=url, headers=headers, request=request
) )
sync_handler = _get_httpx_client() sync_handler = _get_httpx_client()

View file

@ -52,9 +52,9 @@ class VertexEmbedding(VertexBase):
vertex_credentials: Optional[str] = None, vertex_credentials: Optional[str] = None,
gemini_api_key: Optional[str] = None, gemini_api_key: Optional[str] = None,
extra_headers: Optional[dict] = None, extra_headers: Optional[dict] = None,
): ) -> litellm.EmbeddingResponse:
if aembedding is True: if aembedding is True:
return self.async_embedding( return self.async_embedding( # type: ignore
model=model, model=model,
input=input, input=input,
logging_obj=logging_obj, logging_obj=logging_obj,

View file

@ -25,13 +25,8 @@ import requests # type: ignore
import litellm import litellm
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.utils import ( from litellm.secret_managers.main import get_secret_str
EmbeddingResponse, from litellm.utils import EmbeddingResponse, ModelResponse, Usage, map_finish_reason
ModelResponse,
Usage,
get_secret,
map_finish_reason,
)
from .base import BaseLLM from .base import BaseLLM
from .prompt_templates import factory as ptf from .prompt_templates import factory as ptf
@ -184,7 +179,7 @@ class IBMWatsonXAIConfig:
] ]
def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict): def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict) -> str:
# handle anthropic prompts and amazon titan prompts # handle anthropic prompts and amazon titan prompts
if model in custom_prompt_dict: if model in custom_prompt_dict:
# check if the model has a registered custom prompt # check if the model has a registered custom prompt
@ -200,14 +195,10 @@ def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
eos_token=model_prompt_dict.get("eos_token", ""), eos_token=model_prompt_dict.get("eos_token", ""),
) )
return prompt return prompt
elif provider == "ibm":
prompt = ptf.prompt_factory(
model=model, messages=messages, custom_llm_provider="watsonx"
)
elif provider == "ibm-mistralai": elif provider == "ibm-mistralai":
prompt = ptf.mistral_instruct_pt(messages=messages) prompt = ptf.mistral_instruct_pt(messages=messages)
else: else:
prompt = ptf.prompt_factory( prompt: str = ptf.prompt_factory( # type: ignore
model=model, messages=messages, custom_llm_provider="watsonx" model=model, messages=messages, custom_llm_provider="watsonx"
) )
return prompt return prompt
@ -327,37 +318,37 @@ class IBMWatsonXAI(BaseLLM):
# Load auth variables from environment variables # Load auth variables from environment variables
if url is None: if url is None:
url = ( url = (
get_secret("WATSONX_API_BASE") # consistent with 'AZURE_API_BASE' get_secret_str("WATSONX_API_BASE") # consistent with 'AZURE_API_BASE'
or get_secret("WATSONX_URL") or get_secret_str("WATSONX_URL")
or get_secret("WX_URL") or get_secret_str("WX_URL")
or get_secret("WML_URL") or get_secret_str("WML_URL")
) )
if api_key is None: if api_key is None:
api_key = ( api_key = (
get_secret("WATSONX_APIKEY") get_secret_str("WATSONX_APIKEY")
or get_secret("WATSONX_API_KEY") or get_secret_str("WATSONX_API_KEY")
or get_secret("WX_API_KEY") or get_secret_str("WX_API_KEY")
) )
if token is None: if token is None:
token = get_secret("WATSONX_TOKEN") or get_secret("WX_TOKEN") token = get_secret_str("WATSONX_TOKEN") or get_secret_str("WX_TOKEN")
if project_id is None: if project_id is None:
project_id = ( project_id = (
get_secret("WATSONX_PROJECT_ID") get_secret_str("WATSONX_PROJECT_ID")
or get_secret("WX_PROJECT_ID") or get_secret_str("WX_PROJECT_ID")
or get_secret("PROJECT_ID") or get_secret_str("PROJECT_ID")
) )
if region_name is None: if region_name is None:
region_name = ( region_name = (
get_secret("WATSONX_REGION") get_secret_str("WATSONX_REGION")
or get_secret("WX_REGION") or get_secret_str("WX_REGION")
or get_secret("REGION") or get_secret_str("REGION")
) )
if space_id is None: if space_id is None:
space_id = ( space_id = (
get_secret("WATSONX_DEPLOYMENT_SPACE_ID") get_secret_str("WATSONX_DEPLOYMENT_SPACE_ID")
or get_secret("WATSONX_SPACE_ID") or get_secret_str("WATSONX_SPACE_ID")
or get_secret("WX_SPACE_ID") or get_secret_str("WX_SPACE_ID")
or get_secret("SPACE_ID") or get_secret_str("SPACE_ID")
) )
# credentials parsing # credentials parsing
@ -446,8 +437,8 @@ class IBMWatsonXAI(BaseLLM):
model_response: ModelResponse, model_response: ModelResponse,
print_verbose: Callable, print_verbose: Callable,
encoding, encoding,
logging_obj, logging_obj: Any,
optional_params=None, optional_params: dict,
acompletion=None, acompletion=None,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
@ -592,13 +583,13 @@ class IBMWatsonXAI(BaseLLM):
model: str, model: str,
input: Union[list, str], input: Union[list, str],
model_response: litellm.EmbeddingResponse, model_response: litellm.EmbeddingResponse,
api_key: Optional[str] = None, api_key: Optional[str],
logging_obj=None, logging_obj: Any,
optional_params=None, optional_params: dict,
encoding=None, encoding=None,
print_verbose=None, print_verbose=None,
aembedding=None, aembedding=None,
): ) -> litellm.EmbeddingResponse:
""" """
Send a text embedding request to the IBM Watsonx.ai API. Send a text embedding request to the IBM Watsonx.ai API.
""" """
@ -657,7 +648,7 @@ class IBMWatsonXAI(BaseLLM):
try: try:
if aembedding is True: if aembedding is True:
return handle_aembedding(req_params) return handle_aembedding(req_params) # type: ignore
else: else:
return handle_embedding(req_params) return handle_embedding(req_params)
except WatsonXAIError as e: except WatsonXAIError as e:
@ -669,7 +660,7 @@ class IBMWatsonXAI(BaseLLM):
headers = {} headers = {}
headers["Content-Type"] = "application/x-www-form-urlencoded" headers["Content-Type"] = "application/x-www-form-urlencoded"
if api_key is None: if api_key is None:
api_key = get_secret("WX_API_KEY") or get_secret("WATSONX_API_KEY") api_key = get_secret_str("WX_API_KEY") or get_secret_str("WATSONX_API_KEY")
if api_key is None: if api_key is None:
raise ValueError("API key is required") raise ValueError("API key is required")
headers["Accept"] = "application/json" headers["Accept"] = "application/json"
@ -812,22 +803,29 @@ class RequestManager:
request_params["data"] = json.dumps(request_params.pop("json", {})) request_params["data"] = json.dumps(request_params.pop("json", {}))
method = request_params.pop("method") method = request_params.pop("method")
retries = 0 retries = 0
resp: Optional[httpx.Response] = None
while retries < 3: while retries < 3:
if method.upper() == "POST": if method.upper() == "POST":
resp = await self.async_handler.post(**request_params) resp = await self.async_handler.post(**request_params)
else: else:
resp = await self.async_handler.get(**request_params) resp = await self.async_handler.get(**request_params)
if resp.status_code in [429, 503, 504, 520]: if resp is not None and resp.status_code in [429, 503, 504, 520]:
# to handle rate limiting and service unavailable errors # to handle rate limiting and service unavailable errors
# see: ibm_watsonx_ai.foundation_models.inference.base_model_inference.BaseModelInference._send_inference_payload # see: ibm_watsonx_ai.foundation_models.inference.base_model_inference.BaseModelInference._send_inference_payload
await asyncio.sleep(2**retries) await asyncio.sleep(2**retries)
retries += 1 retries += 1
else: else:
break break
if resp is None:
raise WatsonXAIError(
status_code=500,
message="No response from the server",
)
if resp.is_error: if resp.is_error:
error_reason = getattr(resp, "reason", "")
raise WatsonXAIError( raise WatsonXAIError(
status_code=resp.status_code, status_code=resp.status_code,
message=f"Error {resp.status_code} ({resp.reason}): {resp.text}", message=f"Error {resp.status_code} ({error_reason}): {resp.text}",
) )
yield resp yield resp
# await async_handler.close() # await async_handler.close()

View file

@ -19,7 +19,8 @@ import threading
import time import time
import traceback import traceback
import uuid import uuid
from concurrent.futures import ThreadPoolExecutor from concurrent import futures
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
from copy import deepcopy from copy import deepcopy
from functools import partial from functools import partial
from typing import Any, Callable, Dict, List, Literal, Mapping, Optional, Type, Union from typing import Any, Callable, Dict, List, Literal, Mapping, Optional, Type, Union
@ -647,7 +648,7 @@ def mock_completion(
@client @client
def completion( def completion( # type: ignore
model: str, model: str,
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create # Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
messages: List = [], messages: List = [],
@ -2940,16 +2941,16 @@ def completion_with_retries(*args, **kwargs):
num_retries = kwargs.pop("num_retries", 3) num_retries = kwargs.pop("num_retries", 3)
retry_strategy: Literal["exponential_backoff_retry", "constant_retry"] = kwargs.pop("retry_strategy", "constant_retry") # type: ignore retry_strategy: Literal["exponential_backoff_retry", "constant_retry"] = kwargs.pop("retry_strategy", "constant_retry") # type: ignore
original_function = kwargs.pop("original_function", completion) original_function = kwargs.pop("original_function", completion)
if retry_strategy == "constant_retry": if retry_strategy == "exponential_backoff_retry":
retryer = tenacity.Retrying(
stop=tenacity.stop_after_attempt(num_retries), reraise=True
)
elif retry_strategy == "exponential_backoff_retry":
retryer = tenacity.Retrying( retryer = tenacity.Retrying(
wait=tenacity.wait_exponential(multiplier=1, max=10), wait=tenacity.wait_exponential(multiplier=1, max=10),
stop=tenacity.stop_after_attempt(num_retries), stop=tenacity.stop_after_attempt(num_retries),
reraise=True, reraise=True,
) )
else:
retryer = tenacity.Retrying(
stop=tenacity.stop_after_attempt(num_retries), reraise=True
)
return retryer(original_function, *args, **kwargs) return retryer(original_function, *args, **kwargs)
@ -2968,16 +2969,16 @@ async def acompletion_with_retries(*args, **kwargs):
num_retries = kwargs.pop("num_retries", 3) num_retries = kwargs.pop("num_retries", 3)
retry_strategy = kwargs.pop("retry_strategy", "constant_retry") retry_strategy = kwargs.pop("retry_strategy", "constant_retry")
original_function = kwargs.pop("original_function", completion) original_function = kwargs.pop("original_function", completion)
if retry_strategy == "constant_retry": if retry_strategy == "exponential_backoff_retry":
retryer = tenacity.Retrying(
stop=tenacity.stop_after_attempt(num_retries), reraise=True
)
elif retry_strategy == "exponential_backoff_retry":
retryer = tenacity.Retrying( retryer = tenacity.Retrying(
wait=tenacity.wait_exponential(multiplier=1, max=10), wait=tenacity.wait_exponential(multiplier=1, max=10),
stop=tenacity.stop_after_attempt(num_retries), stop=tenacity.stop_after_attempt(num_retries),
reraise=True, reraise=True,
) )
else:
retryer = tenacity.Retrying(
stop=tenacity.stop_after_attempt(num_retries), reraise=True
)
return await retryer(original_function, *args, **kwargs) return await retryer(original_function, *args, **kwargs)
@ -3045,7 +3046,7 @@ def batch_completion(
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
n=n, n=n,
stream=stream, stream=stream or False,
stop=stop, stop=stop,
max_tokens=max_tokens, max_tokens=max_tokens,
presence_penalty=presence_penalty, presence_penalty=presence_penalty,
@ -3124,7 +3125,7 @@ def batch_completion_models(*args, **kwargs):
models = kwargs["models"] models = kwargs["models"]
kwargs.pop("models") kwargs.pop("models")
futures = {} futures = {}
with concurrent.futures.ThreadPoolExecutor(max_workers=len(models)) as executor: with ThreadPoolExecutor(max_workers=len(models)) as executor:
for model in models: for model in models:
futures[model] = executor.submit( futures[model] = executor.submit(
completion, *args, model=model, **kwargs completion, *args, model=model, **kwargs
@ -3141,9 +3142,7 @@ def batch_completion_models(*args, **kwargs):
kwargs.pop("model_list") kwargs.pop("model_list")
nested_kwargs = kwargs.pop("kwargs", {}) nested_kwargs = kwargs.pop("kwargs", {})
futures = {} futures = {}
with concurrent.futures.ThreadPoolExecutor( with ThreadPoolExecutor(max_workers=len(deployments)) as executor:
max_workers=len(deployments)
) as executor:
for deployment in deployments: for deployment in deployments:
for key in kwargs.keys(): for key in kwargs.keys():
if ( if (
@ -3156,9 +3155,7 @@ def batch_completion_models(*args, **kwargs):
while futures: while futures:
# wait for the first returned future # wait for the first returned future
print_verbose("\n\n waiting for next result\n\n") print_verbose("\n\n waiting for next result\n\n")
done, _ = concurrent.futures.wait( done, _ = wait(futures.values(), return_when=FIRST_COMPLETED)
futures.values(), return_when=concurrent.futures.FIRST_COMPLETED
)
print_verbose(f"done list\n{done}") print_verbose(f"done list\n{done}")
for future in done: for future in done:
try: try:
@ -3214,6 +3211,8 @@ def batch_completion_models_all_responses(*args, **kwargs):
if "models" in kwargs: if "models" in kwargs:
models = kwargs["models"] models = kwargs["models"]
kwargs.pop("models") kwargs.pop("models")
else:
raise Exception("'models' param not in kwargs")
responses = [] responses = []
@ -3256,6 +3255,7 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse:
model=model, api_base=kwargs.get("api_base", None) model=model, api_base=kwargs.get("api_base", None)
) )
response: Optional[EmbeddingResponse] = None
if ( if (
custom_llm_provider == "openai" custom_llm_provider == "openai"
or custom_llm_provider == "azure" or custom_llm_provider == "azure"
@ -3294,12 +3294,21 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse:
elif isinstance(init_response, EmbeddingResponse): ## CACHING SCENARIO elif isinstance(init_response, EmbeddingResponse): ## CACHING SCENARIO
response = init_response response = init_response
elif asyncio.iscoroutine(init_response): elif asyncio.iscoroutine(init_response):
response = await init_response response = await init_response # type: ignore
else: else:
# Call the synchronous function using run_in_executor # Call the synchronous function using run_in_executor
response = await loop.run_in_executor(None, func_with_context) response = await loop.run_in_executor(None, func_with_context)
if response is not None and hasattr(response, "_hidden_params"): if (
response is not None
and isinstance(response, EmbeddingResponse)
and hasattr(response, "_hidden_params")
):
response._hidden_params["custom_llm_provider"] = custom_llm_provider response._hidden_params["custom_llm_provider"] = custom_llm_provider
if response is None:
raise ValueError(
"Unable to get Embedding Response. Please pass a valid llm_provider."
)
return response return response
except Exception as e: except Exception as e:
custom_llm_provider = custom_llm_provider or "openai" custom_llm_provider = custom_llm_provider or "openai"
@ -3329,7 +3338,6 @@ def embedding(
user: Optional[str] = None, user: Optional[str] = None,
custom_llm_provider=None, custom_llm_provider=None,
litellm_call_id=None, litellm_call_id=None,
litellm_logging_obj=None,
logger_fn=None, logger_fn=None,
**kwargs, **kwargs,
) -> EmbeddingResponse: ) -> EmbeddingResponse:
@ -3362,6 +3370,7 @@ def embedding(
client = kwargs.pop("client", None) client = kwargs.pop("client", None)
rpm = kwargs.pop("rpm", None) rpm = kwargs.pop("rpm", None)
tpm = kwargs.pop("tpm", None) tpm = kwargs.pop("tpm", None)
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
cooldown_time = kwargs.get("cooldown_time", None) cooldown_time = kwargs.get("cooldown_time", None)
max_parallel_requests = kwargs.pop("max_parallel_requests", None) max_parallel_requests = kwargs.pop("max_parallel_requests", None)
model_info = kwargs.get("model_info", None) model_info = kwargs.get("model_info", None)
@ -3491,7 +3500,7 @@ def embedding(
} }
) )
try: try:
response = None response: Optional[EmbeddingResponse] = None
logging: Logging = litellm_logging_obj # type: ignore logging: Logging = litellm_logging_obj # type: ignore
logging.update_environment_variables( logging.update_environment_variables(
model=model, model=model,
@ -3691,7 +3700,7 @@ def embedding(
raise ValueError( raise ValueError(
"api_base is required for triton. Please pass `api_base`" "api_base is required for triton. Please pass `api_base`"
) )
response = triton_chat_completions.embedding( response = triton_chat_completions.embedding( # type: ignore
model=model, model=model,
input=input, input=input,
api_base=api_base, api_base=api_base,
@ -3783,6 +3792,7 @@ def embedding(
timeout=timeout, timeout=timeout,
aembedding=aembedding, aembedding=aembedding,
print_verbose=print_verbose, print_verbose=print_verbose,
api_key=api_key,
) )
elif custom_llm_provider == "oobabooga": elif custom_llm_provider == "oobabooga":
response = oobabooga.embedding( response = oobabooga.embedding(
@ -3793,14 +3803,16 @@ def embedding(
logging_obj=logging, logging_obj=logging,
optional_params=optional_params, optional_params=optional_params,
model_response=EmbeddingResponse(), model_response=EmbeddingResponse(),
api_key=api_key,
) )
elif custom_llm_provider == "ollama": elif custom_llm_provider == "ollama":
api_base = ( api_base = (
litellm.api_base litellm.api_base
or api_base or api_base
or get_secret("OLLAMA_API_BASE") or get_secret_str("OLLAMA_API_BASE")
or "http://localhost:11434" or "http://localhost:11434"
) # type: ignore ) # type: ignore
if isinstance(input, str): if isinstance(input, str):
input = [input] input = [input]
if not all(isinstance(item, str) for item in input): if not all(isinstance(item, str) for item in input):
@ -3881,13 +3893,13 @@ def embedding(
api_key = ( api_key = (
api_key api_key
or litellm.api_key or litellm.api_key
or get_secret("XINFERENCE_API_KEY") or get_secret_str("XINFERENCE_API_KEY")
or "stub-xinference-key" or "stub-xinference-key"
) # xinference does not need an api key, pass a stub key if user did not set one ) # xinference does not need an api key, pass a stub key if user did not set one
api_base = ( api_base = (
api_base api_base
or litellm.api_base or litellm.api_base
or get_secret("XINFERENCE_API_BASE") or get_secret_str("XINFERENCE_API_BASE")
or "http://127.0.0.1:9997/v1" or "http://127.0.0.1:9997/v1"
) )
response = openai_chat_completions.embedding( response = openai_chat_completions.embedding(
@ -3911,19 +3923,20 @@ def embedding(
optional_params=optional_params, optional_params=optional_params,
model_response=EmbeddingResponse(), model_response=EmbeddingResponse(),
aembedding=aembedding, aembedding=aembedding,
api_key=api_key,
) )
elif custom_llm_provider == "azure_ai": elif custom_llm_provider == "azure_ai":
api_base = ( api_base = (
api_base # for deepinfra/perplexity/anyscale/groq/friendliai we check in get_llm_provider and pass in the api base from there api_base # for deepinfra/perplexity/anyscale/groq/friendliai we check in get_llm_provider and pass in the api base from there
or litellm.api_base or litellm.api_base
or get_secret("AZURE_AI_API_BASE") or get_secret_str("AZURE_AI_API_BASE")
) )
# set API KEY # set API KEY
api_key = ( api_key = (
api_key api_key
or litellm.api_key # for deepinfra/perplexity/anyscale/friendliai we check in get_llm_provider and pass in the api key from there or litellm.api_key # for deepinfra/perplexity/anyscale/friendliai we check in get_llm_provider and pass in the api key from there
or litellm.openai_key or litellm.openai_key
or get_secret("AZURE_AI_API_KEY") or get_secret_str("AZURE_AI_API_KEY")
) )
## EMBEDDING CALL ## EMBEDDING CALL
@ -3944,10 +3957,14 @@ def embedding(
raise ValueError(f"No valid embedding model args passed in - {args}") raise ValueError(f"No valid embedding model args passed in - {args}")
if response is not None and hasattr(response, "_hidden_params"): if response is not None and hasattr(response, "_hidden_params"):
response._hidden_params["custom_llm_provider"] = custom_llm_provider response._hidden_params["custom_llm_provider"] = custom_llm_provider
if response is None:
args = locals()
raise ValueError(f"No valid embedding model args passed in - {args}")
return response return response
except Exception as e: except Exception as e:
## LOGGING ## LOGGING
logging.post_call( litellm_logging_obj.post_call(
input=input, input=input,
api_key=api_key, api_key=api_key,
original_response=str(e), original_response=str(e),
@ -4018,7 +4035,11 @@ async def atext_completion(
else: else:
# Call the synchronous function using run_in_executor # Call the synchronous function using run_in_executor
response = await loop.run_in_executor(None, func_with_context) response = await loop.run_in_executor(None, func_with_context)
if kwargs.get("stream", False) is True: # return an async generator if (
kwargs.get("stream", False) is True
or isinstance(response, TextCompletionStreamWrapper)
or isinstance(response, CustomStreamWrapper)
): # return an async generator
return TextCompletionStreamWrapper( return TextCompletionStreamWrapper(
completion_stream=_async_streaming( completion_stream=_async_streaming(
response=response, response=response,
@ -4153,9 +4174,10 @@ def text_completion(
Your example of how to use this function goes here. Your example of how to use this function goes here.
""" """
if "engine" in kwargs: if "engine" in kwargs:
if model is None: _engine = kwargs["engine"]
if model is None and isinstance(_engine, str):
# only use engine when model not passed # only use engine when model not passed
model = kwargs["engine"] model = _engine
kwargs.pop("engine") kwargs.pop("engine")
text_completion_response = TextCompletionResponse() text_completion_response = TextCompletionResponse()
@ -4223,7 +4245,7 @@ def text_completion(
def process_prompt(i, individual_prompt): def process_prompt(i, individual_prompt):
decoded_prompt = tokenizer.decode(individual_prompt) decoded_prompt = tokenizer.decode(individual_prompt)
all_params = {**kwargs, **optional_params} all_params = {**kwargs, **optional_params}
response = text_completion( response: TextCompletionResponse = text_completion( # type: ignore
model=model, model=model,
prompt=decoded_prompt, prompt=decoded_prompt,
num_retries=3, # ensure this does not fail for the batch num_retries=3, # ensure this does not fail for the batch
@ -4292,6 +4314,8 @@ def text_completion(
model = "text-completion-openai/" + _model model = "text-completion-openai/" + _model
optional_params.pop("custom_llm_provider", None) optional_params.pop("custom_llm_provider", None)
if model is None:
raise ValueError("model is not set. Set either via 'model' or 'engine' param.")
kwargs["text_completion"] = True kwargs["text_completion"] = True
response = completion( response = completion(
model=model, model=model,
@ -4302,7 +4326,11 @@ def text_completion(
) )
if kwargs.get("acompletion", False) is True: if kwargs.get("acompletion", False) is True:
return response return response
if stream is True or kwargs.get("stream", False) is True: if (
stream is True
or kwargs.get("stream", False) is True
or isinstance(response, CustomStreamWrapper)
):
response = TextCompletionStreamWrapper( response = TextCompletionStreamWrapper(
completion_stream=response, completion_stream=response,
model=model, model=model,
@ -4310,6 +4338,8 @@ def text_completion(
custom_llm_provider=custom_llm_provider, custom_llm_provider=custom_llm_provider,
) )
return response return response
elif isinstance(response, TextCompletionStreamWrapper):
return response
transformed_logprobs = None transformed_logprobs = None
# only supported for TGI models # only supported for TGI models
try: try:
@ -4424,7 +4454,10 @@ def moderation(
): ):
# only supports open ai for now # only supports open ai for now
api_key = ( api_key = (
api_key or litellm.api_key or litellm.openai_key or get_secret("OPENAI_API_KEY") api_key
or litellm.api_key
or litellm.openai_key
or get_secret_str("OPENAI_API_KEY")
) )
openai_client = kwargs.get("client", None) openai_client = kwargs.get("client", None)
@ -4433,7 +4466,10 @@ def moderation(
api_key=api_key, api_key=api_key,
) )
if model is not None:
response = openai_client.moderations.create(input=input, model=model) response = openai_client.moderations.create(input=input, model=model)
else:
response = openai_client.moderations.create(input=input)
return response return response
@ -4441,20 +4477,30 @@ def moderation(
async def amoderation( async def amoderation(
input: str, model: Optional[str] = None, api_key: Optional[str] = None, **kwargs input: str, model: Optional[str] = None, api_key: Optional[str] = None, **kwargs
): ):
from openai import AsyncOpenAI
# only supports open ai for now # only supports open ai for now
api_key = ( api_key = (
api_key or litellm.api_key or litellm.openai_key or get_secret("OPENAI_API_KEY") api_key
or litellm.api_key
or litellm.openai_key
or get_secret_str("OPENAI_API_KEY")
) )
openai_client = kwargs.get("client", None) openai_client = kwargs.get("client", None)
if openai_client is None: if openai_client is None or not isinstance(openai_client, AsyncOpenAI):
# call helper to get OpenAI client # call helper to get OpenAI client
# _get_openai_client maintains in-memory caching logic for OpenAI clients # _get_openai_client maintains in-memory caching logic for OpenAI clients
openai_client = openai_chat_completions._get_openai_client( _openai_client: AsyncOpenAI = openai_chat_completions._get_openai_client( # type: ignore
is_async=True, is_async=True,
api_key=api_key, api_key=api_key,
) )
else:
_openai_client = openai_client
if model is not None:
response = await openai_client.moderations.create(input=input, model=model) response = await openai_client.moderations.create(input=input, model=model)
else:
response = await openai_client.moderations.create(input=input)
return response return response
@ -4497,7 +4543,7 @@ async def aimage_generation(*args, **kwargs) -> ImageResponse:
init_response = ImageResponse(**init_response) init_response = ImageResponse(**init_response)
response = init_response response = init_response
elif asyncio.iscoroutine(init_response): elif asyncio.iscoroutine(init_response):
response = await init_response response = await init_response # type: ignore
else: else:
# Call the synchronous function using run_in_executor # Call the synchronous function using run_in_executor
response = await loop.run_in_executor(None, func_with_context) response = await loop.run_in_executor(None, func_with_context)
@ -4527,7 +4573,6 @@ def image_generation(
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
api_version: Optional[str] = None, api_version: Optional[str] = None,
litellm_logging_obj=None,
custom_llm_provider=None, custom_llm_provider=None,
**kwargs, **kwargs,
) -> ImageResponse: ) -> ImageResponse:
@ -4543,9 +4588,10 @@ def image_generation(
proxy_server_request = kwargs.get("proxy_server_request", None) proxy_server_request = kwargs.get("proxy_server_request", None)
model_info = kwargs.get("model_info", None) model_info = kwargs.get("model_info", None)
metadata = kwargs.get("metadata", {}) metadata = kwargs.get("metadata", {})
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
client = kwargs.get("client", None) client = kwargs.get("client", None)
model_response = litellm.utils.ImageResponse() model_response: ImageResponse = litellm.utils.ImageResponse()
if model is not None or custom_llm_provider is not None: if model is not None or custom_llm_provider is not None:
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
else: else:
@ -4651,25 +4697,27 @@ def image_generation(
if custom_llm_provider == "azure": if custom_llm_provider == "azure":
# azure configs # azure configs
api_type = get_secret("AZURE_API_TYPE") or "azure" api_type = get_secret_str("AZURE_API_TYPE") or "azure"
api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE") api_base = api_base or litellm.api_base or get_secret_str("AZURE_API_BASE")
api_version = ( api_version = (
api_version or litellm.api_version or get_secret("AZURE_API_VERSION") api_version
or litellm.api_version
or get_secret_str("AZURE_API_VERSION")
) )
api_key = ( api_key = (
api_key api_key
or litellm.api_key or litellm.api_key
or litellm.azure_key or litellm.azure_key
or get_secret("AZURE_OPENAI_API_KEY") or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret("AZURE_API_KEY") or get_secret_str("AZURE_API_KEY")
) )
azure_ad_token = optional_params.pop("azure_ad_token", None) or get_secret( azure_ad_token = optional_params.pop(
"AZURE_AD_TOKEN" "azure_ad_token", None
) ) or get_secret_str("AZURE_AD_TOKEN")
model_response = azure_chat_completions.image_generation( model_response = azure_chat_completions.image_generation(
model=model, model=model,
@ -4714,18 +4762,18 @@ def image_generation(
optional_params.pop("vertex_project", None) optional_params.pop("vertex_project", None)
or optional_params.pop("vertex_ai_project", None) or optional_params.pop("vertex_ai_project", None)
or litellm.vertex_project or litellm.vertex_project
or get_secret("VERTEXAI_PROJECT") or get_secret_str("VERTEXAI_PROJECT")
) )
vertex_ai_location = ( vertex_ai_location = (
optional_params.pop("vertex_location", None) optional_params.pop("vertex_location", None)
or optional_params.pop("vertex_ai_location", None) or optional_params.pop("vertex_ai_location", None)
or litellm.vertex_location or litellm.vertex_location
or get_secret("VERTEXAI_LOCATION") or get_secret_str("VERTEXAI_LOCATION")
) )
vertex_credentials = ( vertex_credentials = (
optional_params.pop("vertex_credentials", None) optional_params.pop("vertex_credentials", None)
or optional_params.pop("vertex_ai_credentials", None) or optional_params.pop("vertex_ai_credentials", None)
or get_secret("VERTEXAI_CREDENTIALS") or get_secret_str("VERTEXAI_CREDENTIALS")
) )
model_response = vertex_image_generation.image_generation( model_response = vertex_image_generation.image_generation(
model=model, model=model,
@ -4786,7 +4834,7 @@ async def atranscription(*args, **kwargs) -> TranscriptionResponse:
elif isinstance(init_response, TranscriptionResponse): ## CACHING SCENARIO elif isinstance(init_response, TranscriptionResponse): ## CACHING SCENARIO
response = init_response response = init_response
elif asyncio.iscoroutine(init_response): elif asyncio.iscoroutine(init_response):
response = await init_response response = await init_response # type: ignore
else: else:
# Call the synchronous function using run_in_executor # Call the synchronous function using run_in_executor
response = await loop.run_in_executor(None, func_with_context) response = await loop.run_in_executor(None, func_with_context)
@ -4820,7 +4868,6 @@ def transcription(
api_base: Optional[str] = None, api_base: Optional[str] = None,
api_version: Optional[str] = None, api_version: Optional[str] = None,
max_retries: Optional[int] = None, max_retries: Optional[int] = None,
litellm_logging_obj: Optional[LiteLLMLoggingObj] = None,
custom_llm_provider=None, custom_llm_provider=None,
**kwargs, **kwargs,
) -> TranscriptionResponse: ) -> TranscriptionResponse:
@ -4830,6 +4877,7 @@ def transcription(
Allows router to load balance between them Allows router to load balance between them
""" """
atranscription = kwargs.get("atranscription", False) atranscription = kwargs.get("atranscription", False)
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
kwargs.get("litellm_call_id", None) kwargs.get("litellm_call_id", None)
kwargs.get("logger_fn", None) kwargs.get("logger_fn", None)
kwargs.get("proxy_server_request", None) kwargs.get("proxy_server_request", None)
@ -4869,22 +4917,17 @@ def transcription(
custom_llm_provider=custom_llm_provider, custom_llm_provider=custom_llm_provider,
drop_params=drop_params, drop_params=drop_params,
) )
# optional_params = {
# "language": language,
# "prompt": prompt,
# "response_format": response_format,
# "temperature": None, # openai defaults this to 0
# }
response: Optional[TranscriptionResponse] = None
if custom_llm_provider == "azure": if custom_llm_provider == "azure":
# azure configs # azure configs
api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE") api_base = api_base or litellm.api_base or get_secret_str("AZURE_API_BASE")
api_version = ( api_version = (
api_version or litellm.api_version or get_secret("AZURE_API_VERSION") api_version or litellm.api_version or get_secret_str("AZURE_API_VERSION")
) )
azure_ad_token = kwargs.pop("azure_ad_token", None) or get_secret( azure_ad_token = kwargs.pop("azure_ad_token", None) or get_secret_str(
"AZURE_AD_TOKEN" "AZURE_AD_TOKEN"
) )
@ -4892,8 +4935,8 @@ def transcription(
api_key api_key
or litellm.api_key or litellm.api_key
or litellm.azure_key or litellm.azure_key
or get_secret("AZURE_API_KEY") or get_secret_str("AZURE_API_KEY")
) # type: ignore )
response = azure_audio_transcriptions.audio_transcriptions( response = azure_audio_transcriptions.audio_transcriptions(
model=model, model=model,
@ -4942,6 +4985,9 @@ def transcription(
api_base=api_base, api_base=api_base,
api_key=api_key, api_key=api_key,
) )
if response is None:
raise ValueError("Unmapped provider passed in. Unable to get the response.")
return response return response
@ -5149,15 +5195,16 @@ def speech(
vertex_ai_project = ( vertex_ai_project = (
generic_optional_params.vertex_project generic_optional_params.vertex_project
or litellm.vertex_project or litellm.vertex_project
or get_secret("VERTEXAI_PROJECT") or get_secret_str("VERTEXAI_PROJECT")
) )
vertex_ai_location = ( vertex_ai_location = (
generic_optional_params.vertex_location generic_optional_params.vertex_location
or litellm.vertex_location or litellm.vertex_location
or get_secret("VERTEXAI_LOCATION") or get_secret_str("VERTEXAI_LOCATION")
) )
vertex_credentials = generic_optional_params.vertex_credentials or get_secret( vertex_credentials = (
"VERTEXAI_CREDENTIALS" generic_optional_params.vertex_credentials
or get_secret_str("VERTEXAI_CREDENTIALS")
) )
if voice is not None and not isinstance(voice, dict): if voice is not None and not isinstance(voice, dict):
@ -5234,20 +5281,25 @@ async def ahealth_check(
if custom_llm_provider == "azure": if custom_llm_provider == "azure":
api_key = ( api_key = (
model_params.get("api_key") model_params.get("api_key")
or get_secret("AZURE_API_KEY") or get_secret_str("AZURE_API_KEY")
or get_secret("AZURE_OPENAI_API_KEY") or get_secret_str("AZURE_OPENAI_API_KEY")
) )
api_base = ( api_base: Optional[str] = (
model_params.get("api_base") model_params.get("api_base")
or get_secret("AZURE_API_BASE") or get_secret_str("AZURE_API_BASE")
or get_secret("AZURE_OPENAI_API_BASE") or get_secret_str("AZURE_OPENAI_API_BASE")
)
if api_base is None:
raise ValueError(
"Azure API Base cannot be None. Set via 'AZURE_API_BASE' in env var or `.completion(..., api_base=..)`"
) )
api_version = ( api_version = (
model_params.get("api_version") model_params.get("api_version")
or get_secret("AZURE_API_VERSION") or get_secret_str("AZURE_API_VERSION")
or get_secret("AZURE_OPENAI_API_VERSION") or get_secret_str("AZURE_OPENAI_API_VERSION")
) )
timeout = ( timeout = (
@ -5273,7 +5325,7 @@ async def ahealth_check(
custom_llm_provider == "openai" custom_llm_provider == "openai"
or custom_llm_provider == "text-completion-openai" or custom_llm_provider == "text-completion-openai"
): ):
api_key = model_params.get("api_key") or get_secret("OPENAI_API_KEY") api_key = model_params.get("api_key") or get_secret_str("OPENAI_API_KEY")
organization = model_params.get("organization") organization = model_params.get("organization")
timeout = ( timeout = (
@ -5282,7 +5334,7 @@ async def ahealth_check(
or default_timeout or default_timeout
) )
api_base = model_params.get("api_base") or get_secret("OPENAI_API_BASE") api_base = model_params.get("api_base") or get_secret_str("OPENAI_API_BASE")
if custom_llm_provider == "text-completion-openai": if custom_llm_provider == "text-completion-openai":
mode = "completion" mode = "completion"
@ -5377,7 +5429,9 @@ def config_completion(**kwargs):
) )
def stream_chunk_builder_text_completion(chunks: list, messages: Optional[List] = None): def stream_chunk_builder_text_completion(
chunks: list, messages: Optional[List] = None
) -> TextCompletionResponse:
id = chunks[0]["id"] id = chunks[0]["id"]
object = chunks[0]["object"] object = chunks[0]["object"]
created = chunks[0]["created"] created = chunks[0]["created"]
@ -5446,7 +5500,7 @@ def stream_chunk_builder_text_completion(chunks: list, messages: Optional[List]
response["usage"]["total_tokens"] = ( response["usage"]["total_tokens"] = (
response["usage"]["prompt_tokens"] + response["usage"]["completion_tokens"] response["usage"]["prompt_tokens"] + response["usage"]["completion_tokens"]
) )
return response return TextCompletionResponse(**response)
def stream_chunk_builder( def stream_chunk_builder(

View file

@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
from pydantic import BaseModel, ConfigDict, Extra, Field, Json, model_validator from pydantic import BaseModel, ConfigDict, Extra, Field, Json, model_validator
from typing_extensions import Annotated, TypedDict from typing_extensions import Annotated, TypedDict
from litellm.integrations.SlackAlerting.types import AlertType from litellm.types.integrations.slack_alerting import AlertType
from litellm.types.router import RouterErrors, UpdateRouterConfig from litellm.types.router import RouterErrors, UpdateRouterConfig
from litellm.types.utils import ProviderField, StandardCallbackDynamicParams from litellm.types.utils import ProviderField, StandardCallbackDynamicParams

View file

@ -1,12 +1,13 @@
from typing import Optional
from fastapi import Depends, Request, APIRouter
from fastapi import HTTPException
import copy import copy
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Request
import litellm import litellm
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.caching import RedisCache
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
router = APIRouter( router = APIRouter(
prefix="/cache", prefix="/cache",
tags=["caching"], tags=["caching"],
@ -21,9 +22,9 @@ async def cache_ping():
""" """
Endpoint for checking if cache can be pinged Endpoint for checking if cache can be pinged
""" """
try:
litellm_cache_params = {} litellm_cache_params = {}
specific_cache_params = {} specific_cache_params = {}
try:
if litellm.cache is None: if litellm.cache is None:
raise HTTPException( raise HTTPException(
@ -135,7 +136,9 @@ async def cache_redis_info():
raise HTTPException( raise HTTPException(
status_code=503, detail="Cache not initialized. litellm.cache is None" status_code=503, detail="Cache not initialized. litellm.cache is None"
) )
if litellm.cache.type == "redis": if litellm.cache.type == "redis" and isinstance(
litellm.cache.cache, RedisCache
):
client_list = litellm.cache.cache.client_list() client_list = litellm.cache.cache.client_list()
redis_info = litellm.cache.cache.info() redis_info = litellm.cache.cache.info()
num_clients = len(client_list) num_clients = len(client_list)
@ -177,7 +180,9 @@ async def cache_flushall():
raise HTTPException( raise HTTPException(
status_code=503, detail="Cache not initialized. litellm.cache is None" status_code=503, detail="Cache not initialized. litellm.cache is None"
) )
if litellm.cache.type == "redis": if litellm.cache.type == "redis" and isinstance(
litellm.cache.cache, RedisCache
):
litellm.cache.cache.flushall() litellm.cache.cache.flushall()
return { return {
"status": "success", "status": "success",

View file

@ -118,13 +118,13 @@ async def create_fine_tuning_job(
version, version,
) )
data = fine_tuning_request.model_dump(exclude_none=True)
try: try:
if premium_user is not True: if premium_user is not True:
raise ValueError( raise ValueError(
f"Only premium users can use this endpoint + {CommonProxyErrors.not_premium_user.value}" f"Only premium users can use this endpoint + {CommonProxyErrors.not_premium_user.value}"
) )
# Convert Pydantic model to dict # Convert Pydantic model to dict
data = fine_tuning_request.model_dump(exclude_none=True)
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
"Request received by LiteLLM:\n{}".format(json.dumps(data, indent=4)), "Request received by LiteLLM:\n{}".format(json.dumps(data, indent=4)),
@ -146,6 +146,7 @@ async def create_fine_tuning_job(
) )
# add llm_provider_config to data # add llm_provider_config to data
if llm_provider_config is not None:
data.update(llm_provider_config) data.update(llm_provider_config)
response = await litellm.acreate_fine_tuning_job(**data) response = await litellm.acreate_fine_tuning_job(**data)
@ -262,6 +263,7 @@ async def list_fine_tuning_jobs(
custom_llm_provider=custom_llm_provider custom_llm_provider=custom_llm_provider
) )
if llm_provider_config is not None:
data.update(llm_provider_config) data.update(llm_provider_config)
response = await litellm.alist_fine_tuning_jobs( response = await litellm.alist_fine_tuning_jobs(
@ -378,6 +380,7 @@ async def retrieve_fine_tuning_job(
custom_llm_provider=custom_llm_provider custom_llm_provider=custom_llm_provider
) )
if llm_provider_config is not None:
data.update(llm_provider_config) data.update(llm_provider_config)
response = await litellm.acancel_fine_tuning_job( response = await litellm.acancel_fine_tuning_job(

View file

@ -227,8 +227,8 @@ async def send_management_endpoint_alert(
- An internal user is created, updated, or deleted - An internal user is created, updated, or deleted
- A team is created, updated, or deleted - A team is created, updated, or deleted
""" """
from litellm.integrations.SlackAlerting.types import AlertType
from litellm.proxy.proxy_server import premium_user, proxy_logging_obj from litellm.proxy.proxy_server import premium_user, proxy_logging_obj
from litellm.types.integrations.slack_alerting import AlertType
if premium_user is not True: if premium_user is not True:
return return

View file

@ -114,10 +114,7 @@ from litellm import (
from litellm._logging import verbose_proxy_logger, verbose_router_logger from litellm._logging import verbose_proxy_logger, verbose_router_logger
from litellm.caching import DualCache, RedisCache from litellm.caching import DualCache, RedisCache
from litellm.exceptions import RejectedRequestError from litellm.exceptions import RejectedRequestError
from litellm.integrations.SlackAlerting.slack_alerting import ( from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting
SlackAlerting,
SlackAlertingArgs,
)
from litellm.litellm_core_utils.core_helpers import ( from litellm.litellm_core_utils.core_helpers import (
_get_parent_otel_span_from_kwargs, _get_parent_otel_span_from_kwargs,
get_litellm_metadata_from_kwargs, get_litellm_metadata_from_kwargs,
@ -249,6 +246,7 @@ from litellm.secret_managers.aws_secret_manager import (
) )
from litellm.secret_managers.google_kms import load_google_kms from litellm.secret_managers.google_kms import load_google_kms
from litellm.secret_managers.main import get_secret, get_secret_str, str_to_bool from litellm.secret_managers.main import get_secret, get_secret_str, str_to_bool
from litellm.types.integrations.slack_alerting import SlackAlertingArgs
from litellm.types.llms.anthropic import ( from litellm.types.llms.anthropic import (
AnthropicMessagesRequest, AnthropicMessagesRequest,
AnthropicResponse, AnthropicResponse,

View file

@ -54,7 +54,6 @@ from litellm.exceptions import RejectedRequestError
from litellm.integrations.custom_guardrail import CustomGuardrail from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting
from litellm.integrations.SlackAlerting.types import DEFAULT_ALERT_TYPES
from litellm.integrations.SlackAlerting.utils import _add_langfuse_trace_id_to_alert from litellm.integrations.SlackAlerting.utils import _add_langfuse_trace_id_to_alert
from litellm.litellm_core_utils.core_helpers import ( from litellm.litellm_core_utils.core_helpers import (
_get_parent_otel_span_from_kwargs, _get_parent_otel_span_from_kwargs,
@ -85,6 +84,7 @@ from litellm.proxy.hooks.parallel_request_limiter import (
_PROXY_MaxParallelRequestsHandler, _PROXY_MaxParallelRequestsHandler,
) )
from litellm.secret_managers.main import str_to_bool from litellm.secret_managers.main import str_to_bool
from litellm.types.integrations.slack_alerting import DEFAULT_ALERT_TYPES
from litellm.types.utils import CallTypes, LoggedLiteLLMParams from litellm.types.utils import CallTypes, LoggedLiteLLMParams
if TYPE_CHECKING: if TYPE_CHECKING:

View file

@ -208,7 +208,7 @@ async def vertex_proxy_route(
request, request,
fastapi_response, fastapi_response,
user_api_key_dict, user_api_key_dict,
stream=is_streaming_request, stream=is_streaming_request, # type: ignore
) )
return received_value return received_value

View file

@ -10,11 +10,10 @@ from litellm.llms.azure_ai.rerank import AzureAIRerank
from litellm.llms.cohere.rerank import CohereRerank from litellm.llms.cohere.rerank import CohereRerank
from litellm.llms.together_ai.rerank import TogetherAIRerank from litellm.llms.together_ai.rerank import TogetherAIRerank
from litellm.secret_managers.main import get_secret from litellm.secret_managers.main import get_secret
from litellm.types.rerank import RerankRequest, RerankResponse
from litellm.types.router import * from litellm.types.router import *
from litellm.utils import client, exception_type, supports_httpx_timeout from litellm.utils import client, exception_type, supports_httpx_timeout
from .types import RerankRequest, RerankResponse
####### ENVIRONMENT VARIABLES ################### ####### ENVIRONMENT VARIABLES ###################
# Initialize any necessary instances or variables here # Initialize any necessary instances or variables here
cohere_rerank = CohereRerank() cohere_rerank = CohereRerank()

View file

@ -7,6 +7,7 @@ import httpx
import openai import openai
import litellm import litellm
from litellm import get_secret, get_secret_str
from litellm._logging import verbose_router_logger from litellm._logging import verbose_router_logger
from litellm.llms.AzureOpenAI.azure import get_azure_ad_token_from_oidc from litellm.llms.AzureOpenAI.azure import get_azure_ad_token_from_oidc
from litellm.secret_managers.get_azure_ad_token_provider import ( from litellm.secret_managers.get_azure_ad_token_provider import (
@ -111,17 +112,17 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
api_key = litellm_params.get("api_key") or default_api_key api_key = litellm_params.get("api_key") or default_api_key
if api_key and isinstance(api_key, str) and api_key.startswith("os.environ/"): if api_key and isinstance(api_key, str) and api_key.startswith("os.environ/"):
api_key_env_name = api_key.replace("os.environ/", "") api_key_env_name = api_key.replace("os.environ/", "")
api_key = litellm.get_secret(api_key_env_name) api_key = get_secret_str(api_key_env_name)
litellm_params["api_key"] = api_key litellm_params["api_key"] = api_key
api_base = litellm_params.get("api_base") api_base = litellm_params.get("api_base")
base_url = litellm_params.get("base_url") base_url: Optional[str] = litellm_params.get("base_url")
api_base = ( api_base = (
api_base or base_url or default_api_base api_base or base_url or default_api_base
) # allow users to pass in `api_base` or `base_url` for azure ) # allow users to pass in `api_base` or `base_url` for azure
if api_base and api_base.startswith("os.environ/"): if api_base and api_base.startswith("os.environ/"):
api_base_env_name = api_base.replace("os.environ/", "") api_base_env_name = api_base.replace("os.environ/", "")
api_base = litellm.get_secret(api_base_env_name) api_base = get_secret_str(api_base_env_name)
litellm_params["api_base"] = api_base litellm_params["api_base"] = api_base
## AZURE AI STUDIO MISTRAL CHECK ## ## AZURE AI STUDIO MISTRAL CHECK ##
@ -147,33 +148,37 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
api_version = litellm_params.get("api_version") api_version = litellm_params.get("api_version")
if api_version and api_version.startswith("os.environ/"): if api_version and api_version.startswith("os.environ/"):
api_version_env_name = api_version.replace("os.environ/", "") api_version_env_name = api_version.replace("os.environ/", "")
api_version = litellm.get_secret(api_version_env_name) api_version = get_secret_str(api_version_env_name)
litellm_params["api_version"] = api_version litellm_params["api_version"] = api_version
timeout = litellm_params.pop("timeout", None) or litellm.request_timeout timeout: Optional[float] = (
litellm_params.pop("timeout", None) or litellm.request_timeout
)
if isinstance(timeout, str) and timeout.startswith("os.environ/"): if isinstance(timeout, str) and timeout.startswith("os.environ/"):
timeout_env_name = timeout.replace("os.environ/", "") timeout_env_name = timeout.replace("os.environ/", "")
timeout = litellm.get_secret(timeout_env_name) timeout = get_secret(timeout_env_name) # type: ignore
litellm_params["timeout"] = timeout litellm_params["timeout"] = timeout
stream_timeout = litellm_params.pop( stream_timeout: Optional[float] = litellm_params.pop(
"stream_timeout", timeout "stream_timeout", timeout
) # if no stream_timeout is set, default to timeout ) # if no stream_timeout is set, default to timeout
if isinstance(stream_timeout, str) and stream_timeout.startswith("os.environ/"): if isinstance(stream_timeout, str) and stream_timeout.startswith("os.environ/"):
stream_timeout_env_name = stream_timeout.replace("os.environ/", "") stream_timeout_env_name = stream_timeout.replace("os.environ/", "")
stream_timeout = litellm.get_secret(stream_timeout_env_name) stream_timeout = get_secret(stream_timeout_env_name) # type: ignore
litellm_params["stream_timeout"] = stream_timeout litellm_params["stream_timeout"] = stream_timeout
max_retries = litellm_params.pop("max_retries", 0) # router handles retry logic max_retries: Optional[int] = litellm_params.pop(
"max_retries", 0
) # router handles retry logic
if isinstance(max_retries, str) and max_retries.startswith("os.environ/"): if isinstance(max_retries, str) and max_retries.startswith("os.environ/"):
max_retries_env_name = max_retries.replace("os.environ/", "") max_retries_env_name = max_retries.replace("os.environ/", "")
max_retries = litellm.get_secret(max_retries_env_name) max_retries = get_secret(max_retries_env_name) # type: ignore
litellm_params["max_retries"] = max_retries litellm_params["max_retries"] = max_retries
organization = litellm_params.get("organization", None) organization = litellm_params.get("organization", None)
if isinstance(organization, str) and organization.startswith("os.environ/"): if isinstance(organization, str) and organization.startswith("os.environ/"):
organization_env_name = organization.replace("os.environ/", "") organization_env_name = organization.replace("os.environ/", "")
organization = litellm.get_secret(organization_env_name) organization = get_secret_str(organization_env_name)
litellm_params["organization"] = organization litellm_params["organization"] = organization
azure_ad_token_provider: Optional[Callable[[], str]] = None azure_ad_token_provider: Optional[Callable[[], str]] = None
if litellm_params.get("tenant_id"): if litellm_params.get("tenant_id"):
@ -227,8 +232,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
azure_ad_token_provider=azure_ad_token_provider, azure_ad_token_provider=azure_ad_token_provider,
base_url=api_base, base_url=api_base,
api_version=api_version, api_version=api_version,
timeout=timeout, timeout=timeout, # type: ignore
max_retries=max_retries, max_retries=max_retries, # type: ignore
http_client=httpx.AsyncClient( http_client=httpx.AsyncClient(
limits=httpx.Limits( limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100 max_connections=1000, max_keepalive_connections=100
@ -253,8 +258,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
azure_ad_token_provider=azure_ad_token_provider, azure_ad_token_provider=azure_ad_token_provider,
base_url=api_base, base_url=api_base,
api_version=api_version, api_version=api_version,
timeout=timeout, timeout=timeout, # type: ignore
max_retries=max_retries, max_retries=max_retries, # type: ignore
http_client=httpx.Client( http_client=httpx.Client(
limits=httpx.Limits( limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100 max_connections=1000, max_keepalive_connections=100
@ -276,8 +281,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
azure_ad_token_provider=azure_ad_token_provider, azure_ad_token_provider=azure_ad_token_provider,
base_url=api_base, base_url=api_base,
api_version=api_version, api_version=api_version,
timeout=stream_timeout, timeout=stream_timeout, # type: ignore
max_retries=max_retries, max_retries=max_retries, # type: ignore
http_client=httpx.AsyncClient( http_client=httpx.AsyncClient(
limits=httpx.Limits( limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100 max_connections=1000, max_keepalive_connections=100
@ -302,8 +307,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
azure_ad_token_provider=azure_ad_token_provider, azure_ad_token_provider=azure_ad_token_provider,
base_url=api_base, base_url=api_base,
api_version=api_version, api_version=api_version,
timeout=stream_timeout, timeout=stream_timeout, # type: ignore
max_retries=max_retries, max_retries=max_retries, # type: ignore
http_client=httpx.Client( http_client=httpx.Client(
limits=httpx.Limits( limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100 max_connections=1000, max_keepalive_connections=100
@ -350,8 +355,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
cache_key = f"{model_id}_async_client" cache_key = f"{model_id}_async_client"
_client = openai.AsyncAzureOpenAI( # type: ignore _client = openai.AsyncAzureOpenAI( # type: ignore
**azure_client_params, **azure_client_params,
timeout=timeout, timeout=timeout, # type: ignore
max_retries=max_retries, max_retries=max_retries, # type: ignore
http_client=httpx.AsyncClient( http_client=httpx.AsyncClient(
limits=httpx.Limits( limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100 max_connections=1000, max_keepalive_connections=100
@ -371,8 +376,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
cache_key = f"{model_id}_client" cache_key = f"{model_id}_client"
_client = openai.AzureOpenAI( # type: ignore _client = openai.AzureOpenAI( # type: ignore
**azure_client_params, **azure_client_params,
timeout=timeout, timeout=timeout, # type: ignore
max_retries=max_retries, max_retries=max_retries, # type: ignore
http_client=httpx.Client( http_client=httpx.Client(
limits=httpx.Limits( limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100 max_connections=1000, max_keepalive_connections=100
@ -391,8 +396,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
cache_key = f"{model_id}_stream_async_client" cache_key = f"{model_id}_stream_async_client"
_client = openai.AsyncAzureOpenAI( # type: ignore _client = openai.AsyncAzureOpenAI( # type: ignore
**azure_client_params, **azure_client_params,
timeout=stream_timeout, timeout=stream_timeout, # type: ignore
max_retries=max_retries, max_retries=max_retries, # type: ignore
http_client=httpx.AsyncClient( http_client=httpx.AsyncClient(
limits=httpx.Limits( limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100 max_connections=1000, max_keepalive_connections=100
@ -413,8 +418,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
cache_key = f"{model_id}_stream_client" cache_key = f"{model_id}_stream_client"
_client = openai.AzureOpenAI( # type: ignore _client = openai.AzureOpenAI( # type: ignore
**azure_client_params, **azure_client_params,
timeout=stream_timeout, timeout=stream_timeout, # type: ignore
max_retries=max_retries, max_retries=max_retries, # type: ignore
http_client=httpx.Client( http_client=httpx.Client(
limits=httpx.Limits( limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100 max_connections=1000, max_keepalive_connections=100
@ -441,8 +446,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
_client = openai.AsyncOpenAI( # type: ignore _client = openai.AsyncOpenAI( # type: ignore
api_key=api_key, api_key=api_key,
base_url=api_base, base_url=api_base,
timeout=timeout, timeout=timeout, # type: ignore
max_retries=max_retries, max_retries=max_retries, # type: ignore
organization=organization, organization=organization,
http_client=httpx.AsyncClient( http_client=httpx.AsyncClient(
limits=httpx.Limits( limits=httpx.Limits(
@ -465,8 +470,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
_client = openai.OpenAI( # type: ignore _client = openai.OpenAI( # type: ignore
api_key=api_key, api_key=api_key,
base_url=api_base, base_url=api_base,
timeout=timeout, timeout=timeout, # type: ignore
max_retries=max_retries, max_retries=max_retries, # type: ignore
organization=organization, organization=organization,
http_client=httpx.Client( http_client=httpx.Client(
limits=httpx.Limits( limits=httpx.Limits(
@ -487,8 +492,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
_client = openai.AsyncOpenAI( # type: ignore _client = openai.AsyncOpenAI( # type: ignore
api_key=api_key, api_key=api_key,
base_url=api_base, base_url=api_base,
timeout=stream_timeout, timeout=stream_timeout, # type: ignore
max_retries=max_retries, max_retries=max_retries, # type: ignore
organization=organization, organization=organization,
http_client=httpx.AsyncClient( http_client=httpx.AsyncClient(
limits=httpx.Limits( limits=httpx.Limits(
@ -512,8 +517,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
_client = openai.OpenAI( # type: ignore _client = openai.OpenAI( # type: ignore
api_key=api_key, api_key=api_key,
base_url=api_base, base_url=api_base,
timeout=stream_timeout, timeout=stream_timeout, # type: ignore
max_retries=max_retries, max_retries=max_retries, # type: ignore
organization=organization, organization=organization,
http_client=httpx.Client( http_client=httpx.Client(
limits=httpx.Limits( limits=httpx.Limits(
@ -542,20 +547,29 @@ def get_azure_ad_token_from_entrata_id(
verbose_router_logger.debug("Getting Azure AD Token from Entrata ID") verbose_router_logger.debug("Getting Azure AD Token from Entrata ID")
if tenant_id.startswith("os.environ/"): if tenant_id.startswith("os.environ/"):
tenant_id = litellm.get_secret(tenant_id) _tenant_id = get_secret_str(tenant_id)
else:
_tenant_id = tenant_id
if client_id.startswith("os.environ/"): if client_id.startswith("os.environ/"):
client_id = litellm.get_secret(client_id) _client_id = get_secret_str(client_id)
else:
_client_id = client_id
if client_secret.startswith("os.environ/"): if client_secret.startswith("os.environ/"):
client_secret = litellm.get_secret(client_secret) _client_secret = get_secret_str(client_secret)
else:
_client_secret = client_secret
verbose_router_logger.debug( verbose_router_logger.debug(
"tenant_id %s, client_id %s, client_secret %s", "tenant_id %s, client_id %s, client_secret %s",
tenant_id, _tenant_id,
client_id, _client_id,
client_secret, _client_secret,
) )
credential = ClientSecretCredential(tenant_id, client_id, client_secret) if _tenant_id is None or _client_id is None or _client_secret is None:
raise ValueError("tenant_id, client_id, and client_secret must be provided")
credential = ClientSecretCredential(_tenant_id, _client_id, _client_secret)
verbose_router_logger.debug("credential %s", credential) verbose_router_logger.debug("credential %s", credential)

View file

@ -2,6 +2,8 @@ import asyncio
import traceback import traceback
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from litellm.types.integrations.slack_alerting import AlertType
if TYPE_CHECKING: if TYPE_CHECKING:
from litellm.router import Router as _Router from litellm.router import Router as _Router
@ -49,5 +51,6 @@ async def send_llm_exception_alert(
await litellm_router_instance.slack_alerting_logger.send_alert( await litellm_router_instance.slack_alerting_logger.send_alert(
message=f"LLM API call failed: `{exception_str}`", message=f"LLM API call failed: `{exception_str}`",
level="High", level="High",
alert_type="llm_exceptions", alert_type=AlertType.llm_exceptions,
alerting_metadata={},
) )

View file

@ -792,7 +792,7 @@ class EmbeddingResponse(OpenAIObject):
model: Optional[str] = None model: Optional[str] = None
"""The model used for embedding.""" """The model used for embedding."""
data: Optional[List] = None data: List
"""The actual embedding value""" """The actual embedding value"""
object: Literal["list"] object: Literal["list"]
@ -803,6 +803,7 @@ class EmbeddingResponse(OpenAIObject):
_hidden_params: dict = {} _hidden_params: dict = {}
_response_headers: Optional[Dict] = None _response_headers: Optional[Dict] = None
_response_ms: Optional[float] = None
def __init__( def __init__(
self, self,
@ -822,7 +823,7 @@ class EmbeddingResponse(OpenAIObject):
if data: if data:
data = data data = data
else: else:
data = None data = []
if usage: if usage:
usage = usage usage = usage

View file

@ -322,6 +322,9 @@ def function_setup(
original_function: str, rules_obj, start_time, *args, **kwargs original_function: str, rules_obj, start_time, *args, **kwargs
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc. ): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
### NOTICES ### ### NOTICES ###
from litellm import Logging as LiteLLMLogging
from litellm.litellm_core_utils.litellm_logging import set_callbacks
if litellm.set_verbose is True: if litellm.set_verbose is True:
verbose_logger.warning( verbose_logger.warning(
"`litellm.set_verbose` is deprecated. Please set `os.environ['LITELLM_LOG'] = 'DEBUG'` for debug logs." "`litellm.set_verbose` is deprecated. Please set `os.environ['LITELLM_LOG'] = 'DEBUG'` for debug logs."
@ -333,7 +336,7 @@ def function_setup(
custom_llm_setup() custom_llm_setup()
## LOGGING SETUP ## LOGGING SETUP
function_id = kwargs["id"] if "id" in kwargs else None function_id: Optional[str] = kwargs["id"] if "id" in kwargs else None
if len(litellm.callbacks) > 0: if len(litellm.callbacks) > 0:
for callback in litellm.callbacks: for callback in litellm.callbacks:
@ -375,9 +378,7 @@ def function_setup(
+ litellm.failure_callback + litellm.failure_callback
) )
) )
litellm.litellm_core_utils.litellm_logging.set_callbacks( set_callbacks(callback_list=callback_list, function_id=function_id)
callback_list=callback_list, function_id=function_id
)
## ASYNC CALLBACKS ## ASYNC CALLBACKS
if len(litellm.input_callback) > 0: if len(litellm.input_callback) > 0:
removed_async_items = [] removed_async_items = []
@ -560,12 +561,12 @@ def function_setup(
else: else:
messages = "default-message-value" messages = "default-message-value"
stream = True if "stream" in kwargs and kwargs["stream"] is True else False stream = True if "stream" in kwargs and kwargs["stream"] is True else False
logging_obj = litellm.litellm_core_utils.litellm_logging.Logging( logging_obj = LiteLLMLogging(
model=model, model=model,
messages=messages, messages=messages,
stream=stream, stream=stream,
litellm_call_id=kwargs["litellm_call_id"], litellm_call_id=kwargs["litellm_call_id"],
function_id=function_id, function_id=function_id or "",
call_type=call_type, call_type=call_type,
start_time=start_time, start_time=start_time,
dynamic_success_callbacks=dynamic_success_callbacks, dynamic_success_callbacks=dynamic_success_callbacks,
@ -655,10 +656,8 @@ def client(original_function):
json_response_format = optional_params[ json_response_format = optional_params[
"response_format" "response_format"
] ]
elif ( elif _parsing._completions.is_basemodel_type(
_parsing._completions.is_basemodel_type( optional_params["response_format"] # type: ignore
optional_params["response_format"]
)
): ):
json_response_format = ( json_response_format = (
type_to_response_format_param( type_to_response_format_param(
@ -827,6 +826,7 @@ def client(original_function):
print_verbose("INSIDE CHECKING CACHE") print_verbose("INSIDE CHECKING CACHE")
if ( if (
litellm.cache is not None litellm.cache is not None
and litellm.cache.supported_call_types is not None
and str(original_function.__name__) and str(original_function.__name__)
in litellm.cache.supported_call_types in litellm.cache.supported_call_types
): ):
@ -879,7 +879,7 @@ def client(original_function):
dynamic_api_key, dynamic_api_key,
api_base, api_base,
) = litellm.get_llm_provider( ) = litellm.get_llm_provider(
model=model, model=model or "",
custom_llm_provider=kwargs.get( custom_llm_provider=kwargs.get(
"custom_llm_provider", None "custom_llm_provider", None
), ),
@ -949,6 +949,8 @@ def client(original_function):
base_model=base_model, base_model=base_model,
messages=messages, messages=messages,
user_max_tokens=user_max_tokens, user_max_tokens=user_max_tokens,
buffer_num=None,
buffer_perc=None,
) )
kwargs["max_tokens"] = modified_max_tokens kwargs["max_tokens"] = modified_max_tokens
except Exception as e: except Exception as e:
@ -990,6 +992,7 @@ def client(original_function):
# [OPTIONAL] ADD TO CACHE # [OPTIONAL] ADD TO CACHE
if ( if (
litellm.cache is not None litellm.cache is not None
and litellm.cache.supported_call_types is not None
and str(original_function.__name__) and str(original_function.__name__)
in litellm.cache.supported_call_types in litellm.cache.supported_call_types
) and (kwargs.get("cache", {}).get("no-store", False) is not True): ) and (kwargs.get("cache", {}).get("no-store", False) is not True):
@ -1006,7 +1009,7 @@ def client(original_function):
"id", None "id", None
) )
result._hidden_params["api_base"] = get_api_base( result._hidden_params["api_base"] = get_api_base(
model=model, model=model or "",
optional_params=getattr(logging_obj, "optional_params", {}), optional_params=getattr(logging_obj, "optional_params", {}),
) )
result._hidden_params["response_cost"] = ( result._hidden_params["response_cost"] = (
@ -1053,7 +1056,7 @@ def client(original_function):
and not _is_litellm_router_call and not _is_litellm_router_call
): ):
if len(args) > 0: if len(args) > 0:
args[0] = context_window_fallback_dict[model] args[0] = context_window_fallback_dict[model] # type: ignore
else: else:
kwargs["model"] = context_window_fallback_dict[model] kwargs["model"] = context_window_fallback_dict[model]
return original_function(*args, **kwargs) return original_function(*args, **kwargs)
@ -1065,12 +1068,6 @@ def client(original_function):
logging_obj.failure_handler( logging_obj.failure_handler(
e, traceback_exception, start_time, end_time e, traceback_exception, start_time, end_time
) # DO NOT MAKE THREADED - router retry fallback relies on this! ) # DO NOT MAKE THREADED - router retry fallback relies on this!
if hasattr(e, "message"):
if (
liteDebuggerClient
and liteDebuggerClient.dashboard_url is not None
): # make it easy to get to the debugger logs if you've initialized it
e.message += f"\n Check the log in your dashboard - {liteDebuggerClient.dashboard_url}"
raise e raise e
@wraps(original_function) @wraps(original_function)
@ -1126,6 +1123,7 @@ def client(original_function):
print_verbose("INSIDE CHECKING CACHE") print_verbose("INSIDE CHECKING CACHE")
if ( if (
litellm.cache is not None litellm.cache is not None
and litellm.cache.supported_call_types is not None
and str(original_function.__name__) and str(original_function.__name__)
in litellm.cache.supported_call_types in litellm.cache.supported_call_types
): ):
@ -1287,7 +1285,11 @@ def client(original_function):
args=(cached_result, start_time, end_time, cache_hit), args=(cached_result, start_time, end_time, cache_hit),
).start() ).start()
cache_key = kwargs.get("preset_cache_key", None) cache_key = kwargs.get("preset_cache_key", None)
cached_result._hidden_params["cache_key"] = cache_key if (
isinstance(cached_result, BaseModel)
or isinstance(cached_result, CustomStreamWrapper)
) and hasattr(cached_result, "_hidden_params"):
cached_result._hidden_params["cache_key"] = cache_key # type: ignore
return cached_result return cached_result
elif ( elif (
call_type == CallTypes.aembedding.value call_type == CallTypes.aembedding.value
@ -1447,6 +1449,7 @@ def client(original_function):
# [OPTIONAL] ADD TO CACHE # [OPTIONAL] ADD TO CACHE
if ( if (
(litellm.cache is not None) (litellm.cache is not None)
and litellm.cache.supported_call_types is not None
and ( and (
str(original_function.__name__) str(original_function.__name__)
in litellm.cache.supported_call_types in litellm.cache.supported_call_types
@ -1504,11 +1507,12 @@ def client(original_function):
if ( if (
isinstance(result, EmbeddingResponse) isinstance(result, EmbeddingResponse)
and final_embedding_cached_response is not None and final_embedding_cached_response is not None
and final_embedding_cached_response.data is not None
): ):
idx = 0 idx = 0
final_data_list = [] final_data_list = []
for item in final_embedding_cached_response.data: for item in final_embedding_cached_response.data:
if item is None: if item is None and result.data is not None:
final_data_list.append(result.data[idx]) final_data_list.append(result.data[idx])
idx += 1 idx += 1
else: else:
@ -1575,7 +1579,7 @@ def client(original_function):
and model in context_window_fallback_dict and model in context_window_fallback_dict
): ):
if len(args) > 0: if len(args) > 0:
args[0] = context_window_fallback_dict[model] args[0] = context_window_fallback_dict[model] # type: ignore
else: else:
kwargs["model"] = context_window_fallback_dict[model] kwargs["model"] = context_window_fallback_dict[model]
return await original_function(*args, **kwargs) return await original_function(*args, **kwargs)
@ -2945,13 +2949,19 @@ def get_optional_params(
response_format=non_default_params["response_format"] response_format=non_default_params["response_format"]
) )
# # clean out 'additionalProperties = False'. Causes vertexai/gemini OpenAI API Schema errors - https://github.com/langchain-ai/langchainjs/issues/5240 # # clean out 'additionalProperties = False'. Causes vertexai/gemini OpenAI API Schema errors - https://github.com/langchain-ai/langchainjs/issues/5240
if non_default_params["response_format"].get("json_schema", {}).get( if (
"schema" non_default_params["response_format"] is not None
) is not None and custom_llm_provider in [ and non_default_params["response_format"]
.get("json_schema", {})
.get("schema")
is not None
and custom_llm_provider
in [
"gemini", "gemini",
"vertex_ai", "vertex_ai",
"vertex_ai_beta", "vertex_ai_beta",
]: ]
):
old_schema = copy.deepcopy( old_schema = copy.deepcopy(
non_default_params["response_format"] non_default_params["response_format"]
.get("json_schema", {}) .get("json_schema", {})
@ -3754,7 +3764,11 @@ def get_optional_params(
non_default_params=non_default_params, non_default_params=non_default_params,
optional_params=optional_params, optional_params=optional_params,
model=model, model=model,
drop_params=drop_params, drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
) )
elif custom_llm_provider == "openrouter": elif custom_llm_provider == "openrouter":
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
@ -3863,7 +3877,11 @@ def get_optional_params(
non_default_params=non_default_params, non_default_params=non_default_params,
optional_params=optional_params, optional_params=optional_params,
model=model, model=model,
drop_params=drop_params, drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
) )
elif custom_llm_provider == "azure": elif custom_llm_provider == "azure":
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
@ -4889,7 +4907,7 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
try: try:
split_model, custom_llm_provider, _, _ = get_llm_provider(model=model) split_model, custom_llm_provider, _, _ = get_llm_provider(model=model)
except Exception: except Exception:
pass split_model = model
combined_model_name = model combined_model_name = model
stripped_model_name = _strip_model_name(model=model) stripped_model_name = _strip_model_name(model=model)
combined_stripped_model_name = stripped_model_name combined_stripped_model_name = stripped_model_name
@ -5865,6 +5883,8 @@ def convert_to_model_response_object(
for idx, choice in enumerate(response_object["choices"]): for idx, choice in enumerate(response_object["choices"]):
## HANDLE JSON MODE - anthropic returns single function call] ## HANDLE JSON MODE - anthropic returns single function call]
tool_calls = choice["message"].get("tool_calls", None) tool_calls = choice["message"].get("tool_calls", None)
message: Optional[Message] = None
finish_reason: Optional[str] = None
if ( if (
convert_tool_call_to_json_mode convert_tool_call_to_json_mode
and tool_calls is not None and tool_calls is not None
@ -5877,7 +5897,7 @@ def convert_to_model_response_object(
if json_mode_content_str is not None: if json_mode_content_str is not None:
message = litellm.Message(content=json_mode_content_str) message = litellm.Message(content=json_mode_content_str)
finish_reason = "stop" finish_reason = "stop"
else: if message is None:
message = Message( message = Message(
content=choice["message"].get("content", None), content=choice["message"].get("content", None),
role=choice["message"]["role"] or "assistant", role=choice["message"]["role"] or "assistant",
@ -6066,7 +6086,7 @@ def valid_model(model):
model in litellm.open_ai_chat_completion_models model in litellm.open_ai_chat_completion_models
or model in litellm.open_ai_text_completion_models or model in litellm.open_ai_text_completion_models
): ):
openai.Model.retrieve(model) openai.models.retrieve(model)
else: else:
messages = [{"role": "user", "content": "Hello World"}] messages = [{"role": "user", "content": "Hello World"}]
litellm.completion(model=model, messages=messages) litellm.completion(model=model, messages=messages)
@ -6386,8 +6406,8 @@ class CustomStreamWrapper:
self, self,
completion_stream, completion_stream,
model, model,
custom_llm_provider=None, logging_obj: Any,
logging_obj=None, custom_llm_provider: Optional[str] = None,
stream_options=None, stream_options=None,
make_call: Optional[Callable] = None, make_call: Optional[Callable] = None,
_response_headers: Optional[dict] = None, _response_headers: Optional[dict] = None,
@ -6633,36 +6653,6 @@ class CustomStreamWrapper:
"completion_tokens": completion_tokens, "completion_tokens": completion_tokens,
} }
def handle_together_ai_chunk(self, chunk):
chunk = chunk.decode("utf-8")
text = ""
is_finished = False
finish_reason = None
if "text" in chunk:
text_index = chunk.find('"text":"') # this checks if text: exists
text_start = text_index + len('"text":"')
text_end = chunk.find('"}', text_start)
if text_index != -1 and text_end != -1:
extracted_text = chunk[text_start:text_end]
text = extracted_text
return {
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
}
elif "[DONE]" in chunk:
return {"text": text, "is_finished": True, "finish_reason": "stop"}
elif "error" in chunk:
raise litellm.together_ai.TogetherAIError(
status_code=422, message=f"{str(chunk)}"
)
else:
return {
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
}
def handle_predibase_chunk(self, chunk): def handle_predibase_chunk(self, chunk):
try: try:
if not isinstance(chunk, str): if not isinstance(chunk, str):
@ -7264,11 +7254,16 @@ class CustomStreamWrapper:
try: try:
if isinstance(chunk, dict): if isinstance(chunk, dict):
parsed_response = chunk parsed_response = chunk
if isinstance(chunk, (str, bytes)): elif isinstance(chunk, (str, bytes)):
if isinstance(chunk, bytes): if isinstance(chunk, bytes):
parsed_response = chunk.decode("utf-8") parsed_response = chunk.decode("utf-8")
else: else:
parsed_response = chunk parsed_response = chunk
else:
raise ValueError("Unable to parse streaming chunk")
if isinstance(parsed_response, dict):
data_json = parsed_response
else:
data_json = json.loads(parsed_response) data_json = json.loads(parsed_response)
text = ( text = (
data_json.get("outputs", "")[0] data_json.get("outputs", "")[0]
@ -7331,8 +7326,7 @@ class CustomStreamWrapper:
if ( if (
len(model_response.choices) > 0 len(model_response.choices) > 0
and hasattr(model_response.choices[0], "delta") and getattr(model_response.choices[0], "delta") is not None
and model_response.choices[0].delta is not None
): ):
# do nothing, if object instantiated # do nothing, if object instantiated
pass pass
@ -7350,7 +7344,7 @@ class CustomStreamWrapper:
is_empty = False is_empty = False
return is_empty return is_empty
def chunk_creator(self, chunk): def chunk_creator(self, chunk): # type: ignore
model_response = self.model_response_creator() model_response = self.model_response_creator()
response_obj = {} response_obj = {}
try: try:
@ -7422,11 +7416,6 @@ class CustomStreamWrapper:
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]: if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"] self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider and self.custom_llm_provider == "together_ai":
response_obj = self.handle_together_ai_chunk(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider and self.custom_llm_provider == "huggingface": elif self.custom_llm_provider and self.custom_llm_provider == "huggingface":
response_obj = self.handle_huggingface_chunk(chunk) response_obj = self.handle_huggingface_chunk(chunk)
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
@ -7475,51 +7464,6 @@ class CustomStreamWrapper:
if self.sent_first_chunk is False: if self.sent_first_chunk is False:
raise Exception("An unknown error occurred with the stream") raise Exception("An unknown error occurred with the stream")
self.received_finish_reason = "stop" self.received_finish_reason = "stop"
elif self.custom_llm_provider == "gemini":
if hasattr(chunk, "parts") is True:
try:
if len(chunk.parts) > 0:
completion_obj["content"] = chunk.parts[0].text
if len(chunk.parts) > 0 and hasattr(
chunk.parts[0], "finish_reason"
):
self.received_finish_reason = chunk.parts[
0
].finish_reason.name
except Exception:
if chunk.parts[0].finish_reason.name == "SAFETY":
raise Exception(
f"The response was blocked by VertexAI. {str(chunk)}"
)
else:
completion_obj["content"] = str(chunk)
elif self.custom_llm_provider and (
self.custom_llm_provider == "vertex_ai_beta"
):
from litellm.types.utils import (
GenericStreamingChunk as UtilsStreamingChunk,
)
if self.received_finish_reason is not None:
raise StopIteration
response_obj: UtilsStreamingChunk = chunk
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
if (
self.stream_options
and self.stream_options.get("include_usage", False) is True
and response_obj["usage"] is not None
):
model_response.usage = litellm.Usage(
prompt_tokens=response_obj["usage"]["prompt_tokens"],
completion_tokens=response_obj["usage"]["completion_tokens"],
total_tokens=response_obj["usage"]["total_tokens"],
)
if "tool_use" in response_obj and response_obj["tool_use"] is not None:
completion_obj["tool_calls"] = [response_obj["tool_use"]]
elif self.custom_llm_provider and (self.custom_llm_provider == "vertex_ai"): elif self.custom_llm_provider and (self.custom_llm_provider == "vertex_ai"):
import proto # type: ignore import proto # type: ignore
@ -7624,53 +7568,7 @@ class CustomStreamWrapper:
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]: if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"] self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider == "bedrock":
from litellm.types.llms.bedrock import GenericStreamingChunk
if self.received_finish_reason is not None:
raise StopIteration
response_obj: GenericStreamingChunk = chunk
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
if (
self.stream_options
and self.stream_options.get("include_usage", False) is True
and response_obj["usage"] is not None
):
model_response.usage = litellm.Usage(
prompt_tokens=response_obj["usage"]["inputTokens"],
completion_tokens=response_obj["usage"]["outputTokens"],
total_tokens=response_obj["usage"]["totalTokens"],
)
if "tool_use" in response_obj and response_obj["tool_use"] is not None:
completion_obj["tool_calls"] = [response_obj["tool_use"]]
elif self.custom_llm_provider == "sagemaker":
from litellm.types.llms.bedrock import GenericStreamingChunk
if self.received_finish_reason is not None:
raise StopIteration
response_obj: GenericStreamingChunk = chunk
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
if (
self.stream_options
and self.stream_options.get("include_usage", False) is True
and response_obj["usage"] is not None
):
model_response.usage = litellm.Usage(
prompt_tokens=response_obj["usage"]["inputTokens"],
completion_tokens=response_obj["usage"]["outputTokens"],
total_tokens=response_obj["usage"]["totalTokens"],
)
if "tool_use" in response_obj and response_obj["tool_use"] is not None:
completion_obj["tool_calls"] = [response_obj["tool_use"]]
elif self.custom_llm_provider == "petals": elif self.custom_llm_provider == "petals":
if len(self.completion_stream) == 0: if len(self.completion_stream) == 0:
if self.received_finish_reason is not None: if self.received_finish_reason is not None:
@ -8181,9 +8079,11 @@ class CustomStreamWrapper:
target=self.run_success_logging_in_thread, target=self.run_success_logging_in_thread,
args=(response, cache_hit), args=(response, cache_hit),
).start() # log response ).start() # log response
self.response_uptil_now += ( choice = response.choices[0]
response.choices[0].delta.get("content", "") or "" if isinstance(choice, StreamingChoices):
) self.response_uptil_now += choice.delta.get("content", "") or ""
else:
self.response_uptil_now += ""
self.rules.post_call_rules( self.rules.post_call_rules(
input=self.response_uptil_now, model=self.model input=self.response_uptil_now, model=self.model
) )
@ -8223,8 +8123,11 @@ class CustomStreamWrapper:
) )
response = self.model_response_creator() response = self.model_response_creator()
if complete_streaming_response is not None: if complete_streaming_response is not None:
response.usage = complete_streaming_response.usage setattr(
response._hidden_params["usage"] = complete_streaming_response.usage # type: ignore response,
"usage",
getattr(complete_streaming_response, "usage"),
)
## LOGGING ## LOGGING
threading.Thread( threading.Thread(
target=self.logging_obj.success_handler, target=self.logging_obj.success_handler,
@ -8349,9 +8252,11 @@ class CustomStreamWrapper:
processed_chunk, cache_hit=cache_hit processed_chunk, cache_hit=cache_hit
) )
) )
self.response_uptil_now += ( choice = processed_chunk.choices[0]
processed_chunk.choices[0].delta.get("content", "") or "" if isinstance(choice, StreamingChoices):
) self.response_uptil_now += choice.delta.get("content", "") or ""
else:
self.response_uptil_now += ""
self.rules.post_call_rules( self.rules.post_call_rules(
input=self.response_uptil_now, model=self.model input=self.response_uptil_now, model=self.model
) )
@ -8401,9 +8306,13 @@ class CustomStreamWrapper:
) )
) )
choice = processed_chunk.choices[0]
if isinstance(choice, StreamingChoices):
self.response_uptil_now += ( self.response_uptil_now += (
processed_chunk.choices[0].delta.get("content", "") or "" choice.delta.get("content", "") or ""
) )
else:
self.response_uptil_now += ""
self.rules.post_call_rules( self.rules.post_call_rules(
input=self.response_uptil_now, model=self.model input=self.response_uptil_now, model=self.model
) )
@ -8423,7 +8332,11 @@ class CustomStreamWrapper:
) )
response = self.model_response_creator() response = self.model_response_creator()
if complete_streaming_response is not None: if complete_streaming_response is not None:
setattr(response, "usage", complete_streaming_response.usage) setattr(
response,
"usage",
getattr(complete_streaming_response, "usage"),
)
## LOGGING ## LOGGING
threading.Thread( threading.Thread(
target=self.logging_obj.success_handler, target=self.logging_obj.success_handler,
@ -8464,7 +8377,11 @@ class CustomStreamWrapper:
) )
response = self.model_response_creator() response = self.model_response_creator()
if complete_streaming_response is not None: if complete_streaming_response is not None:
response.usage = complete_streaming_response.usage setattr(
response,
"usage",
getattr(complete_streaming_response, "usage"),
)
## LOGGING ## LOGGING
threading.Thread( threading.Thread(
target=self.logging_obj.success_handler, target=self.logging_obj.success_handler,
@ -8898,7 +8815,7 @@ def trim_messages(
if len(tool_messages): if len(tool_messages):
messages = messages[: -len(tool_messages)] messages = messages[: -len(tool_messages)]
current_tokens = token_counter(model=model, messages=messages) current_tokens = token_counter(model=model or "", messages=messages)
print_verbose(f"Current tokens: {current_tokens}, max tokens: {max_tokens}") print_verbose(f"Current tokens: {current_tokens}, max tokens: {max_tokens}")
# Do nothing if current tokens under messages # Do nothing if current tokens under messages
@ -8909,6 +8826,7 @@ def trim_messages(
print_verbose( print_verbose(
f"Need to trim input messages: {messages}, current_tokens{current_tokens}, max_tokens: {max_tokens}" f"Need to trim input messages: {messages}, current_tokens{current_tokens}, max_tokens: {max_tokens}"
) )
system_message_event: Optional[dict] = None
if system_message: if system_message:
system_message_event, max_tokens = process_system_message( system_message_event, max_tokens = process_system_message(
system_message=system_message, max_tokens=max_tokens, model=model system_message=system_message, max_tokens=max_tokens, model=model
@ -8926,7 +8844,7 @@ def trim_messages(
) )
# Add system message to the beginning of the final messages # Add system message to the beginning of the final messages
if system_message: if system_message_event:
final_messages = [system_message_event] + final_messages final_messages = [system_message_event] + final_messages
if len(tool_messages) > 0: if len(tool_messages) > 0:
@ -9214,6 +9132,8 @@ def is_cached_message(message: AllMessageValues) -> bool:
Follows the anthropic format {"cache_control": {"type": "ephemeral"}} Follows the anthropic format {"cache_control": {"type": "ephemeral"}}
""" """
if "content" not in message:
return False
if message["content"] is None or isinstance(message["content"], str): if message["content"] is None or isinstance(message["content"], str):
return False return False

View file

@ -1,6 +1,7 @@
{ {
"ignore": [], "ignore": [],
"exclude": ["**/node_modules", "**/__pycache__", "litellm/tests", "litellm/main.py", "litellm/utils.py", "litellm/types/utils.py"], "exclude": ["**/node_modules", "**/__pycache__", "litellm/types/utils.py"],
"reportMissingImports": false "reportMissingImports": false,
"reportPrivateImportUsage": false
} }

View file

@ -14,7 +14,7 @@ from typing import Optional
import httpx import httpx
from litellm.integrations.SlackAlerting.types import AlertType from litellm.types.integrations.slack_alerting import AlertType
# import logging # import logging
# logging.basicConfig(level=logging.DEBUG) # logging.basicConfig(level=logging.DEBUG)

View file

@ -1188,13 +1188,36 @@ def test_completion_cost_anthropic_prompt_caching():
system_fingerprint=None, system_fingerprint=None,
usage=Usage( usage=Usage(
completion_tokens=10, completion_tokens=10,
prompt_tokens=14, prompt_tokens=114,
total_tokens=24, total_tokens=124,
prompt_tokens_details=PromptTokensDetails(cached_tokens=0),
cache_creation_input_tokens=100, cache_creation_input_tokens=100,
cache_read_input_tokens=0, cache_read_input_tokens=0,
), ),
) )
cost_1 = completion_cost(model=model, completion_response=response_1)
_model_info = litellm.get_model_info(
model="claude-3-5-sonnet-20240620", custom_llm_provider="anthropic"
)
expected_cost = (
(
response_1.usage.prompt_tokens
- response_1.usage.prompt_tokens_details.cached_tokens
)
* _model_info["input_cost_per_token"]
+ response_1.usage.prompt_tokens_details.cached_tokens
* _model_info["cache_read_input_token_cost"]
+ response_1.usage.cache_creation_input_tokens
* _model_info["cache_creation_input_token_cost"]
+ response_1.usage.completion_tokens * _model_info["output_cost_per_token"]
) # Cost of processing (non-cache hit + cache hit) + Cost of cache-writing (cache writing)
assert round(expected_cost, 5) == round(cost_1, 5)
print(f"expected_cost: {expected_cost}, cost_1: {cost_1}")
## READ FROM CACHE ## (LESS EXPENSIVE) ## READ FROM CACHE ## (LESS EXPENSIVE)
response_2 = ModelResponse( response_2 = ModelResponse(
id="chatcmpl-3f427194-0840-4d08-b571-56bfe38a5424", id="chatcmpl-3f427194-0840-4d08-b571-56bfe38a5424",
@ -1216,14 +1239,14 @@ def test_completion_cost_anthropic_prompt_caching():
system_fingerprint=None, system_fingerprint=None,
usage=Usage( usage=Usage(
completion_tokens=10, completion_tokens=10,
prompt_tokens=14, prompt_tokens=114,
total_tokens=24, total_tokens=134,
prompt_tokens_details=PromptTokensDetails(cached_tokens=100),
cache_creation_input_tokens=0, cache_creation_input_tokens=0,
cache_read_input_tokens=100, cache_read_input_tokens=100,
), ),
) )
cost_1 = completion_cost(model=model, completion_response=response_1)
cost_2 = completion_cost(model=model, completion_response=response_2) cost_2 = completion_cost(model=model, completion_response=response_2)
assert cost_1 > cost_2 assert cost_1 > cost_2

View file

@ -10,12 +10,40 @@ import litellm
import pytest import pytest
def _usage_format_tests(usage: litellm.Usage):
"""
OpenAI prompt caching
- prompt_tokens = sum of non-cache hit tokens + cache-hit tokens
- total_tokens = prompt_tokens + completion_tokens
Example
```
"usage": {
"prompt_tokens": 2006,
"completion_tokens": 300,
"total_tokens": 2306,
"prompt_tokens_details": {
"cached_tokens": 1920
},
"completion_tokens_details": {
"reasoning_tokens": 0
}
# ANTHROPIC_ONLY #
"cache_creation_input_tokens": 0
}
```
"""
assert usage.total_tokens == usage.prompt_tokens + usage.completion_tokens
assert usage.prompt_tokens > usage.prompt_tokens_details.cached_tokens
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model", "model",
[ [
"anthropic/claude-3-5-sonnet-20240620", "anthropic/claude-3-5-sonnet-20240620",
"openai/gpt-4o", # "openai/gpt-4o",
"deepseek/deepseek-chat", # "deepseek/deepseek-chat",
], ],
) )
def test_prompt_caching_model(model): def test_prompt_caching_model(model):
@ -66,9 +94,13 @@ def test_prompt_caching_model(model):
max_tokens=10, max_tokens=10,
) )
_usage_format_tests(response.usage)
print("response=", response) print("response=", response)
print("response.usage=", response.usage) print("response.usage=", response.usage)
_usage_format_tests(response.usage)
assert "prompt_tokens_details" in response.usage assert "prompt_tokens_details" in response.usage
assert response.usage.prompt_tokens_details.cached_tokens > 0 assert response.usage.prompt_tokens_details.cached_tokens > 0