diff --git a/.circleci/config.yml b/.circleci/config.yml index 6fa8775ba..748bf14f7 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -315,10 +315,10 @@ jobs: python -m pip install --upgrade pip pip install ruff pip install pylint + pip install pyright pip install . - run: python -c "from litellm import *" || (echo '🚨 import failed, this means you introduced unprotected imports! 🚨'; exit 1) - run: ruff check ./litellm - build_and_test: machine: diff --git a/enterprise/enterprise_callbacks/generic_api_callback.py b/enterprise/enterprise_callbacks/generic_api_callback.py index 8c868e328..0b6487a86 100644 --- a/enterprise/enterprise_callbacks/generic_api_callback.py +++ b/enterprise/enterprise_callbacks/generic_api_callback.py @@ -8,7 +8,7 @@ import requests from litellm.proxy._types import UserAPIKeyAuth from litellm.caching import DualCache -from typing import Literal, Union +from typing import Literal, Union, Optional import traceback @@ -26,9 +26,9 @@ from litellm._logging import print_verbose, verbose_logger class GenericAPILogger: # Class variables or attributes - def __init__(self, endpoint=None, headers=None): + def __init__(self, endpoint: Optional[str] = None, headers: Optional[dict] = None): try: - if endpoint == None: + if endpoint is None: # check env for "GENERIC_LOGGER_ENDPOINT" if os.getenv("GENERIC_LOGGER_ENDPOINT"): # Do something with the endpoint @@ -36,9 +36,15 @@ class GenericAPILogger: else: # Handle the case when the endpoint is not found in the environment variables raise ValueError( - f"endpoint not set for GenericAPILogger, GENERIC_LOGGER_ENDPOINT not found in environment variables" + "endpoint not set for GenericAPILogger, GENERIC_LOGGER_ENDPOINT not found in environment variables" ) headers = headers or litellm.generic_logger_headers + + if endpoint is None: + raise ValueError("endpoint not set for GenericAPILogger") + if headers is None: + raise ValueError("headers not set for GenericAPILogger") + self.endpoint = endpoint self.headers = headers diff --git a/enterprise/enterprise_hooks/aporia_ai.py b/enterprise/enterprise_hooks/aporia_ai.py index 2121f105d..9da4b891b 100644 --- a/enterprise/enterprise_hooks/aporia_ai.py +++ b/enterprise/enterprise_hooks/aporia_ai.py @@ -48,8 +48,6 @@ class AporiaGuardrail(CustomGuardrail): ) self.aporia_api_key = api_key or os.environ["APORIO_API_KEY"] self.aporia_api_base = api_base or os.environ["APORIO_API_BASE"] - self.event_hook: GuardrailEventHooks - super().__init__(**kwargs) #### CALL HOOKS - proxy only #### diff --git a/enterprise/enterprise_hooks/blocked_user_list.py b/enterprise/enterprise_hooks/blocked_user_list.py index 9bda140ba..0bcdbce0c 100644 --- a/enterprise/enterprise_hooks/blocked_user_list.py +++ b/enterprise/enterprise_hooks/blocked_user_list.py @@ -84,7 +84,7 @@ class _ENTERPRISE_BlockedUserList(CustomLogger): ) cache_key = f"litellm:end_user_id:{user}" - end_user_cache_obj: LiteLLM_EndUserTable = cache.get_cache( + end_user_cache_obj: Optional[LiteLLM_EndUserTable] = cache.get_cache( # type: ignore key=cache_key ) if end_user_cache_obj is None and self.prisma_client is not None: diff --git a/enterprise/enterprise_hooks/google_text_moderation.py b/enterprise/enterprise_hooks/google_text_moderation.py index c01d9f2d8..918e59f46 100644 --- a/enterprise/enterprise_hooks/google_text_moderation.py +++ b/enterprise/enterprise_hooks/google_text_moderation.py @@ -48,7 +48,7 @@ class _ENTERPRISE_GoogleTextModeration(CustomLogger): # Class variables or attributes def __init__(self): try: - from google.cloud import language_v1 + from google.cloud import language_v1 # type: ignore except Exception: raise Exception( "Missing google.cloud package. Run `pip install --upgrade google-cloud-language`" @@ -57,8 +57,8 @@ class _ENTERPRISE_GoogleTextModeration(CustomLogger): # Instantiates a client self.client = language_v1.LanguageServiceClient() self.moderate_text_request = language_v1.ModerateTextRequest - self.language_document = language_v1.types.Document - self.document_type = language_v1.types.Document.Type.PLAIN_TEXT + self.language_document = language_v1.types.Document # type: ignore + self.document_type = language_v1.types.Document.Type.PLAIN_TEXT # type: ignore default_confidence_threshold = ( litellm.google_moderation_confidence_threshold or 0.8 diff --git a/enterprise/enterprise_hooks/llama_guard.py b/enterprise/enterprise_hooks/llama_guard.py index 4005cd5ad..e87bb45ca 100644 --- a/enterprise/enterprise_hooks/llama_guard.py +++ b/enterprise/enterprise_hooks/llama_guard.py @@ -8,6 +8,7 @@ # Thank you users! We ❤️ you! - Krrish & Ishaan import sys, os +from collections.abc import Iterable sys.path.insert( 0, os.path.abspath("../..") @@ -19,11 +20,12 @@ from litellm.proxy._types import UserAPIKeyAuth from litellm.integrations.custom_logger import CustomLogger from fastapi import HTTPException from litellm._logging import verbose_proxy_logger -from litellm.utils import ( +from litellm.types.utils import ( ModelResponse, EmbeddingResponse, ImageResponse, StreamingChoices, + Choices, ) from datetime import datetime import aiohttp, asyncio @@ -34,7 +36,10 @@ litellm.set_verbose = True class _ENTERPRISE_LlamaGuard(CustomLogger): # Class variables or attributes def __init__(self, model_name: Optional[str] = None): - self.model = model_name or litellm.llamaguard_model_name + _model = model_name or litellm.llamaguard_model_name + if _model is None: + raise ValueError("model_name not set for LlamaGuard") + self.model = _model file_path = litellm.llamaguard_unsafe_content_categories data = None @@ -124,7 +129,13 @@ class _ENTERPRISE_LlamaGuard(CustomLogger): hf_model_name="meta-llama/LlamaGuard-7b", ) - if "unsafe" in response.choices[0].message.content: + if ( + isinstance(response, ModelResponse) + and isinstance(response.choices[0], Choices) + and response.choices[0].message.content is not None + and isinstance(response.choices[0].message.content, Iterable) + and "unsafe" in response.choices[0].message.content + ): raise HTTPException( status_code=400, detail={"error": "Violated content safety policy"} ) diff --git a/enterprise/enterprise_hooks/llm_guard.py b/enterprise/enterprise_hooks/llm_guard.py index 4c919b111..b8c11ba0f 100644 --- a/enterprise/enterprise_hooks/llm_guard.py +++ b/enterprise/enterprise_hooks/llm_guard.py @@ -8,7 +8,11 @@ ## This provides an LLM Guard Integration for content moderation on the proxy from typing import Optional, Literal, Union -import litellm, traceback, sys, uuid, os +import litellm +import traceback +import sys +import uuid +import os from litellm.caching import DualCache from litellm.proxy._types import UserAPIKeyAuth from litellm.integrations.custom_logger import CustomLogger @@ -21,8 +25,10 @@ from litellm.utils import ( StreamingChoices, ) from datetime import datetime -import aiohttp, asyncio +import aiohttp +import asyncio from litellm.utils import get_formatted_prompt +from litellm.secret_managers.main import get_secret_str litellm.set_verbose = True @@ -38,7 +44,7 @@ class _ENTERPRISE_LLMGuard(CustomLogger): self.llm_guard_mode = litellm.llm_guard_mode if mock_testing == True: # for testing purposes only return - self.llm_guard_api_base = litellm.get_secret("LLM_GUARD_API_BASE", None) + self.llm_guard_api_base = get_secret_str("LLM_GUARD_API_BASE", None) if self.llm_guard_api_base is None: raise Exception("Missing `LLM_GUARD_API_BASE` from environment") elif not self.llm_guard_api_base.endswith("/"): diff --git a/enterprise/enterprise_hooks/openai_moderation.py b/enterprise/enterprise_hooks/openai_moderation.py index 5fcd8dba3..a6806ae8a 100644 --- a/enterprise/enterprise_hooks/openai_moderation.py +++ b/enterprise/enterprise_hooks/openai_moderation.py @@ -51,8 +51,8 @@ class _ENTERPRISE_OpenAI_Moderation(CustomLogger): "audio_transcription", ], ): + text = "" if "messages" in data and isinstance(data["messages"], list): - text = "" for m in data["messages"]: # assume messages is a list if "content" in m and isinstance(m["content"], str): text += m["content"] @@ -67,7 +67,7 @@ class _ENTERPRISE_OpenAI_Moderation(CustomLogger): ) verbose_proxy_logger.debug("Moderation response: %s", moderation_response) - if moderation_response.results[0].flagged == True: + if moderation_response.results[0].flagged is True: raise HTTPException( status_code=403, detail={"error": "Violated content safety policy"} ) diff --git a/enterprise/utils.py b/enterprise/utils.py index a0b79a61b..f0af1d676 100644 --- a/enterprise/utils.py +++ b/enterprise/utils.py @@ -6,7 +6,9 @@ import collections from datetime import datetime -async def get_spend_by_tags(start_date=None, end_date=None, prisma_client=None): +async def get_spend_by_tags( + prisma_client: PrismaClient, start_date=None, end_date=None +): response = await prisma_client.db.query_raw( """ SELECT diff --git a/litellm/_redis.py b/litellm/_redis.py index 0c8c3be68..289a7d4ae 100644 --- a/litellm/_redis.py +++ b/litellm/_redis.py @@ -191,7 +191,7 @@ def init_redis_cluster(redis_kwargs) -> redis.RedisCluster: new_startup_nodes.append(ClusterNode(**item)) redis_kwargs.pop("startup_nodes") - return redis.RedisCluster(startup_nodes=new_startup_nodes, **cluster_kwargs) + return redis.RedisCluster(startup_nodes=new_startup_nodes, **cluster_kwargs) # type: ignore def _init_redis_sentinel(redis_kwargs) -> redis.Redis: diff --git a/litellm/assistants/utils.py b/litellm/assistants/utils.py index ca5a1293d..f8fc6ee0a 100644 --- a/litellm/assistants/utils.py +++ b/litellm/assistants/utils.py @@ -1,5 +1,8 @@ -import litellm from typing import Optional, Union + +import litellm + +from ..exceptions import UnsupportedParamsError from ..types.llms.openai import * diff --git a/litellm/caching.py b/litellm/caching.py index f067bd03b..68e978e4e 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -10,6 +10,7 @@ import ast import asyncio import hashlib +import inspect import io import json import logging @@ -245,7 +246,8 @@ class RedisCache(BaseCache): self.redis_flush_size = redis_flush_size self.redis_version = "Unknown" try: - self.redis_version = self.redis_client.info()["redis_version"] + if not inspect.iscoroutinefunction(self.redis_client): + self.redis_version = self.redis_client.info()["redis_version"] # type: ignore except Exception: pass @@ -266,7 +268,8 @@ class RedisCache(BaseCache): ### SYNC HEALTH PING ### try: - self.redis_client.ping() + if hasattr(self.redis_client, "ping"): + self.redis_client.ping() # type: ignore except Exception as e: verbose_logger.error( "Error connecting to Sync Redis client", extra={"error": str(e)} @@ -308,7 +311,7 @@ class RedisCache(BaseCache): _redis_client = self.redis_client start_time = time.time() try: - result = _redis_client.incr(name=key, amount=value) + result: int = _redis_client.incr(name=key, amount=value) # type: ignore if ttl is not None: # check if key already has ttl, if not -> set ttl @@ -561,7 +564,7 @@ class RedisCache(BaseCache): f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}" ) try: - await redis_client.sadd(key, *value) + await redis_client.sadd(key, *value) # type: ignore if ttl is not None: _td = timedelta(seconds=ttl) await redis_client.expire(key, _td) @@ -712,7 +715,7 @@ class RedisCache(BaseCache): for cache_key in key_list: cache_key = self.check_and_fix_namespace(key=cache_key) _keys.append(cache_key) - results = self.redis_client.mget(keys=_keys) + results: List = self.redis_client.mget(keys=_keys) # type: ignore # Associate the results back with their keys. # 'results' is a list of values corresponding to the order of keys in 'key_list'. @@ -842,7 +845,7 @@ class RedisCache(BaseCache): print_verbose("Pinging Sync Redis Cache") start_time = time.time() try: - response = self.redis_client.ping() + response: bool = self.redis_client.ping() # type: ignore print_verbose(f"Redis Cache PING: {response}") ## LOGGING ## end_time = time.time() @@ -911,8 +914,8 @@ class RedisCache(BaseCache): async with _redis_client as redis_client: await redis_client.delete(*keys) - def client_list(self): - client_list = self.redis_client.client_list() + def client_list(self) -> List: + client_list: List = self.redis_client.client_list() # type: ignore return client_list def info(self): diff --git a/litellm/cost_calculator.py b/litellm/cost_calculator.py index ac7649e3c..d399fd196 100644 --- a/litellm/cost_calculator.py +++ b/litellm/cost_calculator.py @@ -39,8 +39,8 @@ from litellm.llms.fireworks_ai.cost_calculator import ( ) from litellm.llms.OpenAI.cost_calculation import cost_per_token as openai_cost_per_token from litellm.llms.together_ai.cost_calculator import get_model_params_and_category -from litellm.rerank_api.types import RerankResponse from litellm.types.llms.openai import HttpxBinaryResponseContent +from litellm.types.rerank import RerankResponse from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS from litellm.types.utils import PassthroughCallTypes, Usage from litellm.utils import ( diff --git a/litellm/integrations/SlackAlerting/slack_alerting.py b/litellm/integrations/SlackAlerting/slack_alerting.py index f9de9adcd..b276d37d7 100644 --- a/litellm/integrations/SlackAlerting/slack_alerting.py +++ b/litellm/integrations/SlackAlerting/slack_alerting.py @@ -39,11 +39,11 @@ from litellm.proxy._types import ( VirtualKeyEvent, WebhookEvent, ) +from litellm.types.integrations.slack_alerting import * from litellm.types.router import LiteLLM_Params from ..email_templates.templates import * from .batching_handler import send_to_webhook, squash_payloads -from .types import * from .utils import _add_langfuse_trace_id_to_alert, process_slack_alerting_variables diff --git a/litellm/integrations/custom_logger.py b/litellm/integrations/custom_logger.py index f708e641e..d330f4f17 100644 --- a/litellm/integrations/custom_logger.py +++ b/litellm/integrations/custom_logger.py @@ -172,14 +172,14 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac "moderation", "audio_transcription", ], - ): + ) -> Any: pass async def async_post_call_streaming_hook( self, user_api_key_dict: UserAPIKeyAuth, response: str, - ): + ) -> Any: pass #### SINGLE-USE #### - https://docs.litellm.ai/docs/observability/custom_callback#using-your-custom-callback-function diff --git a/litellm/integrations/email_alerting.py b/litellm/integrations/email_alerting.py index f0ad43535..c626c7efc 100644 --- a/litellm/integrations/email_alerting.py +++ b/litellm/integrations/email_alerting.py @@ -42,8 +42,10 @@ async def get_all_team_member_emails(team_id: Optional[str] = None) -> list: ) _team_member_user_ids: List[str] = [] for member in _team_members: - if member and isinstance(member, dict) and member.get("user_id") is not None: - _team_member_user_ids.append(member.get("user_id")) + if member and isinstance(member, dict): + _user_id = member.get("user_id") + if _user_id and isinstance(_user_id, str): + _team_member_user_ids.append(_user_id) sql_query = """ SELECT user_email diff --git a/litellm/integrations/lunary.py b/litellm/integrations/lunary.py index 59ee1557b..8eb8eef26 100644 --- a/litellm/integrations/lunary.py +++ b/litellm/integrations/lunary.py @@ -149,7 +149,7 @@ class LunaryLogger: else: error_obj = None - self.lunary_client.track_event( + self.lunary_client.track_event( # type: ignore type, "start", run_id, @@ -164,7 +164,7 @@ class LunaryLogger: params=extra, ) - self.lunary_client.track_event( + self.lunary_client.track_event( # type: ignore type, event, run_id, diff --git a/litellm/integrations/openmeter.py b/litellm/integrations/openmeter.py index 19460e001..b1621afc7 100644 --- a/litellm/integrations/openmeter.py +++ b/litellm/integrations/openmeter.py @@ -100,16 +100,14 @@ class OpenMeterLogger(CustomLogger): } try: - response = self.sync_http_handler.post( + self.sync_http_handler.post( url=_url, data=json.dumps(_data), headers=_headers, ) - - response.raise_for_status() + except httpx.HTTPStatusError as e: + raise Exception(f"OpenMeter logging error: {e.response.text}") except Exception as e: - if hasattr(response, "text"): - litellm.print_verbose(f"\nError Message: {response.text}") raise e async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): @@ -128,18 +126,12 @@ class OpenMeterLogger(CustomLogger): } try: - response = await self.async_http_handler.post( + await self.async_http_handler.post( url=_url, data=json.dumps(_data), headers=_headers, ) - - response.raise_for_status() except httpx.HTTPStatusError as e: - verbose_logger.error( - "Failed OpenMeter logging - {}".format(e.response.text) - ) - raise e + raise Exception(f"OpenMeter logging error: {e.response.text}") except Exception as e: - verbose_logger.error("Failed OpenMeter logging - {}".format(str(e))) raise e diff --git a/litellm/integrations/s3.py b/litellm/integrations/s3.py index d915100b0..1f82406e1 100644 --- a/litellm/integrations/s3.py +++ b/litellm/integrations/s3.py @@ -146,14 +146,14 @@ class S3Logger: import json - payload = json.dumps(payload) + payload_str = json.dumps(payload) - print_verbose(f"\ns3 Logger - Logging payload = {payload}") + print_verbose(f"\ns3 Logger - Logging payload = {payload_str}") response = self.s3_client.put_object( Bucket=self.bucket_name, Key=s3_object_key, - Body=payload, + Body=payload_str, ContentType="application/json", ContentLanguage="en", ContentDisposition=f'inline; filename="{s3_object_download_filename}"', diff --git a/litellm/integrations/traceloop.py b/litellm/integrations/traceloop.py index d2168caab..06ba4b7f7 100644 --- a/litellm/integrations/traceloop.py +++ b/litellm/integrations/traceloop.py @@ -1,6 +1,7 @@ import traceback -from litellm._logging import verbose_logger + import litellm +from litellm._logging import verbose_logger class TraceloopLogger: @@ -11,14 +12,15 @@ class TraceloopLogger: def __init__(self): try: - from traceloop.sdk.tracing.tracing import TracerWrapper + from opentelemetry.sdk.trace.export import ConsoleSpanExporter from traceloop.sdk import Traceloop from traceloop.sdk.instruments import Instruments - from opentelemetry.sdk.trace.export import ConsoleSpanExporter + from traceloop.sdk.tracing.tracing import TracerWrapper except ModuleNotFoundError as e: verbose_logger.error( f"Traceloop not installed, try running 'pip install traceloop-sdk' to fix this error: {e}\n{traceback.format_exc()}" ) + raise e Traceloop.init( app_name="Litellm-Server", @@ -38,8 +40,8 @@ class TraceloopLogger: status_message=None, ): from opentelemetry import trace - from opentelemetry.trace import SpanKind, Status, StatusCode from opentelemetry.semconv.ai import SpanAttributes + from opentelemetry.trace import SpanKind, Status, StatusCode try: print_verbose( @@ -94,7 +96,7 @@ class TraceloopLogger: ) if "temperature" in optional_params: span.set_attribute( - SpanAttributes.LLM_REQUEST_TEMPERATURE, + SpanAttributes.LLM_REQUEST_TEMPERATURE, # type: ignore kwargs.get("temperature"), ) diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 3673e6b67..b2b3f5392 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -32,8 +32,8 @@ from litellm.litellm_core_utils.redact_messages import ( redact_message_input_output_from_logging, ) from litellm.proxy._types import CommonProxyErrors -from litellm.rerank_api.types import RerankResponse from litellm.types.llms.openai import HttpxBinaryResponseContent +from litellm.types.rerank import RerankResponse from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS from litellm.types.utils import ( CallTypes, diff --git a/litellm/llms/AzureOpenAI/audio_transcriptions.py b/litellm/llms/AzureOpenAI/audio_transcriptions.py index cecdfdc21..efe183b9b 100644 --- a/litellm/llms/AzureOpenAI/audio_transcriptions.py +++ b/litellm/llms/AzureOpenAI/audio_transcriptions.py @@ -1,5 +1,5 @@ import uuid -from typing import Optional, Union +from typing import Any, Optional, Union import httpx from openai import AsyncAzureOpenAI, AzureOpenAI @@ -24,6 +24,7 @@ class AzureAudioTranscription(AzureChatCompletion): model: str, audio_file: FileTypes, optional_params: dict, + logging_obj: Any, model_response: TranscriptionResponse, timeout: float, max_retries: int, @@ -32,9 +33,8 @@ class AzureAudioTranscription(AzureChatCompletion): api_version: Optional[str] = None, client=None, azure_ad_token: Optional[str] = None, - logging_obj=None, atranscription: bool = False, - ): + ) -> TranscriptionResponse: data = {"model": model, "file": audio_file, **optional_params} # init AzureOpenAI Client @@ -59,7 +59,7 @@ class AzureAudioTranscription(AzureChatCompletion): azure_client_params["max_retries"] = max_retries if atranscription is True: - return self.async_audio_transcriptions( + return self.async_audio_transcriptions( # type: ignore audio_file=audio_file, data=data, model_response=model_response, @@ -105,7 +105,7 @@ class AzureAudioTranscription(AzureChatCompletion): original_response=stringified_response, ) hidden_params = {"model": "whisper-1", "custom_llm_provider": "azure"} - final_response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore + final_response: TranscriptionResponse = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore return final_response async def async_audio_transcriptions( @@ -114,12 +114,12 @@ class AzureAudioTranscription(AzureChatCompletion): data: dict, model_response: TranscriptionResponse, timeout: float, + azure_client_params: dict, + logging_obj: Any, api_key: Optional[str] = None, api_base: Optional[str] = None, client=None, - azure_client_params=None, max_retries=None, - logging_obj=None, ): response = None try: diff --git a/litellm/llms/AzureOpenAI/azure.py b/litellm/llms/AzureOpenAI/azure.py index 42c9f48f1..8a89970dc 100644 --- a/litellm/llms/AzureOpenAI/azure.py +++ b/litellm/llms/AzureOpenAI/azure.py @@ -1083,7 +1083,7 @@ class AzureChatCompletion(BaseLLM): azure_ad_token: Optional[str] = None, client=None, aembedding=None, - ): + ) -> litellm.EmbeddingResponse: super().embedding() if self._client_session is None: self._client_session = self.create_client_session() @@ -1128,7 +1128,7 @@ class AzureChatCompletion(BaseLLM): ) if aembedding is True: - response = self.aembedding( + return self.aembedding( # type: ignore data=data, input=input, logging_obj=logging_obj, @@ -1138,7 +1138,6 @@ class AzureChatCompletion(BaseLLM): timeout=timeout, client=client, ) - return response if client is None: azure_client = AzureOpenAI(**azure_client_params) # type: ignore else: @@ -1418,7 +1417,7 @@ class AzureChatCompletion(BaseLLM): logging_obj: LiteLLMLoggingObj, client=None, timeout=None, - ): + ) -> litellm.ImageResponse: response: Optional[dict] = None try: # response = await azure_client.images.generate(**data, timeout=timeout) @@ -1460,7 +1459,7 @@ class AzureChatCompletion(BaseLLM): additional_args={"complete_input_dict": data}, original_response=stringified_response, ) - return convert_to_model_response_object( + return convert_to_model_response_object( # type: ignore response_object=stringified_response, model_response_object=model_response, response_type="image_generation", @@ -1489,7 +1488,7 @@ class AzureChatCompletion(BaseLLM): azure_ad_token: Optional[str] = None, client=None, aimg_generation=None, - ): + ) -> litellm.ImageResponse: try: if model and len(model) > 0: model = model @@ -1531,8 +1530,7 @@ class AzureChatCompletion(BaseLLM): azure_client_params["azure_ad_token"] = azure_ad_token if aimg_generation is True: - response = self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params, timeout=timeout) # type: ignore - return response + return self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params, timeout=timeout) # type: ignore img_gen_api_base = self.create_azure_base_url( azure_client_params=azure_client_params, model=data.get("model", "") @@ -1742,9 +1740,9 @@ class AzureChatCompletion(BaseLLM): async def ahealth_check( self, model: Optional[str], - api_key: str, + api_key: Optional[str], api_base: str, - api_version: str, + api_version: Optional[str], timeout: float, mode: str, messages: Optional[list] = None, diff --git a/litellm/llms/OpenAI/audio_transcriptions.py b/litellm/llms/OpenAI/audio_transcriptions.py index cfa0b0b1a..d4523754c 100644 --- a/litellm/llms/OpenAI/audio_transcriptions.py +++ b/litellm/llms/OpenAI/audio_transcriptions.py @@ -77,15 +77,15 @@ class OpenAIAudioTranscription(OpenAIChatCompletion): model_response: TranscriptionResponse, timeout: float, max_retries: int, + logging_obj: LiteLLMLoggingObj, api_key: Optional[str], api_base: Optional[str], client=None, - logging_obj=None, atranscription: bool = False, - ): + ) -> TranscriptionResponse: data = {"model": model, "file": audio_file, **optional_params} if atranscription is True: - return self.async_audio_transcriptions( + return self.async_audio_transcriptions( # type: ignore audio_file=audio_file, data=data, model_response=model_response, @@ -97,7 +97,7 @@ class OpenAIAudioTranscription(OpenAIChatCompletion): logging_obj=logging_obj, ) - openai_client = self._get_openai_client( + openai_client: OpenAI = self._get_openai_client( # type: ignore is_async=False, api_key=api_key, api_base=api_base, @@ -123,7 +123,7 @@ class OpenAIAudioTranscription(OpenAIChatCompletion): original_response=stringified_response, ) hidden_params = {"model": "whisper-1", "custom_llm_provider": "openai"} - final_response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore + final_response: TranscriptionResponse = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore return final_response async def async_audio_transcriptions( @@ -139,7 +139,7 @@ class OpenAIAudioTranscription(OpenAIChatCompletion): max_retries=None, ): try: - openai_aclient = self._get_openai_client( + openai_aclient: AsyncOpenAI = self._get_openai_client( # type: ignore is_async=True, api_key=api_key, api_base=api_base, diff --git a/litellm/llms/OpenAI/openai.py b/litellm/llms/OpenAI/openai.py index 86b290ab6..3c60ac06a 100644 --- a/litellm/llms/OpenAI/openai.py +++ b/litellm/llms/OpenAI/openai.py @@ -1167,7 +1167,7 @@ class OpenAIChatCompletion(BaseLLM): api_base: Optional[str] = None, client=None, aembedding=None, - ): + ) -> litellm.EmbeddingResponse: super().embedding() try: model = model @@ -1183,7 +1183,7 @@ class OpenAIChatCompletion(BaseLLM): ) if aembedding is True: - async_response = self.aembedding( + return self.aembedding( # type: ignore data=data, input=input, logging_obj=logging_obj, @@ -1194,7 +1194,6 @@ class OpenAIChatCompletion(BaseLLM): client=client, max_retries=max_retries, ) - return async_response openai_client: OpenAI = self._get_openai_client( # type: ignore is_async=False, @@ -1294,7 +1293,7 @@ class OpenAIChatCompletion(BaseLLM): model_response: Optional[litellm.utils.ImageResponse] = None, client=None, aimg_generation=None, - ): + ) -> litellm.ImageResponse: data = {} try: model = model @@ -1304,8 +1303,7 @@ class OpenAIChatCompletion(BaseLLM): raise OpenAIError(status_code=422, message="max retries must be an int") if aimg_generation is True: - response = self.aimage_generation(data=data, prompt=prompt, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore - return response + return self.aimage_generation(data=data, prompt=prompt, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore openai_client = self._get_openai_client( is_async=False, @@ -1449,7 +1447,7 @@ class OpenAIChatCompletion(BaseLLM): async def ahealth_check( self, model: Optional[str], - api_key: str, + api_key: Optional[str], timeout: float, mode: str, messages: Optional[list] = None, diff --git a/litellm/llms/anthropic/chat/handler.py b/litellm/llms/anthropic/chat/handler.py index bd2a76f68..cd38e47e6 100644 --- a/litellm/llms/anthropic/chat/handler.py +++ b/litellm/llms/anthropic/chat/handler.py @@ -282,7 +282,6 @@ class AnthropicChatCompletion(BaseLLM): prompt_tokens = completion_response["usage"]["input_tokens"] completion_tokens = completion_response["usage"]["output_tokens"] _usage = completion_response["usage"] - total_tokens = prompt_tokens + completion_tokens cache_creation_input_tokens: int = 0 cache_read_input_tokens: int = 0 @@ -290,12 +289,15 @@ class AnthropicChatCompletion(BaseLLM): model_response.model = model if "cache_creation_input_tokens" in _usage: cache_creation_input_tokens = _usage["cache_creation_input_tokens"] + prompt_tokens += cache_creation_input_tokens if "cache_read_input_tokens" in _usage: cache_read_input_tokens = _usage["cache_read_input_tokens"] + prompt_tokens += cache_read_input_tokens prompt_tokens_details = PromptTokensDetails( cached_tokens=cache_read_input_tokens ) + total_tokens = prompt_tokens + completion_tokens usage = Usage( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, diff --git a/litellm/llms/anthropic/cost_calculation.py b/litellm/llms/anthropic/cost_calculation.py index d1742aae9..63075b82f 100644 --- a/litellm/llms/anthropic/cost_calculation.py +++ b/litellm/llms/anthropic/cost_calculation.py @@ -24,16 +24,33 @@ def cost_per_token(model: str, usage: Usage) -> Tuple[float, float]: model_info = get_model_info(model=model, custom_llm_provider="anthropic") ## CALCULATE INPUT COST + ### Cost of processing (non-cache hit + cache hit) + Cost of cache-writing (cache writing) + prompt_cost = 0.0 + ### PROCESSING COST + non_cache_hit_tokens = usage.prompt_tokens + cache_hit_tokens = 0 + if usage.prompt_tokens_details and usage.prompt_tokens_details.cached_tokens: + cache_hit_tokens = usage.prompt_tokens_details.cached_tokens + non_cache_hit_tokens = non_cache_hit_tokens - cache_hit_tokens - prompt_cost: float = usage["prompt_tokens"] * model_info["input_cost_per_token"] - if model_info.get("cache_creation_input_token_cost") is not None: + prompt_cost = float(non_cache_hit_tokens) * model_info["input_cost_per_token"] + + _cache_read_input_token_cost = model_info.get("cache_read_input_token_cost") + if ( + _cache_read_input_token_cost is not None + and usage.prompt_tokens_details + and usage.prompt_tokens_details.cached_tokens + ): prompt_cost += ( - usage._cache_creation_input_tokens # type: ignore - * model_info["cache_creation_input_token_cost"] + float(usage.prompt_tokens_details.cached_tokens) + * _cache_read_input_token_cost ) - if model_info.get("cache_read_input_token_cost") is not None: + + ### CACHE WRITING COST + _cache_creation_input_token_cost = model_info.get("cache_creation_input_token_cost") + if _cache_creation_input_token_cost is not None: prompt_cost += ( - usage._cache_read_input_tokens * model_info["cache_read_input_token_cost"] # type: ignore + float(usage._cache_creation_input_tokens) * _cache_creation_input_token_cost ) ## CALCULATE OUTPUT COST diff --git a/litellm/llms/azure_ai/embed/handler.py b/litellm/llms/azure_ai/embed/handler.py index 2428119b7..682e7e654 100644 --- a/litellm/llms/azure_ai/embed/handler.py +++ b/litellm/llms/azure_ai/embed/handler.py @@ -216,7 +216,7 @@ class AzureAIEmbedding(OpenAIChatCompletion): api_base: Optional[str] = None, client=None, aembedding=None, - ): + ) -> litellm.EmbeddingResponse: """ - Separate image url from text -> route image url call to `/image/embeddings` @@ -225,7 +225,7 @@ class AzureAIEmbedding(OpenAIChatCompletion): assemble result in-order, and return """ if aembedding is True: - return self.async_embedding( + return self.async_embedding( # type: ignore model, input, timeout, diff --git a/litellm/llms/azure_ai/rerank/handler.py b/litellm/llms/azure_ai/rerank/handler.py index 9910086fc..a67c893f2 100644 --- a/litellm/llms/azure_ai/rerank/handler.py +++ b/litellm/llms/azure_ai/rerank/handler.py @@ -4,7 +4,7 @@ import httpx from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.llms.cohere.rerank import CohereRerank -from litellm.rerank_api.types import RerankResponse +from litellm.types.rerank import RerankResponse class AzureAIRerank(CohereRerank): diff --git a/litellm/llms/bedrock/image_generation.py b/litellm/llms/bedrock/image_generation.py index a6ddd38cb..65038d12e 100644 --- a/litellm/llms/bedrock/image_generation.py +++ b/litellm/llms/bedrock/image_generation.py @@ -5,7 +5,7 @@ Handles image gen calls to Bedrock's `/invoke` endpoint import copy import json import os -from typing import List +from typing import Any, List from openai.types.image import Image @@ -20,8 +20,8 @@ def image_generation( prompt: str, model_response: ImageResponse, optional_params: dict, + logging_obj: Any, timeout=None, - logging_obj=None, aimg_generation=False, ): """ diff --git a/litellm/llms/cohere/rerank.py b/litellm/llms/cohere/rerank.py index a41c3dfb2..022ffc6f9 100644 --- a/litellm/llms/cohere/rerank.py +++ b/litellm/llms/cohere/rerank.py @@ -13,7 +13,7 @@ from litellm.llms.custom_httpx.http_handler import ( _get_httpx_client, get_async_httpx_client, ) -from litellm.rerank_api.types import RerankRequest, RerankResponse +from litellm.types.rerank import RerankRequest, RerankResponse class CohereRerank(BaseLLM): diff --git a/litellm/llms/fine_tuning_apis/azure.py b/litellm/llms/fine_tuning_apis/azure.py index ff7d40ff8..3e9c335e1 100644 --- a/litellm/llms/fine_tuning_apis/azure.py +++ b/litellm/llms/fine_tuning_apis/azure.py @@ -40,7 +40,7 @@ class AzureOpenAIFineTuningAPI(BaseLLM): organization: Optional[str] = None, client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, api_version: Optional[str] = None, - ) -> Union[FineTuningJob, Union[Coroutine[Any, Any, FineTuningJob]]]: + ) -> Union[FineTuningJob, Coroutine[Any, Any, FineTuningJob]]: openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( get_azure_openai_client( api_key=api_key, diff --git a/litellm/llms/fine_tuning_apis/openai.py b/litellm/llms/fine_tuning_apis/openai.py index 6f3cd6021..7ce8c3536 100644 --- a/litellm/llms/fine_tuning_apis/openai.py +++ b/litellm/llms/fine_tuning_apis/openai.py @@ -68,7 +68,7 @@ class OpenAIFineTuningAPI(BaseLLM): max_retries: Optional[int], organization: Optional[str], client: Optional[Union[OpenAI, AsyncOpenAI]] = None, - ) -> Union[FineTuningJob, Union[Coroutine[Any, Any, FineTuningJob]]]: + ) -> Union[FineTuningJob, Coroutine[Any, Any, FineTuningJob]]: openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client( api_key=api_key, api_base=api_base, diff --git a/litellm/llms/huggingface_restapi.py b/litellm/llms/huggingface_restapi.py index 183949ff9..973cded0b 100644 --- a/litellm/llms/huggingface_restapi.py +++ b/litellm/llms/huggingface_restapi.py @@ -554,7 +554,7 @@ class Huggingface(BaseLLM): else: prompt = prompt_factory(model=model, messages=messages) data = { - "inputs": prompt, + "inputs": prompt, # type: ignore "parameters": optional_params, "stream": ( # type: ignore True @@ -589,7 +589,7 @@ class Huggingface(BaseLLM): inference_params.pop("details") inference_params.pop("return_full_text") data = { - "inputs": prompt, + "inputs": prompt, # type: ignore } if task == "text-generation-inference": data["parameters"] = inference_params diff --git a/litellm/llms/maritalk.py b/litellm/llms/maritalk.py index 63060b48a..813dfa8ea 100644 --- a/litellm/llms/maritalk.py +++ b/litellm/llms/maritalk.py @@ -4,7 +4,7 @@ import time import traceback import types from enum import Enum -from typing import Callable, List, Optional +from typing import Any, Callable, List, Optional import requests # type: ignore @@ -185,8 +185,8 @@ def completion( def embedding( model: str, input: list, - api_key: Optional[str] = None, - logging_obj=None, + api_key: Optional[str], + logging_obj: Any, model_response=None, encoding=None, ): diff --git a/litellm/llms/oobabooga.py b/litellm/llms/oobabooga.py index 927844c2f..d47e56311 100644 --- a/litellm/llms/oobabooga.py +++ b/litellm/llms/oobabooga.py @@ -2,7 +2,7 @@ import json import os import time from enum import Enum -from typing import Callable, Optional +from typing import Any, Callable, Optional import requests # type: ignore @@ -124,9 +124,9 @@ def embedding( model: str, input: list, model_response: EmbeddingResponse, - api_key: Optional[str] = None, - api_base: Optional[str] = None, - logging_obj=None, + api_key: Optional[str], + api_base: Optional[str], + logging_obj: Any, optional_params=None, encoding=None, ): diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index 2a9453031..f5fdee99a 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -2714,7 +2714,7 @@ def custom_prompt( final_prompt_value: str = "", bos_token: str = "", eos_token: str = "", -): +) -> str: prompt = bos_token + initial_prompt_value bos_open = True ## a bos token is at the start of a system / human message diff --git a/litellm/llms/together_ai/rerank.py b/litellm/llms/together_ai/rerank.py index ea57c46c7..1be73af2d 100644 --- a/litellm/llms/together_ai/rerank.py +++ b/litellm/llms/together_ai/rerank.py @@ -15,7 +15,7 @@ from litellm.llms.custom_httpx.http_handler import ( _get_httpx_client, get_async_httpx_client, ) -from litellm.rerank_api.types import RerankRequest, RerankResponse +from litellm.types.rerank import RerankRequest, RerankResponse class TogetherAIRerank(BaseLLM): diff --git a/litellm/llms/triton.py b/litellm/llms/triton.py index 14a2e828b..be4179ccc 100644 --- a/litellm/llms/triton.py +++ b/litellm/llms/triton.py @@ -47,7 +47,7 @@ class TritonChatCompletion(BaseLLM): data: dict, model_response: litellm.utils.EmbeddingResponse, api_base: str, - logging_obj=None, + logging_obj: Any, api_key: Optional[str] = None, ) -> EmbeddingResponse: async_handler = AsyncHTTPHandler( @@ -93,9 +93,9 @@ class TritonChatCompletion(BaseLLM): timeout: float, api_base: str, model_response: litellm.utils.EmbeddingResponse, + logging_obj: Any, + optional_params: dict, api_key: Optional[str] = None, - logging_obj=None, - optional_params=None, client=None, aembedding: bool = False, ) -> EmbeddingResponse: @@ -122,7 +122,7 @@ class TritonChatCompletion(BaseLLM): ) if aembedding: - response = await self.aembedding( + response = await self.aembedding( # type: ignore data=data_for_triton, model_response=model_response, logging_obj=logging_obj, @@ -141,10 +141,10 @@ class TritonChatCompletion(BaseLLM): messages: List[dict], timeout: float, api_base: str, + logging_obj: Any, + optional_params: dict, model_response: ModelResponse, api_key: Optional[str] = None, - logging_obj=None, - optional_params=None, client=None, stream: Optional[bool] = False, acompletion: bool = False, @@ -239,11 +239,13 @@ class TritonChatCompletion(BaseLLM): else: handler = HTTPHandler() if stream: - return self._handle_stream( + return self._handle_stream( # type: ignore handler, api_base, json_data_for_triton, model, logging_obj ) else: - response = handler.post(url=api_base, data=json_data_for_triton, headers=headers) + response = handler.post( + url=api_base, data=json_data_for_triton, headers=headers + ) return self._handle_response( response, model_response, logging_obj, type_of_model=type_of_model ) @@ -261,7 +263,7 @@ class TritonChatCompletion(BaseLLM): ) -> ModelResponse: handler = AsyncHTTPHandler() if stream: - return self._ahandle_stream( + return self._ahandle_stream( # type: ignore handler, api_base, data_for_triton, model, logging_obj ) else: diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/common_utils.py b/litellm/llms/vertex_ai_and_google_ai_studio/common_utils.py index 0649b93ea..5fdd8e40c 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/common_utils.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/common_utils.py @@ -2,7 +2,7 @@ from typing import List, Literal, Tuple import httpx -from litellm import supports_system_messages, supports_response_schema, verbose_logger +from litellm import supports_response_schema, supports_system_messages, verbose_logger from litellm.types.llms.vertex_ai import PartType @@ -67,6 +67,8 @@ def _get_vertex_url( vertex_location: Optional[str], vertex_api_version: Literal["v1", "v1beta1"], ) -> Tuple[str, str]: + url: Optional[str] = None + endpoint: Optional[str] = None if mode == "chat": ### SET RUNTIME ENDPOINT ### endpoint = "generateContent" @@ -88,6 +90,8 @@ def _get_vertex_url( endpoint = "predict" url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}" + if not url or not endpoint: + raise ValueError(f"Unable to get vertex url/endpoinit for mode: {mode}") return url, endpoint diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/gemini_embeddings/batch_embed_content_handler.py b/litellm/llms/vertex_ai_and_google_ai_studio/gemini_embeddings/batch_embed_content_handler.py index 1904ff7e4..314e129c2 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/gemini_embeddings/batch_embed_content_handler.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/gemini_embeddings/batch_embed_content_handler.py @@ -3,7 +3,7 @@ Google AI Studio /batchEmbedContents Embeddings Endpoint """ import json -from typing import List, Literal, Optional, Union +from typing import Any, List, Literal, Optional, Union import httpx @@ -31,9 +31,9 @@ class GoogleBatchEmbeddings(VertexLLM): model_response: EmbeddingResponse, custom_llm_provider: Literal["gemini", "vertex_ai"], optional_params: dict, + logging_obj: Any, api_key: Optional[str] = None, api_base: Optional[str] = None, - logging_obj=None, encoding=None, vertex_project=None, vertex_location=None, diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/image_generation/image_generation_handler.py b/litellm/llms/vertex_ai_and_google_ai_studio/image_generation/image_generation_handler.py index e1969199b..1531464c8 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/image_generation/image_generation_handler.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/image_generation/image_generation_handler.py @@ -43,17 +43,17 @@ class VertexImageGeneration(VertexLLM): vertex_location: Optional[str], vertex_credentials: Optional[str], model_response: litellm.ImageResponse, + logging_obj: Any, model: Optional[ str ] = "imagegeneration", # vertex ai uses imagegeneration as the default model client: Optional[Any] = None, optional_params: Optional[dict] = None, timeout: Optional[int] = None, - logging_obj=None, aimg_generation=False, - ): + ) -> litellm.ImageResponse: if aimg_generation is True: - return self.aimage_generation( + return self.aimage_generation( # type: ignore prompt=prompt, vertex_project=vertex_project, vertex_location=vertex_location, @@ -138,13 +138,13 @@ class VertexImageGeneration(VertexLLM): vertex_location: Optional[str], vertex_credentials: Optional[str], model_response: litellm.ImageResponse, + logging_obj: Any, model: Optional[ str ] = "imagegeneration", # vertex ai uses imagegeneration as the default model client: Optional[AsyncHTTPHandler] = None, optional_params: Optional[dict] = None, timeout: Optional[int] = None, - logging_obj=None, ): response = None if client is None: diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/multimodal_embeddings/embedding_handler.py b/litellm/llms/vertex_ai_and_google_ai_studio/multimodal_embeddings/embedding_handler.py index 0eda7d875..d8af891b0 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/multimodal_embeddings/embedding_handler.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/multimodal_embeddings/embedding_handler.py @@ -48,7 +48,7 @@ class VertexMultimodalEmbedding(VertexLLM): aembedding=False, timeout=300, client=None, - ): + ) -> litellm.EmbeddingResponse: _auth_header, vertex_project = self._ensure_access_token( credentials=vertex_credentials, @@ -121,7 +121,7 @@ class VertexMultimodalEmbedding(VertexLLM): ) if aembedding is True: - return self.async_multimodal_embedding( + return self.async_multimodal_embedding( # type: ignore model=model, api_base=url, data=request_data, diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/text_to_speech/text_to_speech_handler.py b/litellm/llms/vertex_ai_and_google_ai_studio/text_to_speech/text_to_speech_handler.py index 08d139d2d..170c2765d 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/text_to_speech/text_to_speech_handler.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/text_to_speech/text_to_speech_handler.py @@ -62,7 +62,7 @@ class VertexTextToSpeechAPI(VertexLLM): _is_async: Optional[bool] = False, optional_params: Optional[dict] = None, kwargs: Optional[dict] = None, - ): + ) -> HttpxBinaryResponseContent: import base64 ####### Authenticate with Vertex AI ######## @@ -145,7 +145,7 @@ class VertexTextToSpeechAPI(VertexLLM): ########## End of logging ############ ####### Send the request ################### if _is_async is True: - return self.async_audio_speech( + return self.async_audio_speech( # type:ignore logging_obj=logging_obj, url=url, headers=headers, request=request ) sync_handler = _get_httpx_client() diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/embedding_handler.py b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/embedding_handler.py index 682b2eb47..0cde5c3b5 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/embedding_handler.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_embeddings/embedding_handler.py @@ -52,9 +52,9 @@ class VertexEmbedding(VertexBase): vertex_credentials: Optional[str] = None, gemini_api_key: Optional[str] = None, extra_headers: Optional[dict] = None, - ): + ) -> litellm.EmbeddingResponse: if aembedding is True: - return self.async_embedding( + return self.async_embedding( # type: ignore model=model, input=input, logging_obj=logging_obj, diff --git a/litellm/llms/watsonx.py b/litellm/llms/watsonx.py index c01efd8ad..a8a6585cc 100644 --- a/litellm/llms/watsonx.py +++ b/litellm/llms/watsonx.py @@ -25,13 +25,8 @@ import requests # type: ignore import litellm from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler -from litellm.utils import ( - EmbeddingResponse, - ModelResponse, - Usage, - get_secret, - map_finish_reason, -) +from litellm.secret_managers.main import get_secret_str +from litellm.utils import EmbeddingResponse, ModelResponse, Usage, map_finish_reason from .base import BaseLLM from .prompt_templates import factory as ptf @@ -184,7 +179,7 @@ class IBMWatsonXAIConfig: ] -def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict): +def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict) -> str: # handle anthropic prompts and amazon titan prompts if model in custom_prompt_dict: # check if the model has a registered custom prompt @@ -200,14 +195,10 @@ def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict): eos_token=model_prompt_dict.get("eos_token", ""), ) return prompt - elif provider == "ibm": - prompt = ptf.prompt_factory( - model=model, messages=messages, custom_llm_provider="watsonx" - ) elif provider == "ibm-mistralai": prompt = ptf.mistral_instruct_pt(messages=messages) else: - prompt = ptf.prompt_factory( + prompt: str = ptf.prompt_factory( # type: ignore model=model, messages=messages, custom_llm_provider="watsonx" ) return prompt @@ -327,37 +318,37 @@ class IBMWatsonXAI(BaseLLM): # Load auth variables from environment variables if url is None: url = ( - get_secret("WATSONX_API_BASE") # consistent with 'AZURE_API_BASE' - or get_secret("WATSONX_URL") - or get_secret("WX_URL") - or get_secret("WML_URL") + get_secret_str("WATSONX_API_BASE") # consistent with 'AZURE_API_BASE' + or get_secret_str("WATSONX_URL") + or get_secret_str("WX_URL") + or get_secret_str("WML_URL") ) if api_key is None: api_key = ( - get_secret("WATSONX_APIKEY") - or get_secret("WATSONX_API_KEY") - or get_secret("WX_API_KEY") + get_secret_str("WATSONX_APIKEY") + or get_secret_str("WATSONX_API_KEY") + or get_secret_str("WX_API_KEY") ) if token is None: - token = get_secret("WATSONX_TOKEN") or get_secret("WX_TOKEN") + token = get_secret_str("WATSONX_TOKEN") or get_secret_str("WX_TOKEN") if project_id is None: project_id = ( - get_secret("WATSONX_PROJECT_ID") - or get_secret("WX_PROJECT_ID") - or get_secret("PROJECT_ID") + get_secret_str("WATSONX_PROJECT_ID") + or get_secret_str("WX_PROJECT_ID") + or get_secret_str("PROJECT_ID") ) if region_name is None: region_name = ( - get_secret("WATSONX_REGION") - or get_secret("WX_REGION") - or get_secret("REGION") + get_secret_str("WATSONX_REGION") + or get_secret_str("WX_REGION") + or get_secret_str("REGION") ) if space_id is None: space_id = ( - get_secret("WATSONX_DEPLOYMENT_SPACE_ID") - or get_secret("WATSONX_SPACE_ID") - or get_secret("WX_SPACE_ID") - or get_secret("SPACE_ID") + get_secret_str("WATSONX_DEPLOYMENT_SPACE_ID") + or get_secret_str("WATSONX_SPACE_ID") + or get_secret_str("WX_SPACE_ID") + or get_secret_str("SPACE_ID") ) # credentials parsing @@ -446,8 +437,8 @@ class IBMWatsonXAI(BaseLLM): model_response: ModelResponse, print_verbose: Callable, encoding, - logging_obj, - optional_params=None, + logging_obj: Any, + optional_params: dict, acompletion=None, litellm_params=None, logger_fn=None, @@ -592,13 +583,13 @@ class IBMWatsonXAI(BaseLLM): model: str, input: Union[list, str], model_response: litellm.EmbeddingResponse, - api_key: Optional[str] = None, - logging_obj=None, - optional_params=None, + api_key: Optional[str], + logging_obj: Any, + optional_params: dict, encoding=None, print_verbose=None, aembedding=None, - ): + ) -> litellm.EmbeddingResponse: """ Send a text embedding request to the IBM Watsonx.ai API. """ @@ -657,7 +648,7 @@ class IBMWatsonXAI(BaseLLM): try: if aembedding is True: - return handle_aembedding(req_params) + return handle_aembedding(req_params) # type: ignore else: return handle_embedding(req_params) except WatsonXAIError as e: @@ -669,7 +660,7 @@ class IBMWatsonXAI(BaseLLM): headers = {} headers["Content-Type"] = "application/x-www-form-urlencoded" if api_key is None: - api_key = get_secret("WX_API_KEY") or get_secret("WATSONX_API_KEY") + api_key = get_secret_str("WX_API_KEY") or get_secret_str("WATSONX_API_KEY") if api_key is None: raise ValueError("API key is required") headers["Accept"] = "application/json" @@ -812,22 +803,29 @@ class RequestManager: request_params["data"] = json.dumps(request_params.pop("json", {})) method = request_params.pop("method") retries = 0 + resp: Optional[httpx.Response] = None while retries < 3: if method.upper() == "POST": resp = await self.async_handler.post(**request_params) else: resp = await self.async_handler.get(**request_params) - if resp.status_code in [429, 503, 504, 520]: + if resp is not None and resp.status_code in [429, 503, 504, 520]: # to handle rate limiting and service unavailable errors # see: ibm_watsonx_ai.foundation_models.inference.base_model_inference.BaseModelInference._send_inference_payload await asyncio.sleep(2**retries) retries += 1 else: break + if resp is None: + raise WatsonXAIError( + status_code=500, + message="No response from the server", + ) if resp.is_error: + error_reason = getattr(resp, "reason", "") raise WatsonXAIError( status_code=resp.status_code, - message=f"Error {resp.status_code} ({resp.reason}): {resp.text}", + message=f"Error {resp.status_code} ({error_reason}): {resp.text}", ) yield resp # await async_handler.close() diff --git a/litellm/main.py b/litellm/main.py index f54f70a21..87c169f4e 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -19,7 +19,8 @@ import threading import time import traceback import uuid -from concurrent.futures import ThreadPoolExecutor +from concurrent import futures +from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait from copy import deepcopy from functools import partial from typing import Any, Callable, Dict, List, Literal, Mapping, Optional, Type, Union @@ -647,7 +648,7 @@ def mock_completion( @client -def completion( +def completion( # type: ignore model: str, # Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create messages: List = [], @@ -2940,16 +2941,16 @@ def completion_with_retries(*args, **kwargs): num_retries = kwargs.pop("num_retries", 3) retry_strategy: Literal["exponential_backoff_retry", "constant_retry"] = kwargs.pop("retry_strategy", "constant_retry") # type: ignore original_function = kwargs.pop("original_function", completion) - if retry_strategy == "constant_retry": - retryer = tenacity.Retrying( - stop=tenacity.stop_after_attempt(num_retries), reraise=True - ) - elif retry_strategy == "exponential_backoff_retry": + if retry_strategy == "exponential_backoff_retry": retryer = tenacity.Retrying( wait=tenacity.wait_exponential(multiplier=1, max=10), stop=tenacity.stop_after_attempt(num_retries), reraise=True, ) + else: + retryer = tenacity.Retrying( + stop=tenacity.stop_after_attempt(num_retries), reraise=True + ) return retryer(original_function, *args, **kwargs) @@ -2968,16 +2969,16 @@ async def acompletion_with_retries(*args, **kwargs): num_retries = kwargs.pop("num_retries", 3) retry_strategy = kwargs.pop("retry_strategy", "constant_retry") original_function = kwargs.pop("original_function", completion) - if retry_strategy == "constant_retry": - retryer = tenacity.Retrying( - stop=tenacity.stop_after_attempt(num_retries), reraise=True - ) - elif retry_strategy == "exponential_backoff_retry": + if retry_strategy == "exponential_backoff_retry": retryer = tenacity.Retrying( wait=tenacity.wait_exponential(multiplier=1, max=10), stop=tenacity.stop_after_attempt(num_retries), reraise=True, ) + else: + retryer = tenacity.Retrying( + stop=tenacity.stop_after_attempt(num_retries), reraise=True + ) return await retryer(original_function, *args, **kwargs) @@ -3045,7 +3046,7 @@ def batch_completion( temperature=temperature, top_p=top_p, n=n, - stream=stream, + stream=stream or False, stop=stop, max_tokens=max_tokens, presence_penalty=presence_penalty, @@ -3124,7 +3125,7 @@ def batch_completion_models(*args, **kwargs): models = kwargs["models"] kwargs.pop("models") futures = {} - with concurrent.futures.ThreadPoolExecutor(max_workers=len(models)) as executor: + with ThreadPoolExecutor(max_workers=len(models)) as executor: for model in models: futures[model] = executor.submit( completion, *args, model=model, **kwargs @@ -3141,9 +3142,7 @@ def batch_completion_models(*args, **kwargs): kwargs.pop("model_list") nested_kwargs = kwargs.pop("kwargs", {}) futures = {} - with concurrent.futures.ThreadPoolExecutor( - max_workers=len(deployments) - ) as executor: + with ThreadPoolExecutor(max_workers=len(deployments)) as executor: for deployment in deployments: for key in kwargs.keys(): if ( @@ -3156,9 +3155,7 @@ def batch_completion_models(*args, **kwargs): while futures: # wait for the first returned future print_verbose("\n\n waiting for next result\n\n") - done, _ = concurrent.futures.wait( - futures.values(), return_when=concurrent.futures.FIRST_COMPLETED - ) + done, _ = wait(futures.values(), return_when=FIRST_COMPLETED) print_verbose(f"done list\n{done}") for future in done: try: @@ -3214,6 +3211,8 @@ def batch_completion_models_all_responses(*args, **kwargs): if "models" in kwargs: models = kwargs["models"] kwargs.pop("models") + else: + raise Exception("'models' param not in kwargs") responses = [] @@ -3256,6 +3255,7 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse: model=model, api_base=kwargs.get("api_base", None) ) + response: Optional[EmbeddingResponse] = None if ( custom_llm_provider == "openai" or custom_llm_provider == "azure" @@ -3294,12 +3294,21 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse: elif isinstance(init_response, EmbeddingResponse): ## CACHING SCENARIO response = init_response elif asyncio.iscoroutine(init_response): - response = await init_response + response = await init_response # type: ignore else: # Call the synchronous function using run_in_executor response = await loop.run_in_executor(None, func_with_context) - if response is not None and hasattr(response, "_hidden_params"): + if ( + response is not None + and isinstance(response, EmbeddingResponse) + and hasattr(response, "_hidden_params") + ): response._hidden_params["custom_llm_provider"] = custom_llm_provider + + if response is None: + raise ValueError( + "Unable to get Embedding Response. Please pass a valid llm_provider." + ) return response except Exception as e: custom_llm_provider = custom_llm_provider or "openai" @@ -3329,7 +3338,6 @@ def embedding( user: Optional[str] = None, custom_llm_provider=None, litellm_call_id=None, - litellm_logging_obj=None, logger_fn=None, **kwargs, ) -> EmbeddingResponse: @@ -3362,6 +3370,7 @@ def embedding( client = kwargs.pop("client", None) rpm = kwargs.pop("rpm", None) tpm = kwargs.pop("tpm", None) + litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore cooldown_time = kwargs.get("cooldown_time", None) max_parallel_requests = kwargs.pop("max_parallel_requests", None) model_info = kwargs.get("model_info", None) @@ -3491,7 +3500,7 @@ def embedding( } ) try: - response = None + response: Optional[EmbeddingResponse] = None logging: Logging = litellm_logging_obj # type: ignore logging.update_environment_variables( model=model, @@ -3691,7 +3700,7 @@ def embedding( raise ValueError( "api_base is required for triton. Please pass `api_base`" ) - response = triton_chat_completions.embedding( + response = triton_chat_completions.embedding( # type: ignore model=model, input=input, api_base=api_base, @@ -3783,6 +3792,7 @@ def embedding( timeout=timeout, aembedding=aembedding, print_verbose=print_verbose, + api_key=api_key, ) elif custom_llm_provider == "oobabooga": response = oobabooga.embedding( @@ -3793,14 +3803,16 @@ def embedding( logging_obj=logging, optional_params=optional_params, model_response=EmbeddingResponse(), + api_key=api_key, ) elif custom_llm_provider == "ollama": api_base = ( litellm.api_base or api_base - or get_secret("OLLAMA_API_BASE") + or get_secret_str("OLLAMA_API_BASE") or "http://localhost:11434" ) # type: ignore + if isinstance(input, str): input = [input] if not all(isinstance(item, str) for item in input): @@ -3881,13 +3893,13 @@ def embedding( api_key = ( api_key or litellm.api_key - or get_secret("XINFERENCE_API_KEY") + or get_secret_str("XINFERENCE_API_KEY") or "stub-xinference-key" ) # xinference does not need an api key, pass a stub key if user did not set one api_base = ( api_base or litellm.api_base - or get_secret("XINFERENCE_API_BASE") + or get_secret_str("XINFERENCE_API_BASE") or "http://127.0.0.1:9997/v1" ) response = openai_chat_completions.embedding( @@ -3911,19 +3923,20 @@ def embedding( optional_params=optional_params, model_response=EmbeddingResponse(), aembedding=aembedding, + api_key=api_key, ) elif custom_llm_provider == "azure_ai": api_base = ( api_base # for deepinfra/perplexity/anyscale/groq/friendliai we check in get_llm_provider and pass in the api base from there or litellm.api_base - or get_secret("AZURE_AI_API_BASE") + or get_secret_str("AZURE_AI_API_BASE") ) # set API KEY api_key = ( api_key or litellm.api_key # for deepinfra/perplexity/anyscale/friendliai we check in get_llm_provider and pass in the api key from there or litellm.openai_key - or get_secret("AZURE_AI_API_KEY") + or get_secret_str("AZURE_AI_API_KEY") ) ## EMBEDDING CALL @@ -3944,10 +3957,14 @@ def embedding( raise ValueError(f"No valid embedding model args passed in - {args}") if response is not None and hasattr(response, "_hidden_params"): response._hidden_params["custom_llm_provider"] = custom_llm_provider + + if response is None: + args = locals() + raise ValueError(f"No valid embedding model args passed in - {args}") return response except Exception as e: ## LOGGING - logging.post_call( + litellm_logging_obj.post_call( input=input, api_key=api_key, original_response=str(e), @@ -4018,7 +4035,11 @@ async def atext_completion( else: # Call the synchronous function using run_in_executor response = await loop.run_in_executor(None, func_with_context) - if kwargs.get("stream", False) is True: # return an async generator + if ( + kwargs.get("stream", False) is True + or isinstance(response, TextCompletionStreamWrapper) + or isinstance(response, CustomStreamWrapper) + ): # return an async generator return TextCompletionStreamWrapper( completion_stream=_async_streaming( response=response, @@ -4153,9 +4174,10 @@ def text_completion( Your example of how to use this function goes here. """ if "engine" in kwargs: - if model is None: + _engine = kwargs["engine"] + if model is None and isinstance(_engine, str): # only use engine when model not passed - model = kwargs["engine"] + model = _engine kwargs.pop("engine") text_completion_response = TextCompletionResponse() @@ -4223,7 +4245,7 @@ def text_completion( def process_prompt(i, individual_prompt): decoded_prompt = tokenizer.decode(individual_prompt) all_params = {**kwargs, **optional_params} - response = text_completion( + response: TextCompletionResponse = text_completion( # type: ignore model=model, prompt=decoded_prompt, num_retries=3, # ensure this does not fail for the batch @@ -4292,6 +4314,8 @@ def text_completion( model = "text-completion-openai/" + _model optional_params.pop("custom_llm_provider", None) + if model is None: + raise ValueError("model is not set. Set either via 'model' or 'engine' param.") kwargs["text_completion"] = True response = completion( model=model, @@ -4302,7 +4326,11 @@ def text_completion( ) if kwargs.get("acompletion", False) is True: return response - if stream is True or kwargs.get("stream", False) is True: + if ( + stream is True + or kwargs.get("stream", False) is True + or isinstance(response, CustomStreamWrapper) + ): response = TextCompletionStreamWrapper( completion_stream=response, model=model, @@ -4310,6 +4338,8 @@ def text_completion( custom_llm_provider=custom_llm_provider, ) return response + elif isinstance(response, TextCompletionStreamWrapper): + return response transformed_logprobs = None # only supported for TGI models try: @@ -4424,7 +4454,10 @@ def moderation( ): # only supports open ai for now api_key = ( - api_key or litellm.api_key or litellm.openai_key or get_secret("OPENAI_API_KEY") + api_key + or litellm.api_key + or litellm.openai_key + or get_secret_str("OPENAI_API_KEY") ) openai_client = kwargs.get("client", None) @@ -4433,7 +4466,10 @@ def moderation( api_key=api_key, ) - response = openai_client.moderations.create(input=input, model=model) + if model is not None: + response = openai_client.moderations.create(input=input, model=model) + else: + response = openai_client.moderations.create(input=input) return response @@ -4441,20 +4477,30 @@ def moderation( async def amoderation( input: str, model: Optional[str] = None, api_key: Optional[str] = None, **kwargs ): + from openai import AsyncOpenAI + # only supports open ai for now api_key = ( - api_key or litellm.api_key or litellm.openai_key or get_secret("OPENAI_API_KEY") + api_key + or litellm.api_key + or litellm.openai_key + or get_secret_str("OPENAI_API_KEY") ) openai_client = kwargs.get("client", None) - if openai_client is None: + if openai_client is None or not isinstance(openai_client, AsyncOpenAI): # call helper to get OpenAI client # _get_openai_client maintains in-memory caching logic for OpenAI clients - openai_client = openai_chat_completions._get_openai_client( + _openai_client: AsyncOpenAI = openai_chat_completions._get_openai_client( # type: ignore is_async=True, api_key=api_key, ) - response = await openai_client.moderations.create(input=input, model=model) + else: + _openai_client = openai_client + if model is not None: + response = await openai_client.moderations.create(input=input, model=model) + else: + response = await openai_client.moderations.create(input=input) return response @@ -4497,7 +4543,7 @@ async def aimage_generation(*args, **kwargs) -> ImageResponse: init_response = ImageResponse(**init_response) response = init_response elif asyncio.iscoroutine(init_response): - response = await init_response + response = await init_response # type: ignore else: # Call the synchronous function using run_in_executor response = await loop.run_in_executor(None, func_with_context) @@ -4527,7 +4573,6 @@ def image_generation( api_key: Optional[str] = None, api_base: Optional[str] = None, api_version: Optional[str] = None, - litellm_logging_obj=None, custom_llm_provider=None, **kwargs, ) -> ImageResponse: @@ -4543,9 +4588,10 @@ def image_generation( proxy_server_request = kwargs.get("proxy_server_request", None) model_info = kwargs.get("model_info", None) metadata = kwargs.get("metadata", {}) + litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore client = kwargs.get("client", None) - model_response = litellm.utils.ImageResponse() + model_response: ImageResponse = litellm.utils.ImageResponse() if model is not None or custom_llm_provider is not None: model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore else: @@ -4651,25 +4697,27 @@ def image_generation( if custom_llm_provider == "azure": # azure configs - api_type = get_secret("AZURE_API_TYPE") or "azure" + api_type = get_secret_str("AZURE_API_TYPE") or "azure" - api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE") + api_base = api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") api_version = ( - api_version or litellm.api_version or get_secret("AZURE_API_VERSION") + api_version + or litellm.api_version + or get_secret_str("AZURE_API_VERSION") ) api_key = ( api_key or litellm.api_key or litellm.azure_key - or get_secret("AZURE_OPENAI_API_KEY") - or get_secret("AZURE_API_KEY") + or get_secret_str("AZURE_OPENAI_API_KEY") + or get_secret_str("AZURE_API_KEY") ) - azure_ad_token = optional_params.pop("azure_ad_token", None) or get_secret( - "AZURE_AD_TOKEN" - ) + azure_ad_token = optional_params.pop( + "azure_ad_token", None + ) or get_secret_str("AZURE_AD_TOKEN") model_response = azure_chat_completions.image_generation( model=model, @@ -4714,18 +4762,18 @@ def image_generation( optional_params.pop("vertex_project", None) or optional_params.pop("vertex_ai_project", None) or litellm.vertex_project - or get_secret("VERTEXAI_PROJECT") + or get_secret_str("VERTEXAI_PROJECT") ) vertex_ai_location = ( optional_params.pop("vertex_location", None) or optional_params.pop("vertex_ai_location", None) or litellm.vertex_location - or get_secret("VERTEXAI_LOCATION") + or get_secret_str("VERTEXAI_LOCATION") ) vertex_credentials = ( optional_params.pop("vertex_credentials", None) or optional_params.pop("vertex_ai_credentials", None) - or get_secret("VERTEXAI_CREDENTIALS") + or get_secret_str("VERTEXAI_CREDENTIALS") ) model_response = vertex_image_generation.image_generation( model=model, @@ -4786,7 +4834,7 @@ async def atranscription(*args, **kwargs) -> TranscriptionResponse: elif isinstance(init_response, TranscriptionResponse): ## CACHING SCENARIO response = init_response elif asyncio.iscoroutine(init_response): - response = await init_response + response = await init_response # type: ignore else: # Call the synchronous function using run_in_executor response = await loop.run_in_executor(None, func_with_context) @@ -4820,7 +4868,6 @@ def transcription( api_base: Optional[str] = None, api_version: Optional[str] = None, max_retries: Optional[int] = None, - litellm_logging_obj: Optional[LiteLLMLoggingObj] = None, custom_llm_provider=None, **kwargs, ) -> TranscriptionResponse: @@ -4830,6 +4877,7 @@ def transcription( Allows router to load balance between them """ atranscription = kwargs.get("atranscription", False) + litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore kwargs.get("litellm_call_id", None) kwargs.get("logger_fn", None) kwargs.get("proxy_server_request", None) @@ -4869,22 +4917,17 @@ def transcription( custom_llm_provider=custom_llm_provider, drop_params=drop_params, ) - # optional_params = { - # "language": language, - # "prompt": prompt, - # "response_format": response_format, - # "temperature": None, # openai defaults this to 0 - # } + response: Optional[TranscriptionResponse] = None if custom_llm_provider == "azure": # azure configs - api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE") + api_base = api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") api_version = ( - api_version or litellm.api_version or get_secret("AZURE_API_VERSION") + api_version or litellm.api_version or get_secret_str("AZURE_API_VERSION") ) - azure_ad_token = kwargs.pop("azure_ad_token", None) or get_secret( + azure_ad_token = kwargs.pop("azure_ad_token", None) or get_secret_str( "AZURE_AD_TOKEN" ) @@ -4892,8 +4935,8 @@ def transcription( api_key or litellm.api_key or litellm.azure_key - or get_secret("AZURE_API_KEY") - ) # type: ignore + or get_secret_str("AZURE_API_KEY") + ) response = azure_audio_transcriptions.audio_transcriptions( model=model, @@ -4942,6 +4985,9 @@ def transcription( api_base=api_base, api_key=api_key, ) + + if response is None: + raise ValueError("Unmapped provider passed in. Unable to get the response.") return response @@ -5149,15 +5195,16 @@ def speech( vertex_ai_project = ( generic_optional_params.vertex_project or litellm.vertex_project - or get_secret("VERTEXAI_PROJECT") + or get_secret_str("VERTEXAI_PROJECT") ) vertex_ai_location = ( generic_optional_params.vertex_location or litellm.vertex_location - or get_secret("VERTEXAI_LOCATION") + or get_secret_str("VERTEXAI_LOCATION") ) - vertex_credentials = generic_optional_params.vertex_credentials or get_secret( - "VERTEXAI_CREDENTIALS" + vertex_credentials = ( + generic_optional_params.vertex_credentials + or get_secret_str("VERTEXAI_CREDENTIALS") ) if voice is not None and not isinstance(voice, dict): @@ -5234,20 +5281,25 @@ async def ahealth_check( if custom_llm_provider == "azure": api_key = ( model_params.get("api_key") - or get_secret("AZURE_API_KEY") - or get_secret("AZURE_OPENAI_API_KEY") + or get_secret_str("AZURE_API_KEY") + or get_secret_str("AZURE_OPENAI_API_KEY") ) - api_base = ( + api_base: Optional[str] = ( model_params.get("api_base") - or get_secret("AZURE_API_BASE") - or get_secret("AZURE_OPENAI_API_BASE") + or get_secret_str("AZURE_API_BASE") + or get_secret_str("AZURE_OPENAI_API_BASE") ) + if api_base is None: + raise ValueError( + "Azure API Base cannot be None. Set via 'AZURE_API_BASE' in env var or `.completion(..., api_base=..)`" + ) + api_version = ( model_params.get("api_version") - or get_secret("AZURE_API_VERSION") - or get_secret("AZURE_OPENAI_API_VERSION") + or get_secret_str("AZURE_API_VERSION") + or get_secret_str("AZURE_OPENAI_API_VERSION") ) timeout = ( @@ -5273,7 +5325,7 @@ async def ahealth_check( custom_llm_provider == "openai" or custom_llm_provider == "text-completion-openai" ): - api_key = model_params.get("api_key") or get_secret("OPENAI_API_KEY") + api_key = model_params.get("api_key") or get_secret_str("OPENAI_API_KEY") organization = model_params.get("organization") timeout = ( @@ -5282,7 +5334,7 @@ async def ahealth_check( or default_timeout ) - api_base = model_params.get("api_base") or get_secret("OPENAI_API_BASE") + api_base = model_params.get("api_base") or get_secret_str("OPENAI_API_BASE") if custom_llm_provider == "text-completion-openai": mode = "completion" @@ -5377,7 +5429,9 @@ def config_completion(**kwargs): ) -def stream_chunk_builder_text_completion(chunks: list, messages: Optional[List] = None): +def stream_chunk_builder_text_completion( + chunks: list, messages: Optional[List] = None +) -> TextCompletionResponse: id = chunks[0]["id"] object = chunks[0]["object"] created = chunks[0]["created"] @@ -5446,7 +5500,7 @@ def stream_chunk_builder_text_completion(chunks: list, messages: Optional[List] response["usage"]["total_tokens"] = ( response["usage"]["prompt_tokens"] + response["usage"]["completion_tokens"] ) - return response + return TextCompletionResponse(**response) def stream_chunk_builder( diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 1f279f444..c3b9bc00c 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union from pydantic import BaseModel, ConfigDict, Extra, Field, Json, model_validator from typing_extensions import Annotated, TypedDict -from litellm.integrations.SlackAlerting.types import AlertType +from litellm.types.integrations.slack_alerting import AlertType from litellm.types.router import RouterErrors, UpdateRouterConfig from litellm.types.utils import ProviderField, StandardCallbackDynamicParams diff --git a/litellm/proxy/caching_routes.py b/litellm/proxy/caching_routes.py index bad747793..6f07fcb9a 100644 --- a/litellm/proxy/caching_routes.py +++ b/litellm/proxy/caching_routes.py @@ -1,12 +1,13 @@ -from typing import Optional -from fastapi import Depends, Request, APIRouter -from fastapi import HTTPException import copy +from typing import Optional + +from fastapi import APIRouter, Depends, HTTPException, Request + import litellm from litellm._logging import verbose_proxy_logger +from litellm.caching import RedisCache from litellm.proxy.auth.user_api_key_auth import user_api_key_auth - router = APIRouter( prefix="/cache", tags=["caching"], @@ -21,9 +22,9 @@ async def cache_ping(): """ Endpoint for checking if cache can be pinged """ + litellm_cache_params = {} + specific_cache_params = {} try: - litellm_cache_params = {} - specific_cache_params = {} if litellm.cache is None: raise HTTPException( @@ -135,7 +136,9 @@ async def cache_redis_info(): raise HTTPException( status_code=503, detail="Cache not initialized. litellm.cache is None" ) - if litellm.cache.type == "redis": + if litellm.cache.type == "redis" and isinstance( + litellm.cache.cache, RedisCache + ): client_list = litellm.cache.cache.client_list() redis_info = litellm.cache.cache.info() num_clients = len(client_list) @@ -177,7 +180,9 @@ async def cache_flushall(): raise HTTPException( status_code=503, detail="Cache not initialized. litellm.cache is None" ) - if litellm.cache.type == "redis": + if litellm.cache.type == "redis" and isinstance( + litellm.cache.cache, RedisCache + ): litellm.cache.cache.flushall() return { "status": "success", diff --git a/litellm/proxy/fine_tuning_endpoints/endpoints.py b/litellm/proxy/fine_tuning_endpoints/endpoints.py index cda226b5a..02110458e 100644 --- a/litellm/proxy/fine_tuning_endpoints/endpoints.py +++ b/litellm/proxy/fine_tuning_endpoints/endpoints.py @@ -118,13 +118,13 @@ async def create_fine_tuning_job( version, ) + data = fine_tuning_request.model_dump(exclude_none=True) try: if premium_user is not True: raise ValueError( f"Only premium users can use this endpoint + {CommonProxyErrors.not_premium_user.value}" ) # Convert Pydantic model to dict - data = fine_tuning_request.model_dump(exclude_none=True) verbose_proxy_logger.debug( "Request received by LiteLLM:\n{}".format(json.dumps(data, indent=4)), @@ -146,7 +146,8 @@ async def create_fine_tuning_job( ) # add llm_provider_config to data - data.update(llm_provider_config) + if llm_provider_config is not None: + data.update(llm_provider_config) response = await litellm.acreate_fine_tuning_job(**data) @@ -262,7 +263,8 @@ async def list_fine_tuning_jobs( custom_llm_provider=custom_llm_provider ) - data.update(llm_provider_config) + if llm_provider_config is not None: + data.update(llm_provider_config) response = await litellm.alist_fine_tuning_jobs( **data, @@ -378,7 +380,8 @@ async def retrieve_fine_tuning_job( custom_llm_provider=custom_llm_provider ) - data.update(llm_provider_config) + if llm_provider_config is not None: + data.update(llm_provider_config) response = await litellm.acancel_fine_tuning_job( **data, diff --git a/litellm/proxy/management_helpers/utils.py b/litellm/proxy/management_helpers/utils.py index 8fa22ee90..7da90c615 100644 --- a/litellm/proxy/management_helpers/utils.py +++ b/litellm/proxy/management_helpers/utils.py @@ -227,8 +227,8 @@ async def send_management_endpoint_alert( - An internal user is created, updated, or deleted - A team is created, updated, or deleted """ - from litellm.integrations.SlackAlerting.types import AlertType from litellm.proxy.proxy_server import premium_user, proxy_logging_obj + from litellm.types.integrations.slack_alerting import AlertType if premium_user is not True: return diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 6efca6eb9..9e214c6ca 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -114,10 +114,7 @@ from litellm import ( from litellm._logging import verbose_proxy_logger, verbose_router_logger from litellm.caching import DualCache, RedisCache from litellm.exceptions import RejectedRequestError -from litellm.integrations.SlackAlerting.slack_alerting import ( - SlackAlerting, - SlackAlertingArgs, -) +from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting from litellm.litellm_core_utils.core_helpers import ( _get_parent_otel_span_from_kwargs, get_litellm_metadata_from_kwargs, @@ -249,6 +246,7 @@ from litellm.secret_managers.aws_secret_manager import ( ) from litellm.secret_managers.google_kms import load_google_kms from litellm.secret_managers.main import get_secret, get_secret_str, str_to_bool +from litellm.types.integrations.slack_alerting import SlackAlertingArgs from litellm.types.llms.anthropic import ( AnthropicMessagesRequest, AnthropicResponse, diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 0f57e90fc..a53da4512 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -54,7 +54,6 @@ from litellm.exceptions import RejectedRequestError from litellm.integrations.custom_guardrail import CustomGuardrail from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting -from litellm.integrations.SlackAlerting.types import DEFAULT_ALERT_TYPES from litellm.integrations.SlackAlerting.utils import _add_langfuse_trace_id_to_alert from litellm.litellm_core_utils.core_helpers import ( _get_parent_otel_span_from_kwargs, @@ -85,6 +84,7 @@ from litellm.proxy.hooks.parallel_request_limiter import ( _PROXY_MaxParallelRequestsHandler, ) from litellm.secret_managers.main import str_to_bool +from litellm.types.integrations.slack_alerting import DEFAULT_ALERT_TYPES from litellm.types.utils import CallTypes, LoggedLiteLLMParams if TYPE_CHECKING: diff --git a/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py b/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py index 86ffda67e..98e2a707d 100644 --- a/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py +++ b/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py @@ -208,7 +208,7 @@ async def vertex_proxy_route( request, fastapi_response, user_api_key_dict, - stream=is_streaming_request, + stream=is_streaming_request, # type: ignore ) return received_value diff --git a/litellm/rerank_api/main.py b/litellm/rerank_api/main.py index db2217b1e..a06aff135 100644 --- a/litellm/rerank_api/main.py +++ b/litellm/rerank_api/main.py @@ -10,11 +10,10 @@ from litellm.llms.azure_ai.rerank import AzureAIRerank from litellm.llms.cohere.rerank import CohereRerank from litellm.llms.together_ai.rerank import TogetherAIRerank from litellm.secret_managers.main import get_secret +from litellm.types.rerank import RerankRequest, RerankResponse from litellm.types.router import * from litellm.utils import client, exception_type, supports_httpx_timeout -from .types import RerankRequest, RerankResponse - ####### ENVIRONMENT VARIABLES ################### # Initialize any necessary instances or variables here cohere_rerank = CohereRerank() diff --git a/litellm/router_utils/client_initalization_utils.py b/litellm/router_utils/client_initalization_utils.py index 4f750336e..a1d4ee5bd 100644 --- a/litellm/router_utils/client_initalization_utils.py +++ b/litellm/router_utils/client_initalization_utils.py @@ -7,6 +7,7 @@ import httpx import openai import litellm +from litellm import get_secret, get_secret_str from litellm._logging import verbose_router_logger from litellm.llms.AzureOpenAI.azure import get_azure_ad_token_from_oidc from litellm.secret_managers.get_azure_ad_token_provider import ( @@ -111,17 +112,17 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): api_key = litellm_params.get("api_key") or default_api_key if api_key and isinstance(api_key, str) and api_key.startswith("os.environ/"): api_key_env_name = api_key.replace("os.environ/", "") - api_key = litellm.get_secret(api_key_env_name) + api_key = get_secret_str(api_key_env_name) litellm_params["api_key"] = api_key api_base = litellm_params.get("api_base") - base_url = litellm_params.get("base_url") + base_url: Optional[str] = litellm_params.get("base_url") api_base = ( api_base or base_url or default_api_base ) # allow users to pass in `api_base` or `base_url` for azure if api_base and api_base.startswith("os.environ/"): api_base_env_name = api_base.replace("os.environ/", "") - api_base = litellm.get_secret(api_base_env_name) + api_base = get_secret_str(api_base_env_name) litellm_params["api_base"] = api_base ## AZURE AI STUDIO MISTRAL CHECK ## @@ -147,33 +148,37 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): api_version = litellm_params.get("api_version") if api_version and api_version.startswith("os.environ/"): api_version_env_name = api_version.replace("os.environ/", "") - api_version = litellm.get_secret(api_version_env_name) + api_version = get_secret_str(api_version_env_name) litellm_params["api_version"] = api_version - timeout = litellm_params.pop("timeout", None) or litellm.request_timeout + timeout: Optional[float] = ( + litellm_params.pop("timeout", None) or litellm.request_timeout + ) if isinstance(timeout, str) and timeout.startswith("os.environ/"): timeout_env_name = timeout.replace("os.environ/", "") - timeout = litellm.get_secret(timeout_env_name) + timeout = get_secret(timeout_env_name) # type: ignore litellm_params["timeout"] = timeout - stream_timeout = litellm_params.pop( + stream_timeout: Optional[float] = litellm_params.pop( "stream_timeout", timeout ) # if no stream_timeout is set, default to timeout if isinstance(stream_timeout, str) and stream_timeout.startswith("os.environ/"): stream_timeout_env_name = stream_timeout.replace("os.environ/", "") - stream_timeout = litellm.get_secret(stream_timeout_env_name) + stream_timeout = get_secret(stream_timeout_env_name) # type: ignore litellm_params["stream_timeout"] = stream_timeout - max_retries = litellm_params.pop("max_retries", 0) # router handles retry logic + max_retries: Optional[int] = litellm_params.pop( + "max_retries", 0 + ) # router handles retry logic if isinstance(max_retries, str) and max_retries.startswith("os.environ/"): max_retries_env_name = max_retries.replace("os.environ/", "") - max_retries = litellm.get_secret(max_retries_env_name) + max_retries = get_secret(max_retries_env_name) # type: ignore litellm_params["max_retries"] = max_retries organization = litellm_params.get("organization", None) if isinstance(organization, str) and organization.startswith("os.environ/"): organization_env_name = organization.replace("os.environ/", "") - organization = litellm.get_secret(organization_env_name) + organization = get_secret_str(organization_env_name) litellm_params["organization"] = organization azure_ad_token_provider: Optional[Callable[[], str]] = None if litellm_params.get("tenant_id"): @@ -227,8 +232,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): azure_ad_token_provider=azure_ad_token_provider, base_url=api_base, api_version=api_version, - timeout=timeout, - max_retries=max_retries, + timeout=timeout, # type: ignore + max_retries=max_retries, # type: ignore http_client=httpx.AsyncClient( limits=httpx.Limits( max_connections=1000, max_keepalive_connections=100 @@ -253,8 +258,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): azure_ad_token_provider=azure_ad_token_provider, base_url=api_base, api_version=api_version, - timeout=timeout, - max_retries=max_retries, + timeout=timeout, # type: ignore + max_retries=max_retries, # type: ignore http_client=httpx.Client( limits=httpx.Limits( max_connections=1000, max_keepalive_connections=100 @@ -276,8 +281,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): azure_ad_token_provider=azure_ad_token_provider, base_url=api_base, api_version=api_version, - timeout=stream_timeout, - max_retries=max_retries, + timeout=stream_timeout, # type: ignore + max_retries=max_retries, # type: ignore http_client=httpx.AsyncClient( limits=httpx.Limits( max_connections=1000, max_keepalive_connections=100 @@ -302,8 +307,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): azure_ad_token_provider=azure_ad_token_provider, base_url=api_base, api_version=api_version, - timeout=stream_timeout, - max_retries=max_retries, + timeout=stream_timeout, # type: ignore + max_retries=max_retries, # type: ignore http_client=httpx.Client( limits=httpx.Limits( max_connections=1000, max_keepalive_connections=100 @@ -350,8 +355,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): cache_key = f"{model_id}_async_client" _client = openai.AsyncAzureOpenAI( # type: ignore **azure_client_params, - timeout=timeout, - max_retries=max_retries, + timeout=timeout, # type: ignore + max_retries=max_retries, # type: ignore http_client=httpx.AsyncClient( limits=httpx.Limits( max_connections=1000, max_keepalive_connections=100 @@ -371,8 +376,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): cache_key = f"{model_id}_client" _client = openai.AzureOpenAI( # type: ignore **azure_client_params, - timeout=timeout, - max_retries=max_retries, + timeout=timeout, # type: ignore + max_retries=max_retries, # type: ignore http_client=httpx.Client( limits=httpx.Limits( max_connections=1000, max_keepalive_connections=100 @@ -391,8 +396,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): cache_key = f"{model_id}_stream_async_client" _client = openai.AsyncAzureOpenAI( # type: ignore **azure_client_params, - timeout=stream_timeout, - max_retries=max_retries, + timeout=stream_timeout, # type: ignore + max_retries=max_retries, # type: ignore http_client=httpx.AsyncClient( limits=httpx.Limits( max_connections=1000, max_keepalive_connections=100 @@ -413,8 +418,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): cache_key = f"{model_id}_stream_client" _client = openai.AzureOpenAI( # type: ignore **azure_client_params, - timeout=stream_timeout, - max_retries=max_retries, + timeout=stream_timeout, # type: ignore + max_retries=max_retries, # type: ignore http_client=httpx.Client( limits=httpx.Limits( max_connections=1000, max_keepalive_connections=100 @@ -441,8 +446,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): _client = openai.AsyncOpenAI( # type: ignore api_key=api_key, base_url=api_base, - timeout=timeout, - max_retries=max_retries, + timeout=timeout, # type: ignore + max_retries=max_retries, # type: ignore organization=organization, http_client=httpx.AsyncClient( limits=httpx.Limits( @@ -465,8 +470,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): _client = openai.OpenAI( # type: ignore api_key=api_key, base_url=api_base, - timeout=timeout, - max_retries=max_retries, + timeout=timeout, # type: ignore + max_retries=max_retries, # type: ignore organization=organization, http_client=httpx.Client( limits=httpx.Limits( @@ -487,8 +492,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): _client = openai.AsyncOpenAI( # type: ignore api_key=api_key, base_url=api_base, - timeout=stream_timeout, - max_retries=max_retries, + timeout=stream_timeout, # type: ignore + max_retries=max_retries, # type: ignore organization=organization, http_client=httpx.AsyncClient( limits=httpx.Limits( @@ -512,8 +517,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): _client = openai.OpenAI( # type: ignore api_key=api_key, base_url=api_base, - timeout=stream_timeout, - max_retries=max_retries, + timeout=stream_timeout, # type: ignore + max_retries=max_retries, # type: ignore organization=organization, http_client=httpx.Client( limits=httpx.Limits( @@ -542,20 +547,29 @@ def get_azure_ad_token_from_entrata_id( verbose_router_logger.debug("Getting Azure AD Token from Entrata ID") if tenant_id.startswith("os.environ/"): - tenant_id = litellm.get_secret(tenant_id) + _tenant_id = get_secret_str(tenant_id) + else: + _tenant_id = tenant_id if client_id.startswith("os.environ/"): - client_id = litellm.get_secret(client_id) + _client_id = get_secret_str(client_id) + else: + _client_id = client_id if client_secret.startswith("os.environ/"): - client_secret = litellm.get_secret(client_secret) + _client_secret = get_secret_str(client_secret) + else: + _client_secret = client_secret + verbose_router_logger.debug( "tenant_id %s, client_id %s, client_secret %s", - tenant_id, - client_id, - client_secret, + _tenant_id, + _client_id, + _client_secret, ) - credential = ClientSecretCredential(tenant_id, client_id, client_secret) + if _tenant_id is None or _client_id is None or _client_secret is None: + raise ValueError("tenant_id, client_id, and client_secret must be provided") + credential = ClientSecretCredential(_tenant_id, _client_id, _client_secret) verbose_router_logger.debug("credential %s", credential) diff --git a/litellm/router_utils/handle_error.py b/litellm/router_utils/handle_error.py index d848fd82b..fd08d13a9 100644 --- a/litellm/router_utils/handle_error.py +++ b/litellm/router_utils/handle_error.py @@ -2,6 +2,8 @@ import asyncio import traceback from typing import TYPE_CHECKING, Any +from litellm.types.integrations.slack_alerting import AlertType + if TYPE_CHECKING: from litellm.router import Router as _Router @@ -49,5 +51,6 @@ async def send_llm_exception_alert( await litellm_router_instance.slack_alerting_logger.send_alert( message=f"LLM API call failed: `{exception_str}`", level="High", - alert_type="llm_exceptions", + alert_type=AlertType.llm_exceptions, + alerting_metadata={}, ) diff --git a/litellm/integrations/SlackAlerting/types.py b/litellm/types/integrations/slack_alerting.py similarity index 100% rename from litellm/integrations/SlackAlerting/types.py rename to litellm/types/integrations/slack_alerting.py diff --git a/litellm/rerank_api/types.py b/litellm/types/rerank.py similarity index 100% rename from litellm/rerank_api/types.py rename to litellm/types/rerank.py diff --git a/litellm/types/utils.py b/litellm/types/utils.py index faae0ff22..4e4699afa 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -792,7 +792,7 @@ class EmbeddingResponse(OpenAIObject): model: Optional[str] = None """The model used for embedding.""" - data: Optional[List] = None + data: List """The actual embedding value""" object: Literal["list"] @@ -803,6 +803,7 @@ class EmbeddingResponse(OpenAIObject): _hidden_params: dict = {} _response_headers: Optional[Dict] = None + _response_ms: Optional[float] = None def __init__( self, @@ -822,7 +823,7 @@ class EmbeddingResponse(OpenAIObject): if data: data = data else: - data = None + data = [] if usage: usage = usage diff --git a/litellm/utils.py b/litellm/utils.py index f1597f221..753e07f80 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -322,6 +322,9 @@ def function_setup( original_function: str, rules_obj, start_time, *args, **kwargs ): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc. ### NOTICES ### + from litellm import Logging as LiteLLMLogging + from litellm.litellm_core_utils.litellm_logging import set_callbacks + if litellm.set_verbose is True: verbose_logger.warning( "`litellm.set_verbose` is deprecated. Please set `os.environ['LITELLM_LOG'] = 'DEBUG'` for debug logs." @@ -333,7 +336,7 @@ def function_setup( custom_llm_setup() ## LOGGING SETUP - function_id = kwargs["id"] if "id" in kwargs else None + function_id: Optional[str] = kwargs["id"] if "id" in kwargs else None if len(litellm.callbacks) > 0: for callback in litellm.callbacks: @@ -375,9 +378,7 @@ def function_setup( + litellm.failure_callback ) ) - litellm.litellm_core_utils.litellm_logging.set_callbacks( - callback_list=callback_list, function_id=function_id - ) + set_callbacks(callback_list=callback_list, function_id=function_id) ## ASYNC CALLBACKS if len(litellm.input_callback) > 0: removed_async_items = [] @@ -560,12 +561,12 @@ def function_setup( else: messages = "default-message-value" stream = True if "stream" in kwargs and kwargs["stream"] is True else False - logging_obj = litellm.litellm_core_utils.litellm_logging.Logging( + logging_obj = LiteLLMLogging( model=model, messages=messages, stream=stream, litellm_call_id=kwargs["litellm_call_id"], - function_id=function_id, + function_id=function_id or "", call_type=call_type, start_time=start_time, dynamic_success_callbacks=dynamic_success_callbacks, @@ -655,10 +656,8 @@ def client(original_function): json_response_format = optional_params[ "response_format" ] - elif ( - _parsing._completions.is_basemodel_type( - optional_params["response_format"] - ) + elif _parsing._completions.is_basemodel_type( + optional_params["response_format"] # type: ignore ): json_response_format = ( type_to_response_format_param( @@ -827,6 +826,7 @@ def client(original_function): print_verbose("INSIDE CHECKING CACHE") if ( litellm.cache is not None + and litellm.cache.supported_call_types is not None and str(original_function.__name__) in litellm.cache.supported_call_types ): @@ -879,7 +879,7 @@ def client(original_function): dynamic_api_key, api_base, ) = litellm.get_llm_provider( - model=model, + model=model or "", custom_llm_provider=kwargs.get( "custom_llm_provider", None ), @@ -949,6 +949,8 @@ def client(original_function): base_model=base_model, messages=messages, user_max_tokens=user_max_tokens, + buffer_num=None, + buffer_perc=None, ) kwargs["max_tokens"] = modified_max_tokens except Exception as e: @@ -990,6 +992,7 @@ def client(original_function): # [OPTIONAL] ADD TO CACHE if ( litellm.cache is not None + and litellm.cache.supported_call_types is not None and str(original_function.__name__) in litellm.cache.supported_call_types ) and (kwargs.get("cache", {}).get("no-store", False) is not True): @@ -1006,7 +1009,7 @@ def client(original_function): "id", None ) result._hidden_params["api_base"] = get_api_base( - model=model, + model=model or "", optional_params=getattr(logging_obj, "optional_params", {}), ) result._hidden_params["response_cost"] = ( @@ -1053,7 +1056,7 @@ def client(original_function): and not _is_litellm_router_call ): if len(args) > 0: - args[0] = context_window_fallback_dict[model] + args[0] = context_window_fallback_dict[model] # type: ignore else: kwargs["model"] = context_window_fallback_dict[model] return original_function(*args, **kwargs) @@ -1065,12 +1068,6 @@ def client(original_function): logging_obj.failure_handler( e, traceback_exception, start_time, end_time ) # DO NOT MAKE THREADED - router retry fallback relies on this! - if hasattr(e, "message"): - if ( - liteDebuggerClient - and liteDebuggerClient.dashboard_url is not None - ): # make it easy to get to the debugger logs if you've initialized it - e.message += f"\n Check the log in your dashboard - {liteDebuggerClient.dashboard_url}" raise e @wraps(original_function) @@ -1126,6 +1123,7 @@ def client(original_function): print_verbose("INSIDE CHECKING CACHE") if ( litellm.cache is not None + and litellm.cache.supported_call_types is not None and str(original_function.__name__) in litellm.cache.supported_call_types ): @@ -1287,7 +1285,11 @@ def client(original_function): args=(cached_result, start_time, end_time, cache_hit), ).start() cache_key = kwargs.get("preset_cache_key", None) - cached_result._hidden_params["cache_key"] = cache_key + if ( + isinstance(cached_result, BaseModel) + or isinstance(cached_result, CustomStreamWrapper) + ) and hasattr(cached_result, "_hidden_params"): + cached_result._hidden_params["cache_key"] = cache_key # type: ignore return cached_result elif ( call_type == CallTypes.aembedding.value @@ -1447,6 +1449,7 @@ def client(original_function): # [OPTIONAL] ADD TO CACHE if ( (litellm.cache is not None) + and litellm.cache.supported_call_types is not None and ( str(original_function.__name__) in litellm.cache.supported_call_types @@ -1504,11 +1507,12 @@ def client(original_function): if ( isinstance(result, EmbeddingResponse) and final_embedding_cached_response is not None + and final_embedding_cached_response.data is not None ): idx = 0 final_data_list = [] for item in final_embedding_cached_response.data: - if item is None: + if item is None and result.data is not None: final_data_list.append(result.data[idx]) idx += 1 else: @@ -1575,7 +1579,7 @@ def client(original_function): and model in context_window_fallback_dict ): if len(args) > 0: - args[0] = context_window_fallback_dict[model] + args[0] = context_window_fallback_dict[model] # type: ignore else: kwargs["model"] = context_window_fallback_dict[model] return await original_function(*args, **kwargs) @@ -2945,13 +2949,19 @@ def get_optional_params( response_format=non_default_params["response_format"] ) # # clean out 'additionalProperties = False'. Causes vertexai/gemini OpenAI API Schema errors - https://github.com/langchain-ai/langchainjs/issues/5240 - if non_default_params["response_format"].get("json_schema", {}).get( - "schema" - ) is not None and custom_llm_provider in [ - "gemini", - "vertex_ai", - "vertex_ai_beta", - ]: + if ( + non_default_params["response_format"] is not None + and non_default_params["response_format"] + .get("json_schema", {}) + .get("schema") + is not None + and custom_llm_provider + in [ + "gemini", + "vertex_ai", + "vertex_ai_beta", + ] + ): old_schema = copy.deepcopy( non_default_params["response_format"] .get("json_schema", {}) @@ -3754,7 +3764,11 @@ def get_optional_params( non_default_params=non_default_params, optional_params=optional_params, model=model, - drop_params=drop_params, + drop_params=( + drop_params + if drop_params is not None and isinstance(drop_params, bool) + else False + ), ) elif custom_llm_provider == "openrouter": supported_params = get_supported_openai_params( @@ -3863,7 +3877,11 @@ def get_optional_params( non_default_params=non_default_params, optional_params=optional_params, model=model, - drop_params=drop_params, + drop_params=( + drop_params + if drop_params is not None and isinstance(drop_params, bool) + else False + ), ) elif custom_llm_provider == "azure": supported_params = get_supported_openai_params( @@ -4889,7 +4907,7 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod try: split_model, custom_llm_provider, _, _ = get_llm_provider(model=model) except Exception: - pass + split_model = model combined_model_name = model stripped_model_name = _strip_model_name(model=model) combined_stripped_model_name = stripped_model_name @@ -5865,6 +5883,8 @@ def convert_to_model_response_object( for idx, choice in enumerate(response_object["choices"]): ## HANDLE JSON MODE - anthropic returns single function call] tool_calls = choice["message"].get("tool_calls", None) + message: Optional[Message] = None + finish_reason: Optional[str] = None if ( convert_tool_call_to_json_mode and tool_calls is not None @@ -5877,7 +5897,7 @@ def convert_to_model_response_object( if json_mode_content_str is not None: message = litellm.Message(content=json_mode_content_str) finish_reason = "stop" - else: + if message is None: message = Message( content=choice["message"].get("content", None), role=choice["message"]["role"] or "assistant", @@ -6066,7 +6086,7 @@ def valid_model(model): model in litellm.open_ai_chat_completion_models or model in litellm.open_ai_text_completion_models ): - openai.Model.retrieve(model) + openai.models.retrieve(model) else: messages = [{"role": "user", "content": "Hello World"}] litellm.completion(model=model, messages=messages) @@ -6386,8 +6406,8 @@ class CustomStreamWrapper: self, completion_stream, model, - custom_llm_provider=None, - logging_obj=None, + logging_obj: Any, + custom_llm_provider: Optional[str] = None, stream_options=None, make_call: Optional[Callable] = None, _response_headers: Optional[dict] = None, @@ -6633,36 +6653,6 @@ class CustomStreamWrapper: "completion_tokens": completion_tokens, } - def handle_together_ai_chunk(self, chunk): - chunk = chunk.decode("utf-8") - text = "" - is_finished = False - finish_reason = None - if "text" in chunk: - text_index = chunk.find('"text":"') # this checks if text: exists - text_start = text_index + len('"text":"') - text_end = chunk.find('"}', text_start) - if text_index != -1 and text_end != -1: - extracted_text = chunk[text_start:text_end] - text = extracted_text - return { - "text": text, - "is_finished": is_finished, - "finish_reason": finish_reason, - } - elif "[DONE]" in chunk: - return {"text": text, "is_finished": True, "finish_reason": "stop"} - elif "error" in chunk: - raise litellm.together_ai.TogetherAIError( - status_code=422, message=f"{str(chunk)}" - ) - else: - return { - "text": text, - "is_finished": is_finished, - "finish_reason": finish_reason, - } - def handle_predibase_chunk(self, chunk): try: if not isinstance(chunk, str): @@ -7264,12 +7254,17 @@ class CustomStreamWrapper: try: if isinstance(chunk, dict): parsed_response = chunk - if isinstance(chunk, (str, bytes)): + elif isinstance(chunk, (str, bytes)): if isinstance(chunk, bytes): parsed_response = chunk.decode("utf-8") else: parsed_response = chunk - data_json = json.loads(parsed_response) + else: + raise ValueError("Unable to parse streaming chunk") + if isinstance(parsed_response, dict): + data_json = parsed_response + else: + data_json = json.loads(parsed_response) text = ( data_json.get("outputs", "")[0] .get("data", "") @@ -7331,8 +7326,7 @@ class CustomStreamWrapper: if ( len(model_response.choices) > 0 - and hasattr(model_response.choices[0], "delta") - and model_response.choices[0].delta is not None + and getattr(model_response.choices[0], "delta") is not None ): # do nothing, if object instantiated pass @@ -7350,7 +7344,7 @@ class CustomStreamWrapper: is_empty = False return is_empty - def chunk_creator(self, chunk): + def chunk_creator(self, chunk): # type: ignore model_response = self.model_response_creator() response_obj = {} try: @@ -7422,11 +7416,6 @@ class CustomStreamWrapper: completion_obj["content"] = response_obj["text"] if response_obj["is_finished"]: self.received_finish_reason = response_obj["finish_reason"] - elif self.custom_llm_provider and self.custom_llm_provider == "together_ai": - response_obj = self.handle_together_ai_chunk(chunk) - completion_obj["content"] = response_obj["text"] - if response_obj["is_finished"]: - self.received_finish_reason = response_obj["finish_reason"] elif self.custom_llm_provider and self.custom_llm_provider == "huggingface": response_obj = self.handle_huggingface_chunk(chunk) completion_obj["content"] = response_obj["text"] @@ -7475,51 +7464,6 @@ class CustomStreamWrapper: if self.sent_first_chunk is False: raise Exception("An unknown error occurred with the stream") self.received_finish_reason = "stop" - elif self.custom_llm_provider == "gemini": - if hasattr(chunk, "parts") is True: - try: - if len(chunk.parts) > 0: - completion_obj["content"] = chunk.parts[0].text - if len(chunk.parts) > 0 and hasattr( - chunk.parts[0], "finish_reason" - ): - self.received_finish_reason = chunk.parts[ - 0 - ].finish_reason.name - except Exception: - if chunk.parts[0].finish_reason.name == "SAFETY": - raise Exception( - f"The response was blocked by VertexAI. {str(chunk)}" - ) - else: - completion_obj["content"] = str(chunk) - elif self.custom_llm_provider and ( - self.custom_llm_provider == "vertex_ai_beta" - ): - from litellm.types.utils import ( - GenericStreamingChunk as UtilsStreamingChunk, - ) - - if self.received_finish_reason is not None: - raise StopIteration - response_obj: UtilsStreamingChunk = chunk - completion_obj["content"] = response_obj["text"] - if response_obj["is_finished"]: - self.received_finish_reason = response_obj["finish_reason"] - - if ( - self.stream_options - and self.stream_options.get("include_usage", False) is True - and response_obj["usage"] is not None - ): - model_response.usage = litellm.Usage( - prompt_tokens=response_obj["usage"]["prompt_tokens"], - completion_tokens=response_obj["usage"]["completion_tokens"], - total_tokens=response_obj["usage"]["total_tokens"], - ) - - if "tool_use" in response_obj and response_obj["tool_use"] is not None: - completion_obj["tool_calls"] = [response_obj["tool_use"]] elif self.custom_llm_provider and (self.custom_llm_provider == "vertex_ai"): import proto # type: ignore @@ -7624,53 +7568,7 @@ class CustomStreamWrapper: completion_obj["content"] = response_obj["text"] if response_obj["is_finished"]: self.received_finish_reason = response_obj["finish_reason"] - elif self.custom_llm_provider == "bedrock": - from litellm.types.llms.bedrock import GenericStreamingChunk - if self.received_finish_reason is not None: - raise StopIteration - response_obj: GenericStreamingChunk = chunk - completion_obj["content"] = response_obj["text"] - if response_obj["is_finished"]: - self.received_finish_reason = response_obj["finish_reason"] - - if ( - self.stream_options - and self.stream_options.get("include_usage", False) is True - and response_obj["usage"] is not None - ): - model_response.usage = litellm.Usage( - prompt_tokens=response_obj["usage"]["inputTokens"], - completion_tokens=response_obj["usage"]["outputTokens"], - total_tokens=response_obj["usage"]["totalTokens"], - ) - - if "tool_use" in response_obj and response_obj["tool_use"] is not None: - completion_obj["tool_calls"] = [response_obj["tool_use"]] - - elif self.custom_llm_provider == "sagemaker": - from litellm.types.llms.bedrock import GenericStreamingChunk - - if self.received_finish_reason is not None: - raise StopIteration - response_obj: GenericStreamingChunk = chunk - completion_obj["content"] = response_obj["text"] - if response_obj["is_finished"]: - self.received_finish_reason = response_obj["finish_reason"] - - if ( - self.stream_options - and self.stream_options.get("include_usage", False) is True - and response_obj["usage"] is not None - ): - model_response.usage = litellm.Usage( - prompt_tokens=response_obj["usage"]["inputTokens"], - completion_tokens=response_obj["usage"]["outputTokens"], - total_tokens=response_obj["usage"]["totalTokens"], - ) - - if "tool_use" in response_obj and response_obj["tool_use"] is not None: - completion_obj["tool_calls"] = [response_obj["tool_use"]] elif self.custom_llm_provider == "petals": if len(self.completion_stream) == 0: if self.received_finish_reason is not None: @@ -8181,9 +8079,11 @@ class CustomStreamWrapper: target=self.run_success_logging_in_thread, args=(response, cache_hit), ).start() # log response - self.response_uptil_now += ( - response.choices[0].delta.get("content", "") or "" - ) + choice = response.choices[0] + if isinstance(choice, StreamingChoices): + self.response_uptil_now += choice.delta.get("content", "") or "" + else: + self.response_uptil_now += "" self.rules.post_call_rules( input=self.response_uptil_now, model=self.model ) @@ -8223,8 +8123,11 @@ class CustomStreamWrapper: ) response = self.model_response_creator() if complete_streaming_response is not None: - response.usage = complete_streaming_response.usage - response._hidden_params["usage"] = complete_streaming_response.usage # type: ignore + setattr( + response, + "usage", + getattr(complete_streaming_response, "usage"), + ) ## LOGGING threading.Thread( target=self.logging_obj.success_handler, @@ -8349,9 +8252,11 @@ class CustomStreamWrapper: processed_chunk, cache_hit=cache_hit ) ) - self.response_uptil_now += ( - processed_chunk.choices[0].delta.get("content", "") or "" - ) + choice = processed_chunk.choices[0] + if isinstance(choice, StreamingChoices): + self.response_uptil_now += choice.delta.get("content", "") or "" + else: + self.response_uptil_now += "" self.rules.post_call_rules( input=self.response_uptil_now, model=self.model ) @@ -8401,9 +8306,13 @@ class CustomStreamWrapper: ) ) - self.response_uptil_now += ( - processed_chunk.choices[0].delta.get("content", "") or "" - ) + choice = processed_chunk.choices[0] + if isinstance(choice, StreamingChoices): + self.response_uptil_now += ( + choice.delta.get("content", "") or "" + ) + else: + self.response_uptil_now += "" self.rules.post_call_rules( input=self.response_uptil_now, model=self.model ) @@ -8423,7 +8332,11 @@ class CustomStreamWrapper: ) response = self.model_response_creator() if complete_streaming_response is not None: - setattr(response, "usage", complete_streaming_response.usage) + setattr( + response, + "usage", + getattr(complete_streaming_response, "usage"), + ) ## LOGGING threading.Thread( target=self.logging_obj.success_handler, @@ -8464,7 +8377,11 @@ class CustomStreamWrapper: ) response = self.model_response_creator() if complete_streaming_response is not None: - response.usage = complete_streaming_response.usage + setattr( + response, + "usage", + getattr(complete_streaming_response, "usage"), + ) ## LOGGING threading.Thread( target=self.logging_obj.success_handler, @@ -8898,7 +8815,7 @@ def trim_messages( if len(tool_messages): messages = messages[: -len(tool_messages)] - current_tokens = token_counter(model=model, messages=messages) + current_tokens = token_counter(model=model or "", messages=messages) print_verbose(f"Current tokens: {current_tokens}, max tokens: {max_tokens}") # Do nothing if current tokens under messages @@ -8909,6 +8826,7 @@ def trim_messages( print_verbose( f"Need to trim input messages: {messages}, current_tokens{current_tokens}, max_tokens: {max_tokens}" ) + system_message_event: Optional[dict] = None if system_message: system_message_event, max_tokens = process_system_message( system_message=system_message, max_tokens=max_tokens, model=model @@ -8926,7 +8844,7 @@ def trim_messages( ) # Add system message to the beginning of the final messages - if system_message: + if system_message_event: final_messages = [system_message_event] + final_messages if len(tool_messages) > 0: @@ -9214,6 +9132,8 @@ def is_cached_message(message: AllMessageValues) -> bool: Follows the anthropic format {"cache_control": {"type": "ephemeral"}} """ + if "content" not in message: + return False if message["content"] is None or isinstance(message["content"], str): return False diff --git a/pyrightconfig.json b/pyrightconfig.json index 86a21c65e..9a43abda7 100644 --- a/pyrightconfig.json +++ b/pyrightconfig.json @@ -1,6 +1,7 @@ { "ignore": [], - "exclude": ["**/node_modules", "**/__pycache__", "litellm/tests", "litellm/main.py", "litellm/utils.py", "litellm/types/utils.py"], - "reportMissingImports": false + "exclude": ["**/node_modules", "**/__pycache__", "litellm/types/utils.py"], + "reportMissingImports": false, + "reportPrivateImportUsage": false } \ No newline at end of file diff --git a/tests/local_testing/test_alerting.py b/tests/local_testing/test_alerting.py index 829f124a9..5785e829b 100644 --- a/tests/local_testing/test_alerting.py +++ b/tests/local_testing/test_alerting.py @@ -14,7 +14,7 @@ from typing import Optional import httpx -from litellm.integrations.SlackAlerting.types import AlertType +from litellm.types.integrations.slack_alerting import AlertType # import logging # logging.basicConfig(level=logging.DEBUG) diff --git a/tests/local_testing/test_completion_cost.py b/tests/local_testing/test_completion_cost.py index d12a90cc3..b220e94ea 100644 --- a/tests/local_testing/test_completion_cost.py +++ b/tests/local_testing/test_completion_cost.py @@ -1188,13 +1188,36 @@ def test_completion_cost_anthropic_prompt_caching(): system_fingerprint=None, usage=Usage( completion_tokens=10, - prompt_tokens=14, - total_tokens=24, + prompt_tokens=114, + total_tokens=124, + prompt_tokens_details=PromptTokensDetails(cached_tokens=0), cache_creation_input_tokens=100, cache_read_input_tokens=0, ), ) + cost_1 = completion_cost(model=model, completion_response=response_1) + + _model_info = litellm.get_model_info( + model="claude-3-5-sonnet-20240620", custom_llm_provider="anthropic" + ) + expected_cost = ( + ( + response_1.usage.prompt_tokens + - response_1.usage.prompt_tokens_details.cached_tokens + ) + * _model_info["input_cost_per_token"] + + response_1.usage.prompt_tokens_details.cached_tokens + * _model_info["cache_read_input_token_cost"] + + response_1.usage.cache_creation_input_tokens + * _model_info["cache_creation_input_token_cost"] + + response_1.usage.completion_tokens * _model_info["output_cost_per_token"] + ) # Cost of processing (non-cache hit + cache hit) + Cost of cache-writing (cache writing) + + assert round(expected_cost, 5) == round(cost_1, 5) + + print(f"expected_cost: {expected_cost}, cost_1: {cost_1}") + ## READ FROM CACHE ## (LESS EXPENSIVE) response_2 = ModelResponse( id="chatcmpl-3f427194-0840-4d08-b571-56bfe38a5424", @@ -1216,14 +1239,14 @@ def test_completion_cost_anthropic_prompt_caching(): system_fingerprint=None, usage=Usage( completion_tokens=10, - prompt_tokens=14, - total_tokens=24, + prompt_tokens=114, + total_tokens=134, + prompt_tokens_details=PromptTokensDetails(cached_tokens=100), cache_creation_input_tokens=0, cache_read_input_tokens=100, ), ) - cost_1 = completion_cost(model=model, completion_response=response_1) cost_2 = completion_cost(model=model, completion_response=response_2) assert cost_1 > cost_2 diff --git a/tests/local_testing/test_prompt_caching.py b/tests/local_testing/test_prompt_caching.py index d714c58a7..301ead3aa 100644 --- a/tests/local_testing/test_prompt_caching.py +++ b/tests/local_testing/test_prompt_caching.py @@ -10,12 +10,40 @@ import litellm import pytest +def _usage_format_tests(usage: litellm.Usage): + """ + OpenAI prompt caching + - prompt_tokens = sum of non-cache hit tokens + cache-hit tokens + - total_tokens = prompt_tokens + completion_tokens + + Example + ``` + "usage": { + "prompt_tokens": 2006, + "completion_tokens": 300, + "total_tokens": 2306, + "prompt_tokens_details": { + "cached_tokens": 1920 + }, + "completion_tokens_details": { + "reasoning_tokens": 0 + } + # ANTHROPIC_ONLY # + "cache_creation_input_tokens": 0 + } + ``` + """ + assert usage.total_tokens == usage.prompt_tokens + usage.completion_tokens + + assert usage.prompt_tokens > usage.prompt_tokens_details.cached_tokens + + @pytest.mark.parametrize( "model", [ "anthropic/claude-3-5-sonnet-20240620", - "openai/gpt-4o", - "deepseek/deepseek-chat", + # "openai/gpt-4o", + # "deepseek/deepseek-chat", ], ) def test_prompt_caching_model(model): @@ -66,9 +94,13 @@ def test_prompt_caching_model(model): max_tokens=10, ) + _usage_format_tests(response.usage) + print("response=", response) print("response.usage=", response.usage) + _usage_format_tests(response.usage) + assert "prompt_tokens_details" in response.usage assert response.usage.prompt_tokens_details.cached_tokens > 0