Add pyright to ci/cd + Fix remaining type-checking errors (#6082)

* fix: fix type-checking errors

* fix: fix additional type-checking errors

* fix: additional type-checking error fixes

* fix: fix additional type-checking errors

* fix: additional type-check fixes

* fix: fix all type-checking errors + add pyright to ci/cd

* fix: fix incorrect import

* ci(config.yml): use mypy on ci/cd

* fix: fix type-checking errors in utils.py

* fix: fix all type-checking errors on main.py

* fix: fix mypy linting errors

* fix(anthropic/cost_calculator.py): fix linting errors

* fix: fix mypy linting errors

* fix: fix linting errors
This commit is contained in:
Krish Dholakia 2024-10-05 17:04:00 -04:00 committed by GitHub
parent f7ce1173f3
commit fac3b2ee42
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
65 changed files with 619 additions and 522 deletions

View file

@ -315,10 +315,10 @@ jobs:
python -m pip install --upgrade pip
pip install ruff
pip install pylint
pip install pyright
pip install .
- run: python -c "from litellm import *" || (echo '🚨 import failed, this means you introduced unprotected imports! 🚨'; exit 1)
- run: ruff check ./litellm
build_and_test:
machine:

View file

@ -8,7 +8,7 @@ import requests
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
from typing import Literal, Union
from typing import Literal, Union, Optional
import traceback
@ -26,9 +26,9 @@ from litellm._logging import print_verbose, verbose_logger
class GenericAPILogger:
# Class variables or attributes
def __init__(self, endpoint=None, headers=None):
def __init__(self, endpoint: Optional[str] = None, headers: Optional[dict] = None):
try:
if endpoint == None:
if endpoint is None:
# check env for "GENERIC_LOGGER_ENDPOINT"
if os.getenv("GENERIC_LOGGER_ENDPOINT"):
# Do something with the endpoint
@ -36,9 +36,15 @@ class GenericAPILogger:
else:
# Handle the case when the endpoint is not found in the environment variables
raise ValueError(
f"endpoint not set for GenericAPILogger, GENERIC_LOGGER_ENDPOINT not found in environment variables"
"endpoint not set for GenericAPILogger, GENERIC_LOGGER_ENDPOINT not found in environment variables"
)
headers = headers or litellm.generic_logger_headers
if endpoint is None:
raise ValueError("endpoint not set for GenericAPILogger")
if headers is None:
raise ValueError("headers not set for GenericAPILogger")
self.endpoint = endpoint
self.headers = headers

View file

@ -48,8 +48,6 @@ class AporiaGuardrail(CustomGuardrail):
)
self.aporia_api_key = api_key or os.environ["APORIO_API_KEY"]
self.aporia_api_base = api_base or os.environ["APORIO_API_BASE"]
self.event_hook: GuardrailEventHooks
super().__init__(**kwargs)
#### CALL HOOKS - proxy only ####

View file

@ -84,7 +84,7 @@ class _ENTERPRISE_BlockedUserList(CustomLogger):
)
cache_key = f"litellm:end_user_id:{user}"
end_user_cache_obj: LiteLLM_EndUserTable = cache.get_cache(
end_user_cache_obj: Optional[LiteLLM_EndUserTable] = cache.get_cache( # type: ignore
key=cache_key
)
if end_user_cache_obj is None and self.prisma_client is not None:

View file

@ -48,7 +48,7 @@ class _ENTERPRISE_GoogleTextModeration(CustomLogger):
# Class variables or attributes
def __init__(self):
try:
from google.cloud import language_v1
from google.cloud import language_v1 # type: ignore
except Exception:
raise Exception(
"Missing google.cloud package. Run `pip install --upgrade google-cloud-language`"
@ -57,8 +57,8 @@ class _ENTERPRISE_GoogleTextModeration(CustomLogger):
# Instantiates a client
self.client = language_v1.LanguageServiceClient()
self.moderate_text_request = language_v1.ModerateTextRequest
self.language_document = language_v1.types.Document
self.document_type = language_v1.types.Document.Type.PLAIN_TEXT
self.language_document = language_v1.types.Document # type: ignore
self.document_type = language_v1.types.Document.Type.PLAIN_TEXT # type: ignore
default_confidence_threshold = (
litellm.google_moderation_confidence_threshold or 0.8

View file

@ -8,6 +8,7 @@
# Thank you users! We ❤️ you! - Krrish & Ishaan
import sys, os
from collections.abc import Iterable
sys.path.insert(
0, os.path.abspath("../..")
@ -19,11 +20,12 @@ from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger
from fastapi import HTTPException
from litellm._logging import verbose_proxy_logger
from litellm.utils import (
from litellm.types.utils import (
ModelResponse,
EmbeddingResponse,
ImageResponse,
StreamingChoices,
Choices,
)
from datetime import datetime
import aiohttp, asyncio
@ -34,7 +36,10 @@ litellm.set_verbose = True
class _ENTERPRISE_LlamaGuard(CustomLogger):
# Class variables or attributes
def __init__(self, model_name: Optional[str] = None):
self.model = model_name or litellm.llamaguard_model_name
_model = model_name or litellm.llamaguard_model_name
if _model is None:
raise ValueError("model_name not set for LlamaGuard")
self.model = _model
file_path = litellm.llamaguard_unsafe_content_categories
data = None
@ -124,7 +129,13 @@ class _ENTERPRISE_LlamaGuard(CustomLogger):
hf_model_name="meta-llama/LlamaGuard-7b",
)
if "unsafe" in response.choices[0].message.content:
if (
isinstance(response, ModelResponse)
and isinstance(response.choices[0], Choices)
and response.choices[0].message.content is not None
and isinstance(response.choices[0].message.content, Iterable)
and "unsafe" in response.choices[0].message.content
):
raise HTTPException(
status_code=400, detail={"error": "Violated content safety policy"}
)

View file

@ -8,7 +8,11 @@
## This provides an LLM Guard Integration for content moderation on the proxy
from typing import Optional, Literal, Union
import litellm, traceback, sys, uuid, os
import litellm
import traceback
import sys
import uuid
import os
from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger
@ -21,8 +25,10 @@ from litellm.utils import (
StreamingChoices,
)
from datetime import datetime
import aiohttp, asyncio
import aiohttp
import asyncio
from litellm.utils import get_formatted_prompt
from litellm.secret_managers.main import get_secret_str
litellm.set_verbose = True
@ -38,7 +44,7 @@ class _ENTERPRISE_LLMGuard(CustomLogger):
self.llm_guard_mode = litellm.llm_guard_mode
if mock_testing == True: # for testing purposes only
return
self.llm_guard_api_base = litellm.get_secret("LLM_GUARD_API_BASE", None)
self.llm_guard_api_base = get_secret_str("LLM_GUARD_API_BASE", None)
if self.llm_guard_api_base is None:
raise Exception("Missing `LLM_GUARD_API_BASE` from environment")
elif not self.llm_guard_api_base.endswith("/"):

View file

@ -51,8 +51,8 @@ class _ENTERPRISE_OpenAI_Moderation(CustomLogger):
"audio_transcription",
],
):
text = ""
if "messages" in data and isinstance(data["messages"], list):
text = ""
for m in data["messages"]: # assume messages is a list
if "content" in m and isinstance(m["content"], str):
text += m["content"]
@ -67,7 +67,7 @@ class _ENTERPRISE_OpenAI_Moderation(CustomLogger):
)
verbose_proxy_logger.debug("Moderation response: %s", moderation_response)
if moderation_response.results[0].flagged == True:
if moderation_response.results[0].flagged is True:
raise HTTPException(
status_code=403, detail={"error": "Violated content safety policy"}
)

View file

@ -6,7 +6,9 @@ import collections
from datetime import datetime
async def get_spend_by_tags(start_date=None, end_date=None, prisma_client=None):
async def get_spend_by_tags(
prisma_client: PrismaClient, start_date=None, end_date=None
):
response = await prisma_client.db.query_raw(
"""
SELECT

View file

@ -191,7 +191,7 @@ def init_redis_cluster(redis_kwargs) -> redis.RedisCluster:
new_startup_nodes.append(ClusterNode(**item))
redis_kwargs.pop("startup_nodes")
return redis.RedisCluster(startup_nodes=new_startup_nodes, **cluster_kwargs)
return redis.RedisCluster(startup_nodes=new_startup_nodes, **cluster_kwargs) # type: ignore
def _init_redis_sentinel(redis_kwargs) -> redis.Redis:

View file

@ -1,5 +1,8 @@
import litellm
from typing import Optional, Union
import litellm
from ..exceptions import UnsupportedParamsError
from ..types.llms.openai import *

View file

@ -10,6 +10,7 @@
import ast
import asyncio
import hashlib
import inspect
import io
import json
import logging
@ -245,7 +246,8 @@ class RedisCache(BaseCache):
self.redis_flush_size = redis_flush_size
self.redis_version = "Unknown"
try:
self.redis_version = self.redis_client.info()["redis_version"]
if not inspect.iscoroutinefunction(self.redis_client):
self.redis_version = self.redis_client.info()["redis_version"] # type: ignore
except Exception:
pass
@ -266,7 +268,8 @@ class RedisCache(BaseCache):
### SYNC HEALTH PING ###
try:
self.redis_client.ping()
if hasattr(self.redis_client, "ping"):
self.redis_client.ping() # type: ignore
except Exception as e:
verbose_logger.error(
"Error connecting to Sync Redis client", extra={"error": str(e)}
@ -308,7 +311,7 @@ class RedisCache(BaseCache):
_redis_client = self.redis_client
start_time = time.time()
try:
result = _redis_client.incr(name=key, amount=value)
result: int = _redis_client.incr(name=key, amount=value) # type: ignore
if ttl is not None:
# check if key already has ttl, if not -> set ttl
@ -561,7 +564,7 @@ class RedisCache(BaseCache):
f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}"
)
try:
await redis_client.sadd(key, *value)
await redis_client.sadd(key, *value) # type: ignore
if ttl is not None:
_td = timedelta(seconds=ttl)
await redis_client.expire(key, _td)
@ -712,7 +715,7 @@ class RedisCache(BaseCache):
for cache_key in key_list:
cache_key = self.check_and_fix_namespace(key=cache_key)
_keys.append(cache_key)
results = self.redis_client.mget(keys=_keys)
results: List = self.redis_client.mget(keys=_keys) # type: ignore
# Associate the results back with their keys.
# 'results' is a list of values corresponding to the order of keys in 'key_list'.
@ -842,7 +845,7 @@ class RedisCache(BaseCache):
print_verbose("Pinging Sync Redis Cache")
start_time = time.time()
try:
response = self.redis_client.ping()
response: bool = self.redis_client.ping() # type: ignore
print_verbose(f"Redis Cache PING: {response}")
## LOGGING ##
end_time = time.time()
@ -911,8 +914,8 @@ class RedisCache(BaseCache):
async with _redis_client as redis_client:
await redis_client.delete(*keys)
def client_list(self):
client_list = self.redis_client.client_list()
def client_list(self) -> List:
client_list: List = self.redis_client.client_list() # type: ignore
return client_list
def info(self):

View file

@ -39,8 +39,8 @@ from litellm.llms.fireworks_ai.cost_calculator import (
)
from litellm.llms.OpenAI.cost_calculation import cost_per_token as openai_cost_per_token
from litellm.llms.together_ai.cost_calculator import get_model_params_and_category
from litellm.rerank_api.types import RerankResponse
from litellm.types.llms.openai import HttpxBinaryResponseContent
from litellm.types.rerank import RerankResponse
from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS
from litellm.types.utils import PassthroughCallTypes, Usage
from litellm.utils import (

View file

@ -39,11 +39,11 @@ from litellm.proxy._types import (
VirtualKeyEvent,
WebhookEvent,
)
from litellm.types.integrations.slack_alerting import *
from litellm.types.router import LiteLLM_Params
from ..email_templates.templates import *
from .batching_handler import send_to_webhook, squash_payloads
from .types import *
from .utils import _add_langfuse_trace_id_to_alert, process_slack_alerting_variables

View file

@ -172,14 +172,14 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
"moderation",
"audio_transcription",
],
):
) -> Any:
pass
async def async_post_call_streaming_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
response: str,
):
) -> Any:
pass
#### SINGLE-USE #### - https://docs.litellm.ai/docs/observability/custom_callback#using-your-custom-callback-function

View file

@ -42,8 +42,10 @@ async def get_all_team_member_emails(team_id: Optional[str] = None) -> list:
)
_team_member_user_ids: List[str] = []
for member in _team_members:
if member and isinstance(member, dict) and member.get("user_id") is not None:
_team_member_user_ids.append(member.get("user_id"))
if member and isinstance(member, dict):
_user_id = member.get("user_id")
if _user_id and isinstance(_user_id, str):
_team_member_user_ids.append(_user_id)
sql_query = """
SELECT user_email

View file

@ -149,7 +149,7 @@ class LunaryLogger:
else:
error_obj = None
self.lunary_client.track_event(
self.lunary_client.track_event( # type: ignore
type,
"start",
run_id,
@ -164,7 +164,7 @@ class LunaryLogger:
params=extra,
)
self.lunary_client.track_event(
self.lunary_client.track_event( # type: ignore
type,
event,
run_id,

View file

@ -100,16 +100,14 @@ class OpenMeterLogger(CustomLogger):
}
try:
response = self.sync_http_handler.post(
self.sync_http_handler.post(
url=_url,
data=json.dumps(_data),
headers=_headers,
)
response.raise_for_status()
except httpx.HTTPStatusError as e:
raise Exception(f"OpenMeter logging error: {e.response.text}")
except Exception as e:
if hasattr(response, "text"):
litellm.print_verbose(f"\nError Message: {response.text}")
raise e
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
@ -128,18 +126,12 @@ class OpenMeterLogger(CustomLogger):
}
try:
response = await self.async_http_handler.post(
await self.async_http_handler.post(
url=_url,
data=json.dumps(_data),
headers=_headers,
)
response.raise_for_status()
except httpx.HTTPStatusError as e:
verbose_logger.error(
"Failed OpenMeter logging - {}".format(e.response.text)
)
raise e
raise Exception(f"OpenMeter logging error: {e.response.text}")
except Exception as e:
verbose_logger.error("Failed OpenMeter logging - {}".format(str(e)))
raise e

View file

@ -146,14 +146,14 @@ class S3Logger:
import json
payload = json.dumps(payload)
payload_str = json.dumps(payload)
print_verbose(f"\ns3 Logger - Logging payload = {payload}")
print_verbose(f"\ns3 Logger - Logging payload = {payload_str}")
response = self.s3_client.put_object(
Bucket=self.bucket_name,
Key=s3_object_key,
Body=payload,
Body=payload_str,
ContentType="application/json",
ContentLanguage="en",
ContentDisposition=f'inline; filename="{s3_object_download_filename}"',

View file

@ -1,6 +1,7 @@
import traceback
from litellm._logging import verbose_logger
import litellm
from litellm._logging import verbose_logger
class TraceloopLogger:
@ -11,14 +12,15 @@ class TraceloopLogger:
def __init__(self):
try:
from traceloop.sdk.tracing.tracing import TracerWrapper
from opentelemetry.sdk.trace.export import ConsoleSpanExporter
from traceloop.sdk import Traceloop
from traceloop.sdk.instruments import Instruments
from opentelemetry.sdk.trace.export import ConsoleSpanExporter
from traceloop.sdk.tracing.tracing import TracerWrapper
except ModuleNotFoundError as e:
verbose_logger.error(
f"Traceloop not installed, try running 'pip install traceloop-sdk' to fix this error: {e}\n{traceback.format_exc()}"
)
raise e
Traceloop.init(
app_name="Litellm-Server",
@ -38,8 +40,8 @@ class TraceloopLogger:
status_message=None,
):
from opentelemetry import trace
from opentelemetry.trace import SpanKind, Status, StatusCode
from opentelemetry.semconv.ai import SpanAttributes
from opentelemetry.trace import SpanKind, Status, StatusCode
try:
print_verbose(
@ -94,7 +96,7 @@ class TraceloopLogger:
)
if "temperature" in optional_params:
span.set_attribute(
SpanAttributes.LLM_REQUEST_TEMPERATURE,
SpanAttributes.LLM_REQUEST_TEMPERATURE, # type: ignore
kwargs.get("temperature"),
)

View file

@ -32,8 +32,8 @@ from litellm.litellm_core_utils.redact_messages import (
redact_message_input_output_from_logging,
)
from litellm.proxy._types import CommonProxyErrors
from litellm.rerank_api.types import RerankResponse
from litellm.types.llms.openai import HttpxBinaryResponseContent
from litellm.types.rerank import RerankResponse
from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS
from litellm.types.utils import (
CallTypes,

View file

@ -1,5 +1,5 @@
import uuid
from typing import Optional, Union
from typing import Any, Optional, Union
import httpx
from openai import AsyncAzureOpenAI, AzureOpenAI
@ -24,6 +24,7 @@ class AzureAudioTranscription(AzureChatCompletion):
model: str,
audio_file: FileTypes,
optional_params: dict,
logging_obj: Any,
model_response: TranscriptionResponse,
timeout: float,
max_retries: int,
@ -32,9 +33,8 @@ class AzureAudioTranscription(AzureChatCompletion):
api_version: Optional[str] = None,
client=None,
azure_ad_token: Optional[str] = None,
logging_obj=None,
atranscription: bool = False,
):
) -> TranscriptionResponse:
data = {"model": model, "file": audio_file, **optional_params}
# init AzureOpenAI Client
@ -59,7 +59,7 @@ class AzureAudioTranscription(AzureChatCompletion):
azure_client_params["max_retries"] = max_retries
if atranscription is True:
return self.async_audio_transcriptions(
return self.async_audio_transcriptions( # type: ignore
audio_file=audio_file,
data=data,
model_response=model_response,
@ -105,7 +105,7 @@ class AzureAudioTranscription(AzureChatCompletion):
original_response=stringified_response,
)
hidden_params = {"model": "whisper-1", "custom_llm_provider": "azure"}
final_response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore
final_response: TranscriptionResponse = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore
return final_response
async def async_audio_transcriptions(
@ -114,12 +114,12 @@ class AzureAudioTranscription(AzureChatCompletion):
data: dict,
model_response: TranscriptionResponse,
timeout: float,
azure_client_params: dict,
logging_obj: Any,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
client=None,
azure_client_params=None,
max_retries=None,
logging_obj=None,
):
response = None
try:

View file

@ -1083,7 +1083,7 @@ class AzureChatCompletion(BaseLLM):
azure_ad_token: Optional[str] = None,
client=None,
aembedding=None,
):
) -> litellm.EmbeddingResponse:
super().embedding()
if self._client_session is None:
self._client_session = self.create_client_session()
@ -1128,7 +1128,7 @@ class AzureChatCompletion(BaseLLM):
)
if aembedding is True:
response = self.aembedding(
return self.aembedding( # type: ignore
data=data,
input=input,
logging_obj=logging_obj,
@ -1138,7 +1138,6 @@ class AzureChatCompletion(BaseLLM):
timeout=timeout,
client=client,
)
return response
if client is None:
azure_client = AzureOpenAI(**azure_client_params) # type: ignore
else:
@ -1418,7 +1417,7 @@ class AzureChatCompletion(BaseLLM):
logging_obj: LiteLLMLoggingObj,
client=None,
timeout=None,
):
) -> litellm.ImageResponse:
response: Optional[dict] = None
try:
# response = await azure_client.images.generate(**data, timeout=timeout)
@ -1460,7 +1459,7 @@ class AzureChatCompletion(BaseLLM):
additional_args={"complete_input_dict": data},
original_response=stringified_response,
)
return convert_to_model_response_object(
return convert_to_model_response_object( # type: ignore
response_object=stringified_response,
model_response_object=model_response,
response_type="image_generation",
@ -1489,7 +1488,7 @@ class AzureChatCompletion(BaseLLM):
azure_ad_token: Optional[str] = None,
client=None,
aimg_generation=None,
):
) -> litellm.ImageResponse:
try:
if model and len(model) > 0:
model = model
@ -1531,8 +1530,7 @@ class AzureChatCompletion(BaseLLM):
azure_client_params["azure_ad_token"] = azure_ad_token
if aimg_generation is True:
response = self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params, timeout=timeout) # type: ignore
return response
return self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params, timeout=timeout) # type: ignore
img_gen_api_base = self.create_azure_base_url(
azure_client_params=azure_client_params, model=data.get("model", "")
@ -1742,9 +1740,9 @@ class AzureChatCompletion(BaseLLM):
async def ahealth_check(
self,
model: Optional[str],
api_key: str,
api_key: Optional[str],
api_base: str,
api_version: str,
api_version: Optional[str],
timeout: float,
mode: str,
messages: Optional[list] = None,

View file

@ -77,15 +77,15 @@ class OpenAIAudioTranscription(OpenAIChatCompletion):
model_response: TranscriptionResponse,
timeout: float,
max_retries: int,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str],
api_base: Optional[str],
client=None,
logging_obj=None,
atranscription: bool = False,
):
) -> TranscriptionResponse:
data = {"model": model, "file": audio_file, **optional_params}
if atranscription is True:
return self.async_audio_transcriptions(
return self.async_audio_transcriptions( # type: ignore
audio_file=audio_file,
data=data,
model_response=model_response,
@ -97,7 +97,7 @@ class OpenAIAudioTranscription(OpenAIChatCompletion):
logging_obj=logging_obj,
)
openai_client = self._get_openai_client(
openai_client: OpenAI = self._get_openai_client( # type: ignore
is_async=False,
api_key=api_key,
api_base=api_base,
@ -123,7 +123,7 @@ class OpenAIAudioTranscription(OpenAIChatCompletion):
original_response=stringified_response,
)
hidden_params = {"model": "whisper-1", "custom_llm_provider": "openai"}
final_response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore
final_response: TranscriptionResponse = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore
return final_response
async def async_audio_transcriptions(
@ -139,7 +139,7 @@ class OpenAIAudioTranscription(OpenAIChatCompletion):
max_retries=None,
):
try:
openai_aclient = self._get_openai_client(
openai_aclient: AsyncOpenAI = self._get_openai_client( # type: ignore
is_async=True,
api_key=api_key,
api_base=api_base,

View file

@ -1167,7 +1167,7 @@ class OpenAIChatCompletion(BaseLLM):
api_base: Optional[str] = None,
client=None,
aembedding=None,
):
) -> litellm.EmbeddingResponse:
super().embedding()
try:
model = model
@ -1183,7 +1183,7 @@ class OpenAIChatCompletion(BaseLLM):
)
if aembedding is True:
async_response = self.aembedding(
return self.aembedding( # type: ignore
data=data,
input=input,
logging_obj=logging_obj,
@ -1194,7 +1194,6 @@ class OpenAIChatCompletion(BaseLLM):
client=client,
max_retries=max_retries,
)
return async_response
openai_client: OpenAI = self._get_openai_client( # type: ignore
is_async=False,
@ -1294,7 +1293,7 @@ class OpenAIChatCompletion(BaseLLM):
model_response: Optional[litellm.utils.ImageResponse] = None,
client=None,
aimg_generation=None,
):
) -> litellm.ImageResponse:
data = {}
try:
model = model
@ -1304,8 +1303,7 @@ class OpenAIChatCompletion(BaseLLM):
raise OpenAIError(status_code=422, message="max retries must be an int")
if aimg_generation is True:
response = self.aimage_generation(data=data, prompt=prompt, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
return response
return self.aimage_generation(data=data, prompt=prompt, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
openai_client = self._get_openai_client(
is_async=False,
@ -1449,7 +1447,7 @@ class OpenAIChatCompletion(BaseLLM):
async def ahealth_check(
self,
model: Optional[str],
api_key: str,
api_key: Optional[str],
timeout: float,
mode: str,
messages: Optional[list] = None,

View file

@ -282,7 +282,6 @@ class AnthropicChatCompletion(BaseLLM):
prompt_tokens = completion_response["usage"]["input_tokens"]
completion_tokens = completion_response["usage"]["output_tokens"]
_usage = completion_response["usage"]
total_tokens = prompt_tokens + completion_tokens
cache_creation_input_tokens: int = 0
cache_read_input_tokens: int = 0
@ -290,12 +289,15 @@ class AnthropicChatCompletion(BaseLLM):
model_response.model = model
if "cache_creation_input_tokens" in _usage:
cache_creation_input_tokens = _usage["cache_creation_input_tokens"]
prompt_tokens += cache_creation_input_tokens
if "cache_read_input_tokens" in _usage:
cache_read_input_tokens = _usage["cache_read_input_tokens"]
prompt_tokens += cache_read_input_tokens
prompt_tokens_details = PromptTokensDetails(
cached_tokens=cache_read_input_tokens
)
total_tokens = prompt_tokens + completion_tokens
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,

View file

@ -24,16 +24,33 @@ def cost_per_token(model: str, usage: Usage) -> Tuple[float, float]:
model_info = get_model_info(model=model, custom_llm_provider="anthropic")
## CALCULATE INPUT COST
### Cost of processing (non-cache hit + cache hit) + Cost of cache-writing (cache writing)
prompt_cost = 0.0
### PROCESSING COST
non_cache_hit_tokens = usage.prompt_tokens
cache_hit_tokens = 0
if usage.prompt_tokens_details and usage.prompt_tokens_details.cached_tokens:
cache_hit_tokens = usage.prompt_tokens_details.cached_tokens
non_cache_hit_tokens = non_cache_hit_tokens - cache_hit_tokens
prompt_cost: float = usage["prompt_tokens"] * model_info["input_cost_per_token"]
if model_info.get("cache_creation_input_token_cost") is not None:
prompt_cost = float(non_cache_hit_tokens) * model_info["input_cost_per_token"]
_cache_read_input_token_cost = model_info.get("cache_read_input_token_cost")
if (
_cache_read_input_token_cost is not None
and usage.prompt_tokens_details
and usage.prompt_tokens_details.cached_tokens
):
prompt_cost += (
usage._cache_creation_input_tokens # type: ignore
* model_info["cache_creation_input_token_cost"]
float(usage.prompt_tokens_details.cached_tokens)
* _cache_read_input_token_cost
)
if model_info.get("cache_read_input_token_cost") is not None:
### CACHE WRITING COST
_cache_creation_input_token_cost = model_info.get("cache_creation_input_token_cost")
if _cache_creation_input_token_cost is not None:
prompt_cost += (
usage._cache_read_input_tokens * model_info["cache_read_input_token_cost"] # type: ignore
float(usage._cache_creation_input_tokens) * _cache_creation_input_token_cost
)
## CALCULATE OUTPUT COST

View file

@ -216,7 +216,7 @@ class AzureAIEmbedding(OpenAIChatCompletion):
api_base: Optional[str] = None,
client=None,
aembedding=None,
):
) -> litellm.EmbeddingResponse:
"""
- Separate image url from text
-> route image url call to `/image/embeddings`
@ -225,7 +225,7 @@ class AzureAIEmbedding(OpenAIChatCompletion):
assemble result in-order, and return
"""
if aembedding is True:
return self.async_embedding(
return self.async_embedding( # type: ignore
model,
input,
timeout,

View file

@ -4,7 +4,7 @@ import httpx
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.cohere.rerank import CohereRerank
from litellm.rerank_api.types import RerankResponse
from litellm.types.rerank import RerankResponse
class AzureAIRerank(CohereRerank):

View file

@ -5,7 +5,7 @@ Handles image gen calls to Bedrock's `/invoke` endpoint
import copy
import json
import os
from typing import List
from typing import Any, List
from openai.types.image import Image
@ -20,8 +20,8 @@ def image_generation(
prompt: str,
model_response: ImageResponse,
optional_params: dict,
logging_obj: Any,
timeout=None,
logging_obj=None,
aimg_generation=False,
):
"""

View file

@ -13,7 +13,7 @@ from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client,
get_async_httpx_client,
)
from litellm.rerank_api.types import RerankRequest, RerankResponse
from litellm.types.rerank import RerankRequest, RerankResponse
class CohereRerank(BaseLLM):

View file

@ -40,7 +40,7 @@ class AzureOpenAIFineTuningAPI(BaseLLM):
organization: Optional[str] = None,
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
api_version: Optional[str] = None,
) -> Union[FineTuningJob, Union[Coroutine[Any, Any, FineTuningJob]]]:
) -> Union[FineTuningJob, Coroutine[Any, Any, FineTuningJob]]:
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
get_azure_openai_client(
api_key=api_key,

View file

@ -68,7 +68,7 @@ class OpenAIFineTuningAPI(BaseLLM):
max_retries: Optional[int],
organization: Optional[str],
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
) -> Union[FineTuningJob, Union[Coroutine[Any, Any, FineTuningJob]]]:
) -> Union[FineTuningJob, Coroutine[Any, Any, FineTuningJob]]:
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
api_key=api_key,
api_base=api_base,

View file

@ -554,7 +554,7 @@ class Huggingface(BaseLLM):
else:
prompt = prompt_factory(model=model, messages=messages)
data = {
"inputs": prompt,
"inputs": prompt, # type: ignore
"parameters": optional_params,
"stream": ( # type: ignore
True
@ -589,7 +589,7 @@ class Huggingface(BaseLLM):
inference_params.pop("details")
inference_params.pop("return_full_text")
data = {
"inputs": prompt,
"inputs": prompt, # type: ignore
}
if task == "text-generation-inference":
data["parameters"] = inference_params

View file

@ -4,7 +4,7 @@ import time
import traceback
import types
from enum import Enum
from typing import Callable, List, Optional
from typing import Any, Callable, List, Optional
import requests # type: ignore
@ -185,8 +185,8 @@ def completion(
def embedding(
model: str,
input: list,
api_key: Optional[str] = None,
logging_obj=None,
api_key: Optional[str],
logging_obj: Any,
model_response=None,
encoding=None,
):

View file

@ -2,7 +2,7 @@ import json
import os
import time
from enum import Enum
from typing import Callable, Optional
from typing import Any, Callable, Optional
import requests # type: ignore
@ -124,9 +124,9 @@ def embedding(
model: str,
input: list,
model_response: EmbeddingResponse,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
logging_obj=None,
api_key: Optional[str],
api_base: Optional[str],
logging_obj: Any,
optional_params=None,
encoding=None,
):

View file

@ -2714,7 +2714,7 @@ def custom_prompt(
final_prompt_value: str = "",
bos_token: str = "",
eos_token: str = "",
):
) -> str:
prompt = bos_token + initial_prompt_value
bos_open = True
## a bos token is at the start of a system / human message

View file

@ -15,7 +15,7 @@ from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client,
get_async_httpx_client,
)
from litellm.rerank_api.types import RerankRequest, RerankResponse
from litellm.types.rerank import RerankRequest, RerankResponse
class TogetherAIRerank(BaseLLM):

View file

@ -47,7 +47,7 @@ class TritonChatCompletion(BaseLLM):
data: dict,
model_response: litellm.utils.EmbeddingResponse,
api_base: str,
logging_obj=None,
logging_obj: Any,
api_key: Optional[str] = None,
) -> EmbeddingResponse:
async_handler = AsyncHTTPHandler(
@ -93,9 +93,9 @@ class TritonChatCompletion(BaseLLM):
timeout: float,
api_base: str,
model_response: litellm.utils.EmbeddingResponse,
logging_obj: Any,
optional_params: dict,
api_key: Optional[str] = None,
logging_obj=None,
optional_params=None,
client=None,
aembedding: bool = False,
) -> EmbeddingResponse:
@ -122,7 +122,7 @@ class TritonChatCompletion(BaseLLM):
)
if aembedding:
response = await self.aembedding(
response = await self.aembedding( # type: ignore
data=data_for_triton,
model_response=model_response,
logging_obj=logging_obj,
@ -141,10 +141,10 @@ class TritonChatCompletion(BaseLLM):
messages: List[dict],
timeout: float,
api_base: str,
logging_obj: Any,
optional_params: dict,
model_response: ModelResponse,
api_key: Optional[str] = None,
logging_obj=None,
optional_params=None,
client=None,
stream: Optional[bool] = False,
acompletion: bool = False,
@ -239,11 +239,13 @@ class TritonChatCompletion(BaseLLM):
else:
handler = HTTPHandler()
if stream:
return self._handle_stream(
return self._handle_stream( # type: ignore
handler, api_base, json_data_for_triton, model, logging_obj
)
else:
response = handler.post(url=api_base, data=json_data_for_triton, headers=headers)
response = handler.post(
url=api_base, data=json_data_for_triton, headers=headers
)
return self._handle_response(
response, model_response, logging_obj, type_of_model=type_of_model
)
@ -261,7 +263,7 @@ class TritonChatCompletion(BaseLLM):
) -> ModelResponse:
handler = AsyncHTTPHandler()
if stream:
return self._ahandle_stream(
return self._ahandle_stream( # type: ignore
handler, api_base, data_for_triton, model, logging_obj
)
else:

View file

@ -2,7 +2,7 @@ from typing import List, Literal, Tuple
import httpx
from litellm import supports_system_messages, supports_response_schema, verbose_logger
from litellm import supports_response_schema, supports_system_messages, verbose_logger
from litellm.types.llms.vertex_ai import PartType
@ -67,6 +67,8 @@ def _get_vertex_url(
vertex_location: Optional[str],
vertex_api_version: Literal["v1", "v1beta1"],
) -> Tuple[str, str]:
url: Optional[str] = None
endpoint: Optional[str] = None
if mode == "chat":
### SET RUNTIME ENDPOINT ###
endpoint = "generateContent"
@ -88,6 +90,8 @@ def _get_vertex_url(
endpoint = "predict"
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
if not url or not endpoint:
raise ValueError(f"Unable to get vertex url/endpoinit for mode: {mode}")
return url, endpoint

View file

@ -3,7 +3,7 @@ Google AI Studio /batchEmbedContents Embeddings Endpoint
"""
import json
from typing import List, Literal, Optional, Union
from typing import Any, List, Literal, Optional, Union
import httpx
@ -31,9 +31,9 @@ class GoogleBatchEmbeddings(VertexLLM):
model_response: EmbeddingResponse,
custom_llm_provider: Literal["gemini", "vertex_ai"],
optional_params: dict,
logging_obj: Any,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
logging_obj=None,
encoding=None,
vertex_project=None,
vertex_location=None,

View file

@ -43,17 +43,17 @@ class VertexImageGeneration(VertexLLM):
vertex_location: Optional[str],
vertex_credentials: Optional[str],
model_response: litellm.ImageResponse,
logging_obj: Any,
model: Optional[
str
] = "imagegeneration", # vertex ai uses imagegeneration as the default model
client: Optional[Any] = None,
optional_params: Optional[dict] = None,
timeout: Optional[int] = None,
logging_obj=None,
aimg_generation=False,
):
) -> litellm.ImageResponse:
if aimg_generation is True:
return self.aimage_generation(
return self.aimage_generation( # type: ignore
prompt=prompt,
vertex_project=vertex_project,
vertex_location=vertex_location,
@ -138,13 +138,13 @@ class VertexImageGeneration(VertexLLM):
vertex_location: Optional[str],
vertex_credentials: Optional[str],
model_response: litellm.ImageResponse,
logging_obj: Any,
model: Optional[
str
] = "imagegeneration", # vertex ai uses imagegeneration as the default model
client: Optional[AsyncHTTPHandler] = None,
optional_params: Optional[dict] = None,
timeout: Optional[int] = None,
logging_obj=None,
):
response = None
if client is None:

View file

@ -48,7 +48,7 @@ class VertexMultimodalEmbedding(VertexLLM):
aembedding=False,
timeout=300,
client=None,
):
) -> litellm.EmbeddingResponse:
_auth_header, vertex_project = self._ensure_access_token(
credentials=vertex_credentials,
@ -121,7 +121,7 @@ class VertexMultimodalEmbedding(VertexLLM):
)
if aembedding is True:
return self.async_multimodal_embedding(
return self.async_multimodal_embedding( # type: ignore
model=model,
api_base=url,
data=request_data,

View file

@ -62,7 +62,7 @@ class VertexTextToSpeechAPI(VertexLLM):
_is_async: Optional[bool] = False,
optional_params: Optional[dict] = None,
kwargs: Optional[dict] = None,
):
) -> HttpxBinaryResponseContent:
import base64
####### Authenticate with Vertex AI ########
@ -145,7 +145,7 @@ class VertexTextToSpeechAPI(VertexLLM):
########## End of logging ############
####### Send the request ###################
if _is_async is True:
return self.async_audio_speech(
return self.async_audio_speech( # type:ignore
logging_obj=logging_obj, url=url, headers=headers, request=request
)
sync_handler = _get_httpx_client()

View file

@ -52,9 +52,9 @@ class VertexEmbedding(VertexBase):
vertex_credentials: Optional[str] = None,
gemini_api_key: Optional[str] = None,
extra_headers: Optional[dict] = None,
):
) -> litellm.EmbeddingResponse:
if aembedding is True:
return self.async_embedding(
return self.async_embedding( # type: ignore
model=model,
input=input,
logging_obj=logging_obj,

View file

@ -25,13 +25,8 @@ import requests # type: ignore
import litellm
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.utils import (
EmbeddingResponse,
ModelResponse,
Usage,
get_secret,
map_finish_reason,
)
from litellm.secret_managers.main import get_secret_str
from litellm.utils import EmbeddingResponse, ModelResponse, Usage, map_finish_reason
from .base import BaseLLM
from .prompt_templates import factory as ptf
@ -184,7 +179,7 @@ class IBMWatsonXAIConfig:
]
def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict) -> str:
# handle anthropic prompts and amazon titan prompts
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
@ -200,14 +195,10 @@ def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
eos_token=model_prompt_dict.get("eos_token", ""),
)
return prompt
elif provider == "ibm":
prompt = ptf.prompt_factory(
model=model, messages=messages, custom_llm_provider="watsonx"
)
elif provider == "ibm-mistralai":
prompt = ptf.mistral_instruct_pt(messages=messages)
else:
prompt = ptf.prompt_factory(
prompt: str = ptf.prompt_factory( # type: ignore
model=model, messages=messages, custom_llm_provider="watsonx"
)
return prompt
@ -327,37 +318,37 @@ class IBMWatsonXAI(BaseLLM):
# Load auth variables from environment variables
if url is None:
url = (
get_secret("WATSONX_API_BASE") # consistent with 'AZURE_API_BASE'
or get_secret("WATSONX_URL")
or get_secret("WX_URL")
or get_secret("WML_URL")
get_secret_str("WATSONX_API_BASE") # consistent with 'AZURE_API_BASE'
or get_secret_str("WATSONX_URL")
or get_secret_str("WX_URL")
or get_secret_str("WML_URL")
)
if api_key is None:
api_key = (
get_secret("WATSONX_APIKEY")
or get_secret("WATSONX_API_KEY")
or get_secret("WX_API_KEY")
get_secret_str("WATSONX_APIKEY")
or get_secret_str("WATSONX_API_KEY")
or get_secret_str("WX_API_KEY")
)
if token is None:
token = get_secret("WATSONX_TOKEN") or get_secret("WX_TOKEN")
token = get_secret_str("WATSONX_TOKEN") or get_secret_str("WX_TOKEN")
if project_id is None:
project_id = (
get_secret("WATSONX_PROJECT_ID")
or get_secret("WX_PROJECT_ID")
or get_secret("PROJECT_ID")
get_secret_str("WATSONX_PROJECT_ID")
or get_secret_str("WX_PROJECT_ID")
or get_secret_str("PROJECT_ID")
)
if region_name is None:
region_name = (
get_secret("WATSONX_REGION")
or get_secret("WX_REGION")
or get_secret("REGION")
get_secret_str("WATSONX_REGION")
or get_secret_str("WX_REGION")
or get_secret_str("REGION")
)
if space_id is None:
space_id = (
get_secret("WATSONX_DEPLOYMENT_SPACE_ID")
or get_secret("WATSONX_SPACE_ID")
or get_secret("WX_SPACE_ID")
or get_secret("SPACE_ID")
get_secret_str("WATSONX_DEPLOYMENT_SPACE_ID")
or get_secret_str("WATSONX_SPACE_ID")
or get_secret_str("WX_SPACE_ID")
or get_secret_str("SPACE_ID")
)
# credentials parsing
@ -446,8 +437,8 @@ class IBMWatsonXAI(BaseLLM):
model_response: ModelResponse,
print_verbose: Callable,
encoding,
logging_obj,
optional_params=None,
logging_obj: Any,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
@ -592,13 +583,13 @@ class IBMWatsonXAI(BaseLLM):
model: str,
input: Union[list, str],
model_response: litellm.EmbeddingResponse,
api_key: Optional[str] = None,
logging_obj=None,
optional_params=None,
api_key: Optional[str],
logging_obj: Any,
optional_params: dict,
encoding=None,
print_verbose=None,
aembedding=None,
):
) -> litellm.EmbeddingResponse:
"""
Send a text embedding request to the IBM Watsonx.ai API.
"""
@ -657,7 +648,7 @@ class IBMWatsonXAI(BaseLLM):
try:
if aembedding is True:
return handle_aembedding(req_params)
return handle_aembedding(req_params) # type: ignore
else:
return handle_embedding(req_params)
except WatsonXAIError as e:
@ -669,7 +660,7 @@ class IBMWatsonXAI(BaseLLM):
headers = {}
headers["Content-Type"] = "application/x-www-form-urlencoded"
if api_key is None:
api_key = get_secret("WX_API_KEY") or get_secret("WATSONX_API_KEY")
api_key = get_secret_str("WX_API_KEY") or get_secret_str("WATSONX_API_KEY")
if api_key is None:
raise ValueError("API key is required")
headers["Accept"] = "application/json"
@ -812,22 +803,29 @@ class RequestManager:
request_params["data"] = json.dumps(request_params.pop("json", {}))
method = request_params.pop("method")
retries = 0
resp: Optional[httpx.Response] = None
while retries < 3:
if method.upper() == "POST":
resp = await self.async_handler.post(**request_params)
else:
resp = await self.async_handler.get(**request_params)
if resp.status_code in [429, 503, 504, 520]:
if resp is not None and resp.status_code in [429, 503, 504, 520]:
# to handle rate limiting and service unavailable errors
# see: ibm_watsonx_ai.foundation_models.inference.base_model_inference.BaseModelInference._send_inference_payload
await asyncio.sleep(2**retries)
retries += 1
else:
break
if resp is None:
raise WatsonXAIError(
status_code=500,
message="No response from the server",
)
if resp.is_error:
error_reason = getattr(resp, "reason", "")
raise WatsonXAIError(
status_code=resp.status_code,
message=f"Error {resp.status_code} ({resp.reason}): {resp.text}",
message=f"Error {resp.status_code} ({error_reason}): {resp.text}",
)
yield resp
# await async_handler.close()

View file

@ -19,7 +19,8 @@ import threading
import time
import traceback
import uuid
from concurrent.futures import ThreadPoolExecutor
from concurrent import futures
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
from copy import deepcopy
from functools import partial
from typing import Any, Callable, Dict, List, Literal, Mapping, Optional, Type, Union
@ -647,7 +648,7 @@ def mock_completion(
@client
def completion(
def completion( # type: ignore
model: str,
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
messages: List = [],
@ -2940,16 +2941,16 @@ def completion_with_retries(*args, **kwargs):
num_retries = kwargs.pop("num_retries", 3)
retry_strategy: Literal["exponential_backoff_retry", "constant_retry"] = kwargs.pop("retry_strategy", "constant_retry") # type: ignore
original_function = kwargs.pop("original_function", completion)
if retry_strategy == "constant_retry":
retryer = tenacity.Retrying(
stop=tenacity.stop_after_attempt(num_retries), reraise=True
)
elif retry_strategy == "exponential_backoff_retry":
if retry_strategy == "exponential_backoff_retry":
retryer = tenacity.Retrying(
wait=tenacity.wait_exponential(multiplier=1, max=10),
stop=tenacity.stop_after_attempt(num_retries),
reraise=True,
)
else:
retryer = tenacity.Retrying(
stop=tenacity.stop_after_attempt(num_retries), reraise=True
)
return retryer(original_function, *args, **kwargs)
@ -2968,16 +2969,16 @@ async def acompletion_with_retries(*args, **kwargs):
num_retries = kwargs.pop("num_retries", 3)
retry_strategy = kwargs.pop("retry_strategy", "constant_retry")
original_function = kwargs.pop("original_function", completion)
if retry_strategy == "constant_retry":
retryer = tenacity.Retrying(
stop=tenacity.stop_after_attempt(num_retries), reraise=True
)
elif retry_strategy == "exponential_backoff_retry":
if retry_strategy == "exponential_backoff_retry":
retryer = tenacity.Retrying(
wait=tenacity.wait_exponential(multiplier=1, max=10),
stop=tenacity.stop_after_attempt(num_retries),
reraise=True,
)
else:
retryer = tenacity.Retrying(
stop=tenacity.stop_after_attempt(num_retries), reraise=True
)
return await retryer(original_function, *args, **kwargs)
@ -3045,7 +3046,7 @@ def batch_completion(
temperature=temperature,
top_p=top_p,
n=n,
stream=stream,
stream=stream or False,
stop=stop,
max_tokens=max_tokens,
presence_penalty=presence_penalty,
@ -3124,7 +3125,7 @@ def batch_completion_models(*args, **kwargs):
models = kwargs["models"]
kwargs.pop("models")
futures = {}
with concurrent.futures.ThreadPoolExecutor(max_workers=len(models)) as executor:
with ThreadPoolExecutor(max_workers=len(models)) as executor:
for model in models:
futures[model] = executor.submit(
completion, *args, model=model, **kwargs
@ -3141,9 +3142,7 @@ def batch_completion_models(*args, **kwargs):
kwargs.pop("model_list")
nested_kwargs = kwargs.pop("kwargs", {})
futures = {}
with concurrent.futures.ThreadPoolExecutor(
max_workers=len(deployments)
) as executor:
with ThreadPoolExecutor(max_workers=len(deployments)) as executor:
for deployment in deployments:
for key in kwargs.keys():
if (
@ -3156,9 +3155,7 @@ def batch_completion_models(*args, **kwargs):
while futures:
# wait for the first returned future
print_verbose("\n\n waiting for next result\n\n")
done, _ = concurrent.futures.wait(
futures.values(), return_when=concurrent.futures.FIRST_COMPLETED
)
done, _ = wait(futures.values(), return_when=FIRST_COMPLETED)
print_verbose(f"done list\n{done}")
for future in done:
try:
@ -3214,6 +3211,8 @@ def batch_completion_models_all_responses(*args, **kwargs):
if "models" in kwargs:
models = kwargs["models"]
kwargs.pop("models")
else:
raise Exception("'models' param not in kwargs")
responses = []
@ -3256,6 +3255,7 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse:
model=model, api_base=kwargs.get("api_base", None)
)
response: Optional[EmbeddingResponse] = None
if (
custom_llm_provider == "openai"
or custom_llm_provider == "azure"
@ -3294,12 +3294,21 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse:
elif isinstance(init_response, EmbeddingResponse): ## CACHING SCENARIO
response = init_response
elif asyncio.iscoroutine(init_response):
response = await init_response
response = await init_response # type: ignore
else:
# Call the synchronous function using run_in_executor
response = await loop.run_in_executor(None, func_with_context)
if response is not None and hasattr(response, "_hidden_params"):
if (
response is not None
and isinstance(response, EmbeddingResponse)
and hasattr(response, "_hidden_params")
):
response._hidden_params["custom_llm_provider"] = custom_llm_provider
if response is None:
raise ValueError(
"Unable to get Embedding Response. Please pass a valid llm_provider."
)
return response
except Exception as e:
custom_llm_provider = custom_llm_provider or "openai"
@ -3329,7 +3338,6 @@ def embedding(
user: Optional[str] = None,
custom_llm_provider=None,
litellm_call_id=None,
litellm_logging_obj=None,
logger_fn=None,
**kwargs,
) -> EmbeddingResponse:
@ -3362,6 +3370,7 @@ def embedding(
client = kwargs.pop("client", None)
rpm = kwargs.pop("rpm", None)
tpm = kwargs.pop("tpm", None)
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
cooldown_time = kwargs.get("cooldown_time", None)
max_parallel_requests = kwargs.pop("max_parallel_requests", None)
model_info = kwargs.get("model_info", None)
@ -3491,7 +3500,7 @@ def embedding(
}
)
try:
response = None
response: Optional[EmbeddingResponse] = None
logging: Logging = litellm_logging_obj # type: ignore
logging.update_environment_variables(
model=model,
@ -3691,7 +3700,7 @@ def embedding(
raise ValueError(
"api_base is required for triton. Please pass `api_base`"
)
response = triton_chat_completions.embedding(
response = triton_chat_completions.embedding( # type: ignore
model=model,
input=input,
api_base=api_base,
@ -3783,6 +3792,7 @@ def embedding(
timeout=timeout,
aembedding=aembedding,
print_verbose=print_verbose,
api_key=api_key,
)
elif custom_llm_provider == "oobabooga":
response = oobabooga.embedding(
@ -3793,14 +3803,16 @@ def embedding(
logging_obj=logging,
optional_params=optional_params,
model_response=EmbeddingResponse(),
api_key=api_key,
)
elif custom_llm_provider == "ollama":
api_base = (
litellm.api_base
or api_base
or get_secret("OLLAMA_API_BASE")
or get_secret_str("OLLAMA_API_BASE")
or "http://localhost:11434"
) # type: ignore
if isinstance(input, str):
input = [input]
if not all(isinstance(item, str) for item in input):
@ -3881,13 +3893,13 @@ def embedding(
api_key = (
api_key
or litellm.api_key
or get_secret("XINFERENCE_API_KEY")
or get_secret_str("XINFERENCE_API_KEY")
or "stub-xinference-key"
) # xinference does not need an api key, pass a stub key if user did not set one
api_base = (
api_base
or litellm.api_base
or get_secret("XINFERENCE_API_BASE")
or get_secret_str("XINFERENCE_API_BASE")
or "http://127.0.0.1:9997/v1"
)
response = openai_chat_completions.embedding(
@ -3911,19 +3923,20 @@ def embedding(
optional_params=optional_params,
model_response=EmbeddingResponse(),
aembedding=aembedding,
api_key=api_key,
)
elif custom_llm_provider == "azure_ai":
api_base = (
api_base # for deepinfra/perplexity/anyscale/groq/friendliai we check in get_llm_provider and pass in the api base from there
or litellm.api_base
or get_secret("AZURE_AI_API_BASE")
or get_secret_str("AZURE_AI_API_BASE")
)
# set API KEY
api_key = (
api_key
or litellm.api_key # for deepinfra/perplexity/anyscale/friendliai we check in get_llm_provider and pass in the api key from there
or litellm.openai_key
or get_secret("AZURE_AI_API_KEY")
or get_secret_str("AZURE_AI_API_KEY")
)
## EMBEDDING CALL
@ -3944,10 +3957,14 @@ def embedding(
raise ValueError(f"No valid embedding model args passed in - {args}")
if response is not None and hasattr(response, "_hidden_params"):
response._hidden_params["custom_llm_provider"] = custom_llm_provider
if response is None:
args = locals()
raise ValueError(f"No valid embedding model args passed in - {args}")
return response
except Exception as e:
## LOGGING
logging.post_call(
litellm_logging_obj.post_call(
input=input,
api_key=api_key,
original_response=str(e),
@ -4018,7 +4035,11 @@ async def atext_completion(
else:
# Call the synchronous function using run_in_executor
response = await loop.run_in_executor(None, func_with_context)
if kwargs.get("stream", False) is True: # return an async generator
if (
kwargs.get("stream", False) is True
or isinstance(response, TextCompletionStreamWrapper)
or isinstance(response, CustomStreamWrapper)
): # return an async generator
return TextCompletionStreamWrapper(
completion_stream=_async_streaming(
response=response,
@ -4153,9 +4174,10 @@ def text_completion(
Your example of how to use this function goes here.
"""
if "engine" in kwargs:
if model is None:
_engine = kwargs["engine"]
if model is None and isinstance(_engine, str):
# only use engine when model not passed
model = kwargs["engine"]
model = _engine
kwargs.pop("engine")
text_completion_response = TextCompletionResponse()
@ -4223,7 +4245,7 @@ def text_completion(
def process_prompt(i, individual_prompt):
decoded_prompt = tokenizer.decode(individual_prompt)
all_params = {**kwargs, **optional_params}
response = text_completion(
response: TextCompletionResponse = text_completion( # type: ignore
model=model,
prompt=decoded_prompt,
num_retries=3, # ensure this does not fail for the batch
@ -4292,6 +4314,8 @@ def text_completion(
model = "text-completion-openai/" + _model
optional_params.pop("custom_llm_provider", None)
if model is None:
raise ValueError("model is not set. Set either via 'model' or 'engine' param.")
kwargs["text_completion"] = True
response = completion(
model=model,
@ -4302,7 +4326,11 @@ def text_completion(
)
if kwargs.get("acompletion", False) is True:
return response
if stream is True or kwargs.get("stream", False) is True:
if (
stream is True
or kwargs.get("stream", False) is True
or isinstance(response, CustomStreamWrapper)
):
response = TextCompletionStreamWrapper(
completion_stream=response,
model=model,
@ -4310,6 +4338,8 @@ def text_completion(
custom_llm_provider=custom_llm_provider,
)
return response
elif isinstance(response, TextCompletionStreamWrapper):
return response
transformed_logprobs = None
# only supported for TGI models
try:
@ -4424,7 +4454,10 @@ def moderation(
):
# only supports open ai for now
api_key = (
api_key or litellm.api_key or litellm.openai_key or get_secret("OPENAI_API_KEY")
api_key
or litellm.api_key
or litellm.openai_key
or get_secret_str("OPENAI_API_KEY")
)
openai_client = kwargs.get("client", None)
@ -4433,7 +4466,10 @@ def moderation(
api_key=api_key,
)
response = openai_client.moderations.create(input=input, model=model)
if model is not None:
response = openai_client.moderations.create(input=input, model=model)
else:
response = openai_client.moderations.create(input=input)
return response
@ -4441,20 +4477,30 @@ def moderation(
async def amoderation(
input: str, model: Optional[str] = None, api_key: Optional[str] = None, **kwargs
):
from openai import AsyncOpenAI
# only supports open ai for now
api_key = (
api_key or litellm.api_key or litellm.openai_key or get_secret("OPENAI_API_KEY")
api_key
or litellm.api_key
or litellm.openai_key
or get_secret_str("OPENAI_API_KEY")
)
openai_client = kwargs.get("client", None)
if openai_client is None:
if openai_client is None or not isinstance(openai_client, AsyncOpenAI):
# call helper to get OpenAI client
# _get_openai_client maintains in-memory caching logic for OpenAI clients
openai_client = openai_chat_completions._get_openai_client(
_openai_client: AsyncOpenAI = openai_chat_completions._get_openai_client( # type: ignore
is_async=True,
api_key=api_key,
)
response = await openai_client.moderations.create(input=input, model=model)
else:
_openai_client = openai_client
if model is not None:
response = await openai_client.moderations.create(input=input, model=model)
else:
response = await openai_client.moderations.create(input=input)
return response
@ -4497,7 +4543,7 @@ async def aimage_generation(*args, **kwargs) -> ImageResponse:
init_response = ImageResponse(**init_response)
response = init_response
elif asyncio.iscoroutine(init_response):
response = await init_response
response = await init_response # type: ignore
else:
# Call the synchronous function using run_in_executor
response = await loop.run_in_executor(None, func_with_context)
@ -4527,7 +4573,6 @@ def image_generation(
api_key: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
litellm_logging_obj=None,
custom_llm_provider=None,
**kwargs,
) -> ImageResponse:
@ -4543,9 +4588,10 @@ def image_generation(
proxy_server_request = kwargs.get("proxy_server_request", None)
model_info = kwargs.get("model_info", None)
metadata = kwargs.get("metadata", {})
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
client = kwargs.get("client", None)
model_response = litellm.utils.ImageResponse()
model_response: ImageResponse = litellm.utils.ImageResponse()
if model is not None or custom_llm_provider is not None:
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
else:
@ -4651,25 +4697,27 @@ def image_generation(
if custom_llm_provider == "azure":
# azure configs
api_type = get_secret("AZURE_API_TYPE") or "azure"
api_type = get_secret_str("AZURE_API_TYPE") or "azure"
api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE")
api_base = api_base or litellm.api_base or get_secret_str("AZURE_API_BASE")
api_version = (
api_version or litellm.api_version or get_secret("AZURE_API_VERSION")
api_version
or litellm.api_version
or get_secret_str("AZURE_API_VERSION")
)
api_key = (
api_key
or litellm.api_key
or litellm.azure_key
or get_secret("AZURE_OPENAI_API_KEY")
or get_secret("AZURE_API_KEY")
or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret_str("AZURE_API_KEY")
)
azure_ad_token = optional_params.pop("azure_ad_token", None) or get_secret(
"AZURE_AD_TOKEN"
)
azure_ad_token = optional_params.pop(
"azure_ad_token", None
) or get_secret_str("AZURE_AD_TOKEN")
model_response = azure_chat_completions.image_generation(
model=model,
@ -4714,18 +4762,18 @@ def image_generation(
optional_params.pop("vertex_project", None)
or optional_params.pop("vertex_ai_project", None)
or litellm.vertex_project
or get_secret("VERTEXAI_PROJECT")
or get_secret_str("VERTEXAI_PROJECT")
)
vertex_ai_location = (
optional_params.pop("vertex_location", None)
or optional_params.pop("vertex_ai_location", None)
or litellm.vertex_location
or get_secret("VERTEXAI_LOCATION")
or get_secret_str("VERTEXAI_LOCATION")
)
vertex_credentials = (
optional_params.pop("vertex_credentials", None)
or optional_params.pop("vertex_ai_credentials", None)
or get_secret("VERTEXAI_CREDENTIALS")
or get_secret_str("VERTEXAI_CREDENTIALS")
)
model_response = vertex_image_generation.image_generation(
model=model,
@ -4786,7 +4834,7 @@ async def atranscription(*args, **kwargs) -> TranscriptionResponse:
elif isinstance(init_response, TranscriptionResponse): ## CACHING SCENARIO
response = init_response
elif asyncio.iscoroutine(init_response):
response = await init_response
response = await init_response # type: ignore
else:
# Call the synchronous function using run_in_executor
response = await loop.run_in_executor(None, func_with_context)
@ -4820,7 +4868,6 @@ def transcription(
api_base: Optional[str] = None,
api_version: Optional[str] = None,
max_retries: Optional[int] = None,
litellm_logging_obj: Optional[LiteLLMLoggingObj] = None,
custom_llm_provider=None,
**kwargs,
) -> TranscriptionResponse:
@ -4830,6 +4877,7 @@ def transcription(
Allows router to load balance between them
"""
atranscription = kwargs.get("atranscription", False)
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
kwargs.get("litellm_call_id", None)
kwargs.get("logger_fn", None)
kwargs.get("proxy_server_request", None)
@ -4869,22 +4917,17 @@ def transcription(
custom_llm_provider=custom_llm_provider,
drop_params=drop_params,
)
# optional_params = {
# "language": language,
# "prompt": prompt,
# "response_format": response_format,
# "temperature": None, # openai defaults this to 0
# }
response: Optional[TranscriptionResponse] = None
if custom_llm_provider == "azure":
# azure configs
api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE")
api_base = api_base or litellm.api_base or get_secret_str("AZURE_API_BASE")
api_version = (
api_version or litellm.api_version or get_secret("AZURE_API_VERSION")
api_version or litellm.api_version or get_secret_str("AZURE_API_VERSION")
)
azure_ad_token = kwargs.pop("azure_ad_token", None) or get_secret(
azure_ad_token = kwargs.pop("azure_ad_token", None) or get_secret_str(
"AZURE_AD_TOKEN"
)
@ -4892,8 +4935,8 @@ def transcription(
api_key
or litellm.api_key
or litellm.azure_key
or get_secret("AZURE_API_KEY")
) # type: ignore
or get_secret_str("AZURE_API_KEY")
)
response = azure_audio_transcriptions.audio_transcriptions(
model=model,
@ -4942,6 +4985,9 @@ def transcription(
api_base=api_base,
api_key=api_key,
)
if response is None:
raise ValueError("Unmapped provider passed in. Unable to get the response.")
return response
@ -5149,15 +5195,16 @@ def speech(
vertex_ai_project = (
generic_optional_params.vertex_project
or litellm.vertex_project
or get_secret("VERTEXAI_PROJECT")
or get_secret_str("VERTEXAI_PROJECT")
)
vertex_ai_location = (
generic_optional_params.vertex_location
or litellm.vertex_location
or get_secret("VERTEXAI_LOCATION")
or get_secret_str("VERTEXAI_LOCATION")
)
vertex_credentials = generic_optional_params.vertex_credentials or get_secret(
"VERTEXAI_CREDENTIALS"
vertex_credentials = (
generic_optional_params.vertex_credentials
or get_secret_str("VERTEXAI_CREDENTIALS")
)
if voice is not None and not isinstance(voice, dict):
@ -5234,20 +5281,25 @@ async def ahealth_check(
if custom_llm_provider == "azure":
api_key = (
model_params.get("api_key")
or get_secret("AZURE_API_KEY")
or get_secret("AZURE_OPENAI_API_KEY")
or get_secret_str("AZURE_API_KEY")
or get_secret_str("AZURE_OPENAI_API_KEY")
)
api_base = (
api_base: Optional[str] = (
model_params.get("api_base")
or get_secret("AZURE_API_BASE")
or get_secret("AZURE_OPENAI_API_BASE")
or get_secret_str("AZURE_API_BASE")
or get_secret_str("AZURE_OPENAI_API_BASE")
)
if api_base is None:
raise ValueError(
"Azure API Base cannot be None. Set via 'AZURE_API_BASE' in env var or `.completion(..., api_base=..)`"
)
api_version = (
model_params.get("api_version")
or get_secret("AZURE_API_VERSION")
or get_secret("AZURE_OPENAI_API_VERSION")
or get_secret_str("AZURE_API_VERSION")
or get_secret_str("AZURE_OPENAI_API_VERSION")
)
timeout = (
@ -5273,7 +5325,7 @@ async def ahealth_check(
custom_llm_provider == "openai"
or custom_llm_provider == "text-completion-openai"
):
api_key = model_params.get("api_key") or get_secret("OPENAI_API_KEY")
api_key = model_params.get("api_key") or get_secret_str("OPENAI_API_KEY")
organization = model_params.get("organization")
timeout = (
@ -5282,7 +5334,7 @@ async def ahealth_check(
or default_timeout
)
api_base = model_params.get("api_base") or get_secret("OPENAI_API_BASE")
api_base = model_params.get("api_base") or get_secret_str("OPENAI_API_BASE")
if custom_llm_provider == "text-completion-openai":
mode = "completion"
@ -5377,7 +5429,9 @@ def config_completion(**kwargs):
)
def stream_chunk_builder_text_completion(chunks: list, messages: Optional[List] = None):
def stream_chunk_builder_text_completion(
chunks: list, messages: Optional[List] = None
) -> TextCompletionResponse:
id = chunks[0]["id"]
object = chunks[0]["object"]
created = chunks[0]["created"]
@ -5446,7 +5500,7 @@ def stream_chunk_builder_text_completion(chunks: list, messages: Optional[List]
response["usage"]["total_tokens"] = (
response["usage"]["prompt_tokens"] + response["usage"]["completion_tokens"]
)
return response
return TextCompletionResponse(**response)
def stream_chunk_builder(

View file

@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
from pydantic import BaseModel, ConfigDict, Extra, Field, Json, model_validator
from typing_extensions import Annotated, TypedDict
from litellm.integrations.SlackAlerting.types import AlertType
from litellm.types.integrations.slack_alerting import AlertType
from litellm.types.router import RouterErrors, UpdateRouterConfig
from litellm.types.utils import ProviderField, StandardCallbackDynamicParams

View file

@ -1,12 +1,13 @@
from typing import Optional
from fastapi import Depends, Request, APIRouter
from fastapi import HTTPException
import copy
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Request
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.caching import RedisCache
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
router = APIRouter(
prefix="/cache",
tags=["caching"],
@ -21,9 +22,9 @@ async def cache_ping():
"""
Endpoint for checking if cache can be pinged
"""
litellm_cache_params = {}
specific_cache_params = {}
try:
litellm_cache_params = {}
specific_cache_params = {}
if litellm.cache is None:
raise HTTPException(
@ -135,7 +136,9 @@ async def cache_redis_info():
raise HTTPException(
status_code=503, detail="Cache not initialized. litellm.cache is None"
)
if litellm.cache.type == "redis":
if litellm.cache.type == "redis" and isinstance(
litellm.cache.cache, RedisCache
):
client_list = litellm.cache.cache.client_list()
redis_info = litellm.cache.cache.info()
num_clients = len(client_list)
@ -177,7 +180,9 @@ async def cache_flushall():
raise HTTPException(
status_code=503, detail="Cache not initialized. litellm.cache is None"
)
if litellm.cache.type == "redis":
if litellm.cache.type == "redis" and isinstance(
litellm.cache.cache, RedisCache
):
litellm.cache.cache.flushall()
return {
"status": "success",

View file

@ -118,13 +118,13 @@ async def create_fine_tuning_job(
version,
)
data = fine_tuning_request.model_dump(exclude_none=True)
try:
if premium_user is not True:
raise ValueError(
f"Only premium users can use this endpoint + {CommonProxyErrors.not_premium_user.value}"
)
# Convert Pydantic model to dict
data = fine_tuning_request.model_dump(exclude_none=True)
verbose_proxy_logger.debug(
"Request received by LiteLLM:\n{}".format(json.dumps(data, indent=4)),
@ -146,7 +146,8 @@ async def create_fine_tuning_job(
)
# add llm_provider_config to data
data.update(llm_provider_config)
if llm_provider_config is not None:
data.update(llm_provider_config)
response = await litellm.acreate_fine_tuning_job(**data)
@ -262,7 +263,8 @@ async def list_fine_tuning_jobs(
custom_llm_provider=custom_llm_provider
)
data.update(llm_provider_config)
if llm_provider_config is not None:
data.update(llm_provider_config)
response = await litellm.alist_fine_tuning_jobs(
**data,
@ -378,7 +380,8 @@ async def retrieve_fine_tuning_job(
custom_llm_provider=custom_llm_provider
)
data.update(llm_provider_config)
if llm_provider_config is not None:
data.update(llm_provider_config)
response = await litellm.acancel_fine_tuning_job(
**data,

View file

@ -227,8 +227,8 @@ async def send_management_endpoint_alert(
- An internal user is created, updated, or deleted
- A team is created, updated, or deleted
"""
from litellm.integrations.SlackAlerting.types import AlertType
from litellm.proxy.proxy_server import premium_user, proxy_logging_obj
from litellm.types.integrations.slack_alerting import AlertType
if premium_user is not True:
return

View file

@ -114,10 +114,7 @@ from litellm import (
from litellm._logging import verbose_proxy_logger, verbose_router_logger
from litellm.caching import DualCache, RedisCache
from litellm.exceptions import RejectedRequestError
from litellm.integrations.SlackAlerting.slack_alerting import (
SlackAlerting,
SlackAlertingArgs,
)
from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting
from litellm.litellm_core_utils.core_helpers import (
_get_parent_otel_span_from_kwargs,
get_litellm_metadata_from_kwargs,
@ -249,6 +246,7 @@ from litellm.secret_managers.aws_secret_manager import (
)
from litellm.secret_managers.google_kms import load_google_kms
from litellm.secret_managers.main import get_secret, get_secret_str, str_to_bool
from litellm.types.integrations.slack_alerting import SlackAlertingArgs
from litellm.types.llms.anthropic import (
AnthropicMessagesRequest,
AnthropicResponse,

View file

@ -54,7 +54,6 @@ from litellm.exceptions import RejectedRequestError
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.integrations.custom_logger import CustomLogger
from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting
from litellm.integrations.SlackAlerting.types import DEFAULT_ALERT_TYPES
from litellm.integrations.SlackAlerting.utils import _add_langfuse_trace_id_to_alert
from litellm.litellm_core_utils.core_helpers import (
_get_parent_otel_span_from_kwargs,
@ -85,6 +84,7 @@ from litellm.proxy.hooks.parallel_request_limiter import (
_PROXY_MaxParallelRequestsHandler,
)
from litellm.secret_managers.main import str_to_bool
from litellm.types.integrations.slack_alerting import DEFAULT_ALERT_TYPES
from litellm.types.utils import CallTypes, LoggedLiteLLMParams
if TYPE_CHECKING:

View file

@ -208,7 +208,7 @@ async def vertex_proxy_route(
request,
fastapi_response,
user_api_key_dict,
stream=is_streaming_request,
stream=is_streaming_request, # type: ignore
)
return received_value

View file

@ -10,11 +10,10 @@ from litellm.llms.azure_ai.rerank import AzureAIRerank
from litellm.llms.cohere.rerank import CohereRerank
from litellm.llms.together_ai.rerank import TogetherAIRerank
from litellm.secret_managers.main import get_secret
from litellm.types.rerank import RerankRequest, RerankResponse
from litellm.types.router import *
from litellm.utils import client, exception_type, supports_httpx_timeout
from .types import RerankRequest, RerankResponse
####### ENVIRONMENT VARIABLES ###################
# Initialize any necessary instances or variables here
cohere_rerank = CohereRerank()

View file

@ -7,6 +7,7 @@ import httpx
import openai
import litellm
from litellm import get_secret, get_secret_str
from litellm._logging import verbose_router_logger
from litellm.llms.AzureOpenAI.azure import get_azure_ad_token_from_oidc
from litellm.secret_managers.get_azure_ad_token_provider import (
@ -111,17 +112,17 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
api_key = litellm_params.get("api_key") or default_api_key
if api_key and isinstance(api_key, str) and api_key.startswith("os.environ/"):
api_key_env_name = api_key.replace("os.environ/", "")
api_key = litellm.get_secret(api_key_env_name)
api_key = get_secret_str(api_key_env_name)
litellm_params["api_key"] = api_key
api_base = litellm_params.get("api_base")
base_url = litellm_params.get("base_url")
base_url: Optional[str] = litellm_params.get("base_url")
api_base = (
api_base or base_url or default_api_base
) # allow users to pass in `api_base` or `base_url` for azure
if api_base and api_base.startswith("os.environ/"):
api_base_env_name = api_base.replace("os.environ/", "")
api_base = litellm.get_secret(api_base_env_name)
api_base = get_secret_str(api_base_env_name)
litellm_params["api_base"] = api_base
## AZURE AI STUDIO MISTRAL CHECK ##
@ -147,33 +148,37 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
api_version = litellm_params.get("api_version")
if api_version and api_version.startswith("os.environ/"):
api_version_env_name = api_version.replace("os.environ/", "")
api_version = litellm.get_secret(api_version_env_name)
api_version = get_secret_str(api_version_env_name)
litellm_params["api_version"] = api_version
timeout = litellm_params.pop("timeout", None) or litellm.request_timeout
timeout: Optional[float] = (
litellm_params.pop("timeout", None) or litellm.request_timeout
)
if isinstance(timeout, str) and timeout.startswith("os.environ/"):
timeout_env_name = timeout.replace("os.environ/", "")
timeout = litellm.get_secret(timeout_env_name)
timeout = get_secret(timeout_env_name) # type: ignore
litellm_params["timeout"] = timeout
stream_timeout = litellm_params.pop(
stream_timeout: Optional[float] = litellm_params.pop(
"stream_timeout", timeout
) # if no stream_timeout is set, default to timeout
if isinstance(stream_timeout, str) and stream_timeout.startswith("os.environ/"):
stream_timeout_env_name = stream_timeout.replace("os.environ/", "")
stream_timeout = litellm.get_secret(stream_timeout_env_name)
stream_timeout = get_secret(stream_timeout_env_name) # type: ignore
litellm_params["stream_timeout"] = stream_timeout
max_retries = litellm_params.pop("max_retries", 0) # router handles retry logic
max_retries: Optional[int] = litellm_params.pop(
"max_retries", 0
) # router handles retry logic
if isinstance(max_retries, str) and max_retries.startswith("os.environ/"):
max_retries_env_name = max_retries.replace("os.environ/", "")
max_retries = litellm.get_secret(max_retries_env_name)
max_retries = get_secret(max_retries_env_name) # type: ignore
litellm_params["max_retries"] = max_retries
organization = litellm_params.get("organization", None)
if isinstance(organization, str) and organization.startswith("os.environ/"):
organization_env_name = organization.replace("os.environ/", "")
organization = litellm.get_secret(organization_env_name)
organization = get_secret_str(organization_env_name)
litellm_params["organization"] = organization
azure_ad_token_provider: Optional[Callable[[], str]] = None
if litellm_params.get("tenant_id"):
@ -227,8 +232,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
azure_ad_token_provider=azure_ad_token_provider,
base_url=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries,
timeout=timeout, # type: ignore
max_retries=max_retries, # type: ignore
http_client=httpx.AsyncClient(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
@ -253,8 +258,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
azure_ad_token_provider=azure_ad_token_provider,
base_url=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries,
timeout=timeout, # type: ignore
max_retries=max_retries, # type: ignore
http_client=httpx.Client(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
@ -276,8 +281,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
azure_ad_token_provider=azure_ad_token_provider,
base_url=api_base,
api_version=api_version,
timeout=stream_timeout,
max_retries=max_retries,
timeout=stream_timeout, # type: ignore
max_retries=max_retries, # type: ignore
http_client=httpx.AsyncClient(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
@ -302,8 +307,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
azure_ad_token_provider=azure_ad_token_provider,
base_url=api_base,
api_version=api_version,
timeout=stream_timeout,
max_retries=max_retries,
timeout=stream_timeout, # type: ignore
max_retries=max_retries, # type: ignore
http_client=httpx.Client(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
@ -350,8 +355,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
cache_key = f"{model_id}_async_client"
_client = openai.AsyncAzureOpenAI( # type: ignore
**azure_client_params,
timeout=timeout,
max_retries=max_retries,
timeout=timeout, # type: ignore
max_retries=max_retries, # type: ignore
http_client=httpx.AsyncClient(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
@ -371,8 +376,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
cache_key = f"{model_id}_client"
_client = openai.AzureOpenAI( # type: ignore
**azure_client_params,
timeout=timeout,
max_retries=max_retries,
timeout=timeout, # type: ignore
max_retries=max_retries, # type: ignore
http_client=httpx.Client(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
@ -391,8 +396,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
cache_key = f"{model_id}_stream_async_client"
_client = openai.AsyncAzureOpenAI( # type: ignore
**azure_client_params,
timeout=stream_timeout,
max_retries=max_retries,
timeout=stream_timeout, # type: ignore
max_retries=max_retries, # type: ignore
http_client=httpx.AsyncClient(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
@ -413,8 +418,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
cache_key = f"{model_id}_stream_client"
_client = openai.AzureOpenAI( # type: ignore
**azure_client_params,
timeout=stream_timeout,
max_retries=max_retries,
timeout=stream_timeout, # type: ignore
max_retries=max_retries, # type: ignore
http_client=httpx.Client(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
@ -441,8 +446,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
_client = openai.AsyncOpenAI( # type: ignore
api_key=api_key,
base_url=api_base,
timeout=timeout,
max_retries=max_retries,
timeout=timeout, # type: ignore
max_retries=max_retries, # type: ignore
organization=organization,
http_client=httpx.AsyncClient(
limits=httpx.Limits(
@ -465,8 +470,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
_client = openai.OpenAI( # type: ignore
api_key=api_key,
base_url=api_base,
timeout=timeout,
max_retries=max_retries,
timeout=timeout, # type: ignore
max_retries=max_retries, # type: ignore
organization=organization,
http_client=httpx.Client(
limits=httpx.Limits(
@ -487,8 +492,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
_client = openai.AsyncOpenAI( # type: ignore
api_key=api_key,
base_url=api_base,
timeout=stream_timeout,
max_retries=max_retries,
timeout=stream_timeout, # type: ignore
max_retries=max_retries, # type: ignore
organization=organization,
http_client=httpx.AsyncClient(
limits=httpx.Limits(
@ -512,8 +517,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
_client = openai.OpenAI( # type: ignore
api_key=api_key,
base_url=api_base,
timeout=stream_timeout,
max_retries=max_retries,
timeout=stream_timeout, # type: ignore
max_retries=max_retries, # type: ignore
organization=organization,
http_client=httpx.Client(
limits=httpx.Limits(
@ -542,20 +547,29 @@ def get_azure_ad_token_from_entrata_id(
verbose_router_logger.debug("Getting Azure AD Token from Entrata ID")
if tenant_id.startswith("os.environ/"):
tenant_id = litellm.get_secret(tenant_id)
_tenant_id = get_secret_str(tenant_id)
else:
_tenant_id = tenant_id
if client_id.startswith("os.environ/"):
client_id = litellm.get_secret(client_id)
_client_id = get_secret_str(client_id)
else:
_client_id = client_id
if client_secret.startswith("os.environ/"):
client_secret = litellm.get_secret(client_secret)
_client_secret = get_secret_str(client_secret)
else:
_client_secret = client_secret
verbose_router_logger.debug(
"tenant_id %s, client_id %s, client_secret %s",
tenant_id,
client_id,
client_secret,
_tenant_id,
_client_id,
_client_secret,
)
credential = ClientSecretCredential(tenant_id, client_id, client_secret)
if _tenant_id is None or _client_id is None or _client_secret is None:
raise ValueError("tenant_id, client_id, and client_secret must be provided")
credential = ClientSecretCredential(_tenant_id, _client_id, _client_secret)
verbose_router_logger.debug("credential %s", credential)

View file

@ -2,6 +2,8 @@ import asyncio
import traceback
from typing import TYPE_CHECKING, Any
from litellm.types.integrations.slack_alerting import AlertType
if TYPE_CHECKING:
from litellm.router import Router as _Router
@ -49,5 +51,6 @@ async def send_llm_exception_alert(
await litellm_router_instance.slack_alerting_logger.send_alert(
message=f"LLM API call failed: `{exception_str}`",
level="High",
alert_type="llm_exceptions",
alert_type=AlertType.llm_exceptions,
alerting_metadata={},
)

View file

@ -792,7 +792,7 @@ class EmbeddingResponse(OpenAIObject):
model: Optional[str] = None
"""The model used for embedding."""
data: Optional[List] = None
data: List
"""The actual embedding value"""
object: Literal["list"]
@ -803,6 +803,7 @@ class EmbeddingResponse(OpenAIObject):
_hidden_params: dict = {}
_response_headers: Optional[Dict] = None
_response_ms: Optional[float] = None
def __init__(
self,
@ -822,7 +823,7 @@ class EmbeddingResponse(OpenAIObject):
if data:
data = data
else:
data = None
data = []
if usage:
usage = usage

View file

@ -322,6 +322,9 @@ def function_setup(
original_function: str, rules_obj, start_time, *args, **kwargs
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
### NOTICES ###
from litellm import Logging as LiteLLMLogging
from litellm.litellm_core_utils.litellm_logging import set_callbacks
if litellm.set_verbose is True:
verbose_logger.warning(
"`litellm.set_verbose` is deprecated. Please set `os.environ['LITELLM_LOG'] = 'DEBUG'` for debug logs."
@ -333,7 +336,7 @@ def function_setup(
custom_llm_setup()
## LOGGING SETUP
function_id = kwargs["id"] if "id" in kwargs else None
function_id: Optional[str] = kwargs["id"] if "id" in kwargs else None
if len(litellm.callbacks) > 0:
for callback in litellm.callbacks:
@ -375,9 +378,7 @@ def function_setup(
+ litellm.failure_callback
)
)
litellm.litellm_core_utils.litellm_logging.set_callbacks(
callback_list=callback_list, function_id=function_id
)
set_callbacks(callback_list=callback_list, function_id=function_id)
## ASYNC CALLBACKS
if len(litellm.input_callback) > 0:
removed_async_items = []
@ -560,12 +561,12 @@ def function_setup(
else:
messages = "default-message-value"
stream = True if "stream" in kwargs and kwargs["stream"] is True else False
logging_obj = litellm.litellm_core_utils.litellm_logging.Logging(
logging_obj = LiteLLMLogging(
model=model,
messages=messages,
stream=stream,
litellm_call_id=kwargs["litellm_call_id"],
function_id=function_id,
function_id=function_id or "",
call_type=call_type,
start_time=start_time,
dynamic_success_callbacks=dynamic_success_callbacks,
@ -655,10 +656,8 @@ def client(original_function):
json_response_format = optional_params[
"response_format"
]
elif (
_parsing._completions.is_basemodel_type(
optional_params["response_format"]
)
elif _parsing._completions.is_basemodel_type(
optional_params["response_format"] # type: ignore
):
json_response_format = (
type_to_response_format_param(
@ -827,6 +826,7 @@ def client(original_function):
print_verbose("INSIDE CHECKING CACHE")
if (
litellm.cache is not None
and litellm.cache.supported_call_types is not None
and str(original_function.__name__)
in litellm.cache.supported_call_types
):
@ -879,7 +879,7 @@ def client(original_function):
dynamic_api_key,
api_base,
) = litellm.get_llm_provider(
model=model,
model=model or "",
custom_llm_provider=kwargs.get(
"custom_llm_provider", None
),
@ -949,6 +949,8 @@ def client(original_function):
base_model=base_model,
messages=messages,
user_max_tokens=user_max_tokens,
buffer_num=None,
buffer_perc=None,
)
kwargs["max_tokens"] = modified_max_tokens
except Exception as e:
@ -990,6 +992,7 @@ def client(original_function):
# [OPTIONAL] ADD TO CACHE
if (
litellm.cache is not None
and litellm.cache.supported_call_types is not None
and str(original_function.__name__)
in litellm.cache.supported_call_types
) and (kwargs.get("cache", {}).get("no-store", False) is not True):
@ -1006,7 +1009,7 @@ def client(original_function):
"id", None
)
result._hidden_params["api_base"] = get_api_base(
model=model,
model=model or "",
optional_params=getattr(logging_obj, "optional_params", {}),
)
result._hidden_params["response_cost"] = (
@ -1053,7 +1056,7 @@ def client(original_function):
and not _is_litellm_router_call
):
if len(args) > 0:
args[0] = context_window_fallback_dict[model]
args[0] = context_window_fallback_dict[model] # type: ignore
else:
kwargs["model"] = context_window_fallback_dict[model]
return original_function(*args, **kwargs)
@ -1065,12 +1068,6 @@ def client(original_function):
logging_obj.failure_handler(
e, traceback_exception, start_time, end_time
) # DO NOT MAKE THREADED - router retry fallback relies on this!
if hasattr(e, "message"):
if (
liteDebuggerClient
and liteDebuggerClient.dashboard_url is not None
): # make it easy to get to the debugger logs if you've initialized it
e.message += f"\n Check the log in your dashboard - {liteDebuggerClient.dashboard_url}"
raise e
@wraps(original_function)
@ -1126,6 +1123,7 @@ def client(original_function):
print_verbose("INSIDE CHECKING CACHE")
if (
litellm.cache is not None
and litellm.cache.supported_call_types is not None
and str(original_function.__name__)
in litellm.cache.supported_call_types
):
@ -1287,7 +1285,11 @@ def client(original_function):
args=(cached_result, start_time, end_time, cache_hit),
).start()
cache_key = kwargs.get("preset_cache_key", None)
cached_result._hidden_params["cache_key"] = cache_key
if (
isinstance(cached_result, BaseModel)
or isinstance(cached_result, CustomStreamWrapper)
) and hasattr(cached_result, "_hidden_params"):
cached_result._hidden_params["cache_key"] = cache_key # type: ignore
return cached_result
elif (
call_type == CallTypes.aembedding.value
@ -1447,6 +1449,7 @@ def client(original_function):
# [OPTIONAL] ADD TO CACHE
if (
(litellm.cache is not None)
and litellm.cache.supported_call_types is not None
and (
str(original_function.__name__)
in litellm.cache.supported_call_types
@ -1504,11 +1507,12 @@ def client(original_function):
if (
isinstance(result, EmbeddingResponse)
and final_embedding_cached_response is not None
and final_embedding_cached_response.data is not None
):
idx = 0
final_data_list = []
for item in final_embedding_cached_response.data:
if item is None:
if item is None and result.data is not None:
final_data_list.append(result.data[idx])
idx += 1
else:
@ -1575,7 +1579,7 @@ def client(original_function):
and model in context_window_fallback_dict
):
if len(args) > 0:
args[0] = context_window_fallback_dict[model]
args[0] = context_window_fallback_dict[model] # type: ignore
else:
kwargs["model"] = context_window_fallback_dict[model]
return await original_function(*args, **kwargs)
@ -2945,13 +2949,19 @@ def get_optional_params(
response_format=non_default_params["response_format"]
)
# # clean out 'additionalProperties = False'. Causes vertexai/gemini OpenAI API Schema errors - https://github.com/langchain-ai/langchainjs/issues/5240
if non_default_params["response_format"].get("json_schema", {}).get(
"schema"
) is not None and custom_llm_provider in [
"gemini",
"vertex_ai",
"vertex_ai_beta",
]:
if (
non_default_params["response_format"] is not None
and non_default_params["response_format"]
.get("json_schema", {})
.get("schema")
is not None
and custom_llm_provider
in [
"gemini",
"vertex_ai",
"vertex_ai_beta",
]
):
old_schema = copy.deepcopy(
non_default_params["response_format"]
.get("json_schema", {})
@ -3754,7 +3764,11 @@ def get_optional_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=drop_params,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
elif custom_llm_provider == "openrouter":
supported_params = get_supported_openai_params(
@ -3863,7 +3877,11 @@ def get_optional_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=drop_params,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
elif custom_llm_provider == "azure":
supported_params = get_supported_openai_params(
@ -4889,7 +4907,7 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
try:
split_model, custom_llm_provider, _, _ = get_llm_provider(model=model)
except Exception:
pass
split_model = model
combined_model_name = model
stripped_model_name = _strip_model_name(model=model)
combined_stripped_model_name = stripped_model_name
@ -5865,6 +5883,8 @@ def convert_to_model_response_object(
for idx, choice in enumerate(response_object["choices"]):
## HANDLE JSON MODE - anthropic returns single function call]
tool_calls = choice["message"].get("tool_calls", None)
message: Optional[Message] = None
finish_reason: Optional[str] = None
if (
convert_tool_call_to_json_mode
and tool_calls is not None
@ -5877,7 +5897,7 @@ def convert_to_model_response_object(
if json_mode_content_str is not None:
message = litellm.Message(content=json_mode_content_str)
finish_reason = "stop"
else:
if message is None:
message = Message(
content=choice["message"].get("content", None),
role=choice["message"]["role"] or "assistant",
@ -6066,7 +6086,7 @@ def valid_model(model):
model in litellm.open_ai_chat_completion_models
or model in litellm.open_ai_text_completion_models
):
openai.Model.retrieve(model)
openai.models.retrieve(model)
else:
messages = [{"role": "user", "content": "Hello World"}]
litellm.completion(model=model, messages=messages)
@ -6386,8 +6406,8 @@ class CustomStreamWrapper:
self,
completion_stream,
model,
custom_llm_provider=None,
logging_obj=None,
logging_obj: Any,
custom_llm_provider: Optional[str] = None,
stream_options=None,
make_call: Optional[Callable] = None,
_response_headers: Optional[dict] = None,
@ -6633,36 +6653,6 @@ class CustomStreamWrapper:
"completion_tokens": completion_tokens,
}
def handle_together_ai_chunk(self, chunk):
chunk = chunk.decode("utf-8")
text = ""
is_finished = False
finish_reason = None
if "text" in chunk:
text_index = chunk.find('"text":"') # this checks if text: exists
text_start = text_index + len('"text":"')
text_end = chunk.find('"}', text_start)
if text_index != -1 and text_end != -1:
extracted_text = chunk[text_start:text_end]
text = extracted_text
return {
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
}
elif "[DONE]" in chunk:
return {"text": text, "is_finished": True, "finish_reason": "stop"}
elif "error" in chunk:
raise litellm.together_ai.TogetherAIError(
status_code=422, message=f"{str(chunk)}"
)
else:
return {
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
}
def handle_predibase_chunk(self, chunk):
try:
if not isinstance(chunk, str):
@ -7264,12 +7254,17 @@ class CustomStreamWrapper:
try:
if isinstance(chunk, dict):
parsed_response = chunk
if isinstance(chunk, (str, bytes)):
elif isinstance(chunk, (str, bytes)):
if isinstance(chunk, bytes):
parsed_response = chunk.decode("utf-8")
else:
parsed_response = chunk
data_json = json.loads(parsed_response)
else:
raise ValueError("Unable to parse streaming chunk")
if isinstance(parsed_response, dict):
data_json = parsed_response
else:
data_json = json.loads(parsed_response)
text = (
data_json.get("outputs", "")[0]
.get("data", "")
@ -7331,8 +7326,7 @@ class CustomStreamWrapper:
if (
len(model_response.choices) > 0
and hasattr(model_response.choices[0], "delta")
and model_response.choices[0].delta is not None
and getattr(model_response.choices[0], "delta") is not None
):
# do nothing, if object instantiated
pass
@ -7350,7 +7344,7 @@ class CustomStreamWrapper:
is_empty = False
return is_empty
def chunk_creator(self, chunk):
def chunk_creator(self, chunk): # type: ignore
model_response = self.model_response_creator()
response_obj = {}
try:
@ -7422,11 +7416,6 @@ class CustomStreamWrapper:
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider and self.custom_llm_provider == "together_ai":
response_obj = self.handle_together_ai_chunk(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider and self.custom_llm_provider == "huggingface":
response_obj = self.handle_huggingface_chunk(chunk)
completion_obj["content"] = response_obj["text"]
@ -7475,51 +7464,6 @@ class CustomStreamWrapper:
if self.sent_first_chunk is False:
raise Exception("An unknown error occurred with the stream")
self.received_finish_reason = "stop"
elif self.custom_llm_provider == "gemini":
if hasattr(chunk, "parts") is True:
try:
if len(chunk.parts) > 0:
completion_obj["content"] = chunk.parts[0].text
if len(chunk.parts) > 0 and hasattr(
chunk.parts[0], "finish_reason"
):
self.received_finish_reason = chunk.parts[
0
].finish_reason.name
except Exception:
if chunk.parts[0].finish_reason.name == "SAFETY":
raise Exception(
f"The response was blocked by VertexAI. {str(chunk)}"
)
else:
completion_obj["content"] = str(chunk)
elif self.custom_llm_provider and (
self.custom_llm_provider == "vertex_ai_beta"
):
from litellm.types.utils import (
GenericStreamingChunk as UtilsStreamingChunk,
)
if self.received_finish_reason is not None:
raise StopIteration
response_obj: UtilsStreamingChunk = chunk
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
if (
self.stream_options
and self.stream_options.get("include_usage", False) is True
and response_obj["usage"] is not None
):
model_response.usage = litellm.Usage(
prompt_tokens=response_obj["usage"]["prompt_tokens"],
completion_tokens=response_obj["usage"]["completion_tokens"],
total_tokens=response_obj["usage"]["total_tokens"],
)
if "tool_use" in response_obj and response_obj["tool_use"] is not None:
completion_obj["tool_calls"] = [response_obj["tool_use"]]
elif self.custom_llm_provider and (self.custom_llm_provider == "vertex_ai"):
import proto # type: ignore
@ -7624,53 +7568,7 @@ class CustomStreamWrapper:
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider == "bedrock":
from litellm.types.llms.bedrock import GenericStreamingChunk
if self.received_finish_reason is not None:
raise StopIteration
response_obj: GenericStreamingChunk = chunk
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
if (
self.stream_options
and self.stream_options.get("include_usage", False) is True
and response_obj["usage"] is not None
):
model_response.usage = litellm.Usage(
prompt_tokens=response_obj["usage"]["inputTokens"],
completion_tokens=response_obj["usage"]["outputTokens"],
total_tokens=response_obj["usage"]["totalTokens"],
)
if "tool_use" in response_obj and response_obj["tool_use"] is not None:
completion_obj["tool_calls"] = [response_obj["tool_use"]]
elif self.custom_llm_provider == "sagemaker":
from litellm.types.llms.bedrock import GenericStreamingChunk
if self.received_finish_reason is not None:
raise StopIteration
response_obj: GenericStreamingChunk = chunk
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
if (
self.stream_options
and self.stream_options.get("include_usage", False) is True
and response_obj["usage"] is not None
):
model_response.usage = litellm.Usage(
prompt_tokens=response_obj["usage"]["inputTokens"],
completion_tokens=response_obj["usage"]["outputTokens"],
total_tokens=response_obj["usage"]["totalTokens"],
)
if "tool_use" in response_obj and response_obj["tool_use"] is not None:
completion_obj["tool_calls"] = [response_obj["tool_use"]]
elif self.custom_llm_provider == "petals":
if len(self.completion_stream) == 0:
if self.received_finish_reason is not None:
@ -8181,9 +8079,11 @@ class CustomStreamWrapper:
target=self.run_success_logging_in_thread,
args=(response, cache_hit),
).start() # log response
self.response_uptil_now += (
response.choices[0].delta.get("content", "") or ""
)
choice = response.choices[0]
if isinstance(choice, StreamingChoices):
self.response_uptil_now += choice.delta.get("content", "") or ""
else:
self.response_uptil_now += ""
self.rules.post_call_rules(
input=self.response_uptil_now, model=self.model
)
@ -8223,8 +8123,11 @@ class CustomStreamWrapper:
)
response = self.model_response_creator()
if complete_streaming_response is not None:
response.usage = complete_streaming_response.usage
response._hidden_params["usage"] = complete_streaming_response.usage # type: ignore
setattr(
response,
"usage",
getattr(complete_streaming_response, "usage"),
)
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler,
@ -8349,9 +8252,11 @@ class CustomStreamWrapper:
processed_chunk, cache_hit=cache_hit
)
)
self.response_uptil_now += (
processed_chunk.choices[0].delta.get("content", "") or ""
)
choice = processed_chunk.choices[0]
if isinstance(choice, StreamingChoices):
self.response_uptil_now += choice.delta.get("content", "") or ""
else:
self.response_uptil_now += ""
self.rules.post_call_rules(
input=self.response_uptil_now, model=self.model
)
@ -8401,9 +8306,13 @@ class CustomStreamWrapper:
)
)
self.response_uptil_now += (
processed_chunk.choices[0].delta.get("content", "") or ""
)
choice = processed_chunk.choices[0]
if isinstance(choice, StreamingChoices):
self.response_uptil_now += (
choice.delta.get("content", "") or ""
)
else:
self.response_uptil_now += ""
self.rules.post_call_rules(
input=self.response_uptil_now, model=self.model
)
@ -8423,7 +8332,11 @@ class CustomStreamWrapper:
)
response = self.model_response_creator()
if complete_streaming_response is not None:
setattr(response, "usage", complete_streaming_response.usage)
setattr(
response,
"usage",
getattr(complete_streaming_response, "usage"),
)
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler,
@ -8464,7 +8377,11 @@ class CustomStreamWrapper:
)
response = self.model_response_creator()
if complete_streaming_response is not None:
response.usage = complete_streaming_response.usage
setattr(
response,
"usage",
getattr(complete_streaming_response, "usage"),
)
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler,
@ -8898,7 +8815,7 @@ def trim_messages(
if len(tool_messages):
messages = messages[: -len(tool_messages)]
current_tokens = token_counter(model=model, messages=messages)
current_tokens = token_counter(model=model or "", messages=messages)
print_verbose(f"Current tokens: {current_tokens}, max tokens: {max_tokens}")
# Do nothing if current tokens under messages
@ -8909,6 +8826,7 @@ def trim_messages(
print_verbose(
f"Need to trim input messages: {messages}, current_tokens{current_tokens}, max_tokens: {max_tokens}"
)
system_message_event: Optional[dict] = None
if system_message:
system_message_event, max_tokens = process_system_message(
system_message=system_message, max_tokens=max_tokens, model=model
@ -8926,7 +8844,7 @@ def trim_messages(
)
# Add system message to the beginning of the final messages
if system_message:
if system_message_event:
final_messages = [system_message_event] + final_messages
if len(tool_messages) > 0:
@ -9214,6 +9132,8 @@ def is_cached_message(message: AllMessageValues) -> bool:
Follows the anthropic format {"cache_control": {"type": "ephemeral"}}
"""
if "content" not in message:
return False
if message["content"] is None or isinstance(message["content"], str):
return False

View file

@ -1,6 +1,7 @@
{
"ignore": [],
"exclude": ["**/node_modules", "**/__pycache__", "litellm/tests", "litellm/main.py", "litellm/utils.py", "litellm/types/utils.py"],
"reportMissingImports": false
"exclude": ["**/node_modules", "**/__pycache__", "litellm/types/utils.py"],
"reportMissingImports": false,
"reportPrivateImportUsage": false
}

View file

@ -14,7 +14,7 @@ from typing import Optional
import httpx
from litellm.integrations.SlackAlerting.types import AlertType
from litellm.types.integrations.slack_alerting import AlertType
# import logging
# logging.basicConfig(level=logging.DEBUG)

View file

@ -1188,13 +1188,36 @@ def test_completion_cost_anthropic_prompt_caching():
system_fingerprint=None,
usage=Usage(
completion_tokens=10,
prompt_tokens=14,
total_tokens=24,
prompt_tokens=114,
total_tokens=124,
prompt_tokens_details=PromptTokensDetails(cached_tokens=0),
cache_creation_input_tokens=100,
cache_read_input_tokens=0,
),
)
cost_1 = completion_cost(model=model, completion_response=response_1)
_model_info = litellm.get_model_info(
model="claude-3-5-sonnet-20240620", custom_llm_provider="anthropic"
)
expected_cost = (
(
response_1.usage.prompt_tokens
- response_1.usage.prompt_tokens_details.cached_tokens
)
* _model_info["input_cost_per_token"]
+ response_1.usage.prompt_tokens_details.cached_tokens
* _model_info["cache_read_input_token_cost"]
+ response_1.usage.cache_creation_input_tokens
* _model_info["cache_creation_input_token_cost"]
+ response_1.usage.completion_tokens * _model_info["output_cost_per_token"]
) # Cost of processing (non-cache hit + cache hit) + Cost of cache-writing (cache writing)
assert round(expected_cost, 5) == round(cost_1, 5)
print(f"expected_cost: {expected_cost}, cost_1: {cost_1}")
## READ FROM CACHE ## (LESS EXPENSIVE)
response_2 = ModelResponse(
id="chatcmpl-3f427194-0840-4d08-b571-56bfe38a5424",
@ -1216,14 +1239,14 @@ def test_completion_cost_anthropic_prompt_caching():
system_fingerprint=None,
usage=Usage(
completion_tokens=10,
prompt_tokens=14,
total_tokens=24,
prompt_tokens=114,
total_tokens=134,
prompt_tokens_details=PromptTokensDetails(cached_tokens=100),
cache_creation_input_tokens=0,
cache_read_input_tokens=100,
),
)
cost_1 = completion_cost(model=model, completion_response=response_1)
cost_2 = completion_cost(model=model, completion_response=response_2)
assert cost_1 > cost_2

View file

@ -10,12 +10,40 @@ import litellm
import pytest
def _usage_format_tests(usage: litellm.Usage):
"""
OpenAI prompt caching
- prompt_tokens = sum of non-cache hit tokens + cache-hit tokens
- total_tokens = prompt_tokens + completion_tokens
Example
```
"usage": {
"prompt_tokens": 2006,
"completion_tokens": 300,
"total_tokens": 2306,
"prompt_tokens_details": {
"cached_tokens": 1920
},
"completion_tokens_details": {
"reasoning_tokens": 0
}
# ANTHROPIC_ONLY #
"cache_creation_input_tokens": 0
}
```
"""
assert usage.total_tokens == usage.prompt_tokens + usage.completion_tokens
assert usage.prompt_tokens > usage.prompt_tokens_details.cached_tokens
@pytest.mark.parametrize(
"model",
[
"anthropic/claude-3-5-sonnet-20240620",
"openai/gpt-4o",
"deepseek/deepseek-chat",
# "openai/gpt-4o",
# "deepseek/deepseek-chat",
],
)
def test_prompt_caching_model(model):
@ -66,9 +94,13 @@ def test_prompt_caching_model(model):
max_tokens=10,
)
_usage_format_tests(response.usage)
print("response=", response)
print("response.usage=", response.usage)
_usage_format_tests(response.usage)
assert "prompt_tokens_details" in response.usage
assert response.usage.prompt_tokens_details.cached_tokens > 0