diff --git a/docs/my-website/docs/proxy/prod.md b/docs/my-website/docs/proxy/prod.md index 654d8cc2d..2fb4dd3b3 100644 --- a/docs/my-website/docs/proxy/prod.md +++ b/docs/my-website/docs/proxy/prod.md @@ -107,9 +107,9 @@ export LITELLM_SALT_KEY="sk-1234" |--------------|-------| | Avg latency | `50ms` | | Median latency | `51ms` | -| `/chat/completions` Requests/second | `35` | -| `/chat/completions` Requests/minute | `2100` | -| `/chat/completions` Requests/hour | `126K` | +| `/chat/completions` Requests/second | `100` | +| `/chat/completions` Requests/minute | `6000` | +| `/chat/completions` Requests/hour | `360K` | ### Verifying Debugging logs are off diff --git a/enterprise/enterprise_hooks/aporia_ai.py b/enterprise/enterprise_hooks/aporia_ai.py index 14a9a06f6..7f5339f30 100644 --- a/enterprise/enterprise_hooks/aporia_ai.py +++ b/enterprise/enterprise_hooks/aporia_ai.py @@ -26,7 +26,10 @@ from typing import List from datetime import datetime import aiohttp, asyncio from litellm._logging import verbose_proxy_logger -from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler +from litellm.llms.custom_httpx.http_handler import ( + get_async_httpx_client, + httpxSpecialProvider, +) import httpx import json from litellm.types.guardrails import GuardrailEventHooks @@ -40,8 +43,8 @@ class AporiaGuardrail(CustomGuardrail): def __init__( self, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs ): - self.async_handler = AsyncHTTPHandler( - timeout=httpx.Timeout(timeout=600.0, connect=5.0) + self.async_handler = get_async_httpx_client( + llm_provider=httpxSpecialProvider.GuardrailCallback ) self.aporia_api_key = api_key or os.environ["APORIO_API_KEY"] self.aporia_api_base = api_base or os.environ["APORIO_API_BASE"] diff --git a/litellm/__init__.py b/litellm/__init__.py index cf13edce4..57b9f6a71 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -24,6 +24,7 @@ from litellm.proxy._types import ( ) import httpx import dotenv +from enum import Enum litellm_mode = os.getenv("LITELLM_MODE", "DEV") # "PRODUCTION", "DEV" if litellm_mode == "DEV": @@ -678,62 +679,66 @@ model_list = ( + gemini_models ) -provider_list: List = [ - "openai", - "custom_openai", - "text-completion-openai", - "cohere", - "cohere_chat", - "clarifai", - "anthropic", - "replicate", - "huggingface", - "together_ai", - "openrouter", - "vertex_ai", - "vertex_ai_beta", - "palm", - "gemini", - "ai21", - "baseten", - "azure", - "azure_text", - "azure_ai", - "sagemaker", - "sagemaker_chat", - "bedrock", - "vllm", - "nlp_cloud", - "petals", - "oobabooga", - "ollama", - "ollama_chat", - "deepinfra", - "perplexity", - "anyscale", - "mistral", - "groq", - "nvidia_nim", - "cerebras", - "ai21_chat", - "volcengine", - "codestral", - "text-completion-codestral", - "deepseek", - "maritalk", - "voyage", - "cloudflare", - "xinference", - "fireworks_ai", - "friendliai", - "watsonx", - "triton", - "predibase", - "databricks", - "empower", - "github", - "custom", # custom apis -] + +class LlmProviders(str, Enum): + OPENAI = "openai" + CUSTOM_OPENAI = "custom_openai" + TEXT_COMPLETION_OPENAI = "text-completion-openai" + COHERE = "cohere" + COHERE_CHAT = "cohere_chat" + CLARIFAI = "clarifai" + ANTHROPIC = "anthropic" + REPLICATE = "replicate" + HUGGINGFACE = "huggingface" + TOGETHER_AI = "together_ai" + OPENROUTER = "openrouter" + VERTEX_AI = "vertex_ai" + VERTEX_AI_BETA = "vertex_ai_beta" + PALM = "palm" + GEMINI = "gemini" + AI21 = "ai21" + BASETEN = "baseten" + AZURE = "azure" + AZURE_TEXT = "azure_text" + AZURE_AI = "azure_ai" + SAGEMAKER = "sagemaker" + SAGEMAKER_CHAT = "sagemaker_chat" + BEDROCK = "bedrock" + VLLM = "vllm" + NLP_CLOUD = "nlp_cloud" + PETALS = "petals" + OOBABOOGA = "oobabooga" + OLLAMA = "ollama" + OLLAMA_CHAT = "ollama_chat" + DEEPINFRA = "deepinfra" + PERPLEXITY = "perplexity" + ANYSCALE = "anyscale" + MISTRAL = "mistral" + GROQ = "groq" + NVIDIA_NIM = "nvidia_nim" + CEREBRAS = "cerebras" + AI21_CHAT = "ai21_chat" + VOLCENGINE = "volcengine" + CODESTRAL = "codestral" + TEXT_COMPLETION_CODESTRAL = "text-completion-codestral" + DEEPSEEK = "deepseek" + MARITALK = "maritalk" + VOYAGE = "voyage" + CLOUDFLARE = "cloudflare" + XINFERENCE = "xinference" + FIREWORKS_AI = "fireworks_ai" + FRIENDLIAI = "friendliai" + WATSONX = "watsonx" + TRITON = "triton" + PREDIBASE = "predibase" + DATABRICKS = "databricks" + EMPOWER = "empower" + GITHUB = "github" + CUSTOM = "custom" + + +provider_list: List[Union[LlmProviders, str]] = list(LlmProviders) + models_by_provider: dict = { "openai": open_ai_chat_completion_models + open_ai_text_completion_models, diff --git a/litellm/caching.py b/litellm/caching.py index 7f67ee455..5add0cd8e 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -1242,8 +1242,9 @@ class QdrantSemanticCache(BaseCache): import os from litellm.llms.custom_httpx.http_handler import ( - _get_async_httpx_client, _get_httpx_client, + get_async_httpx_client, + httpxSpecialProvider, ) if collection_name is None: @@ -1290,7 +1291,9 @@ class QdrantSemanticCache(BaseCache): self.headers = headers self.sync_client = _get_httpx_client() - self.async_client = _get_async_httpx_client() + self.async_client = get_async_httpx_client( + llm_provider=httpxSpecialProvider.Caching + ) if quantization_config is None: print_verbose( diff --git a/litellm/integrations/braintrust_logging.py b/litellm/integrations/braintrust_logging.py index 5128fcc49..b08cea54e 100644 --- a/litellm/integrations/braintrust_logging.py +++ b/litellm/integrations/braintrust_logging.py @@ -17,10 +17,17 @@ from pydantic import BaseModel import litellm from litellm import verbose_logger from litellm.integrations.custom_logger import CustomLogger -from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + HTTPHandler, + get_async_httpx_client, + httpxSpecialProvider, +) from litellm.utils import get_formatted_prompt -global_braintrust_http_handler = AsyncHTTPHandler() +global_braintrust_http_handler = get_async_httpx_client( + llm_provider=httpxSpecialProvider.LoggingCallback +) global_braintrust_sync_http_handler = HTTPHandler() API_BASE = "https://api.braintrustdata.com/v1" diff --git a/litellm/integrations/galileo.py b/litellm/integrations/galileo.py index 51d845fcb..70b5c3b01 100644 --- a/litellm/integrations/galileo.py +++ b/litellm/integrations/galileo.py @@ -8,7 +8,10 @@ from pydantic import BaseModel, Field import litellm from litellm._logging import verbose_logger from litellm.integrations.custom_logger import CustomLogger -from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler +from litellm.llms.custom_httpx.http_handler import ( + get_async_httpx_client, + httpxSpecialProvider, +) # from here: https://docs.rungalileo.io/galileo/gen-ai-studio-products/galileo-observe/how-to/logging-data-via-restful-apis#structuring-your-records @@ -39,8 +42,8 @@ class GalileoObserve(CustomLogger): self.base_url = os.getenv("GALILEO_BASE_URL", None) self.project_id = os.getenv("GALILEO_PROJECT_ID", None) self.headers = None - self.async_httpx_handler = AsyncHTTPHandler( - timeout=httpx.Timeout(timeout=600.0, connect=5.0) + self.async_httpx_handler = get_async_httpx_client( + llm_provider=httpxSpecialProvider.LoggingCallback ) pass diff --git a/litellm/integrations/gcs_bucket_base.py b/litellm/integrations/gcs_bucket_base.py index 073f0f265..0f71fe829 100644 --- a/litellm/integrations/gcs_bucket_base.py +++ b/litellm/integrations/gcs_bucket_base.py @@ -13,15 +13,18 @@ from litellm.integrations.custom_logger import CustomLogger from litellm.litellm_core_utils.logging_utils import ( convert_litellm_response_object_to_dict, ) -from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler +from litellm.llms.custom_httpx.http_handler import ( + get_async_httpx_client, + httpxSpecialProvider, +) class GCSBucketBase(CustomLogger): def __init__(self, bucket_name: Optional[str] = None) -> None: from litellm.proxy.proxy_server import premium_user - self.async_httpx_client = AsyncHTTPHandler( - timeout=httpx.Timeout(timeout=600.0, connect=5.0) + self.async_httpx_client = get_async_httpx_client( + llm_provider=httpxSpecialProvider.LoggingCallback ) self.path_service_account_json = os.getenv("GCS_PATH_SERVICE_ACCOUNT", None) self.BUCKET_NAME = bucket_name or os.getenv("GCS_BUCKET_NAME", None) diff --git a/litellm/integrations/lago.py b/litellm/integrations/lago.py index b3db53301..e0238e311 100644 --- a/litellm/integrations/lago.py +++ b/litellm/integrations/lago.py @@ -13,7 +13,11 @@ import httpx import litellm from litellm import verbose_logger from litellm.integrations.custom_logger import CustomLogger -from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.llms.custom_httpx.http_handler import ( + HTTPHandler, + get_async_httpx_client, + httpxSpecialProvider, +) def get_utc_datetime(): @@ -30,7 +34,9 @@ class LagoLogger(CustomLogger): def __init__(self) -> None: super().__init__() self.validate_environment() - self.async_http_handler = AsyncHTTPHandler() + self.async_http_handler = get_async_httpx_client( + llm_provider=httpxSpecialProvider.LoggingCallback + ) self.sync_http_handler = HTTPHandler() def validate_environment(self): diff --git a/litellm/integrations/langsmith.py b/litellm/integrations/langsmith.py index d245a6e2c..3ee55b3e9 100644 --- a/litellm/integrations/langsmith.py +++ b/litellm/integrations/langsmith.py @@ -16,7 +16,11 @@ from pydantic import BaseModel # type: ignore import litellm from litellm._logging import verbose_logger from litellm.integrations.custom_logger import CustomLogger -from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + get_async_httpx_client, + httpxSpecialProvider, +) class LangsmithInputs(BaseModel): @@ -61,8 +65,8 @@ class LangsmithLogger(CustomLogger): self.langsmith_base_url = os.getenv( "LANGSMITH_BASE_URL", "https://api.smith.langchain.com" ) - self.async_httpx_client = AsyncHTTPHandler( - timeout=httpx.Timeout(timeout=600.0, connect=5.0) + self.async_httpx_client = get_async_httpx_client( + llm_provider=httpxSpecialProvider.LoggingCallback ) def _prepare_log_data(self, kwargs, response_obj, start_time, end_time): diff --git a/litellm/integrations/openmeter.py b/litellm/integrations/openmeter.py index 6905fd789..19460e001 100644 --- a/litellm/integrations/openmeter.py +++ b/litellm/integrations/openmeter.py @@ -12,7 +12,12 @@ import httpx import litellm from litellm import verbose_logger from litellm.integrations.custom_logger import CustomLogger -from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + HTTPHandler, + get_async_httpx_client, + httpxSpecialProvider, +) def get_utc_datetime(): @@ -29,7 +34,9 @@ class OpenMeterLogger(CustomLogger): def __init__(self) -> None: super().__init__() self.validate_environment() - self.async_http_handler = AsyncHTTPHandler() + self.async_http_handler = get_async_httpx_client( + llm_provider=httpxSpecialProvider.LoggingCallback + ) self.sync_http_handler = HTTPHandler() def validate_environment(self): diff --git a/litellm/integrations/prometheus_helpers/prometheus_api.py b/litellm/integrations/prometheus_helpers/prometheus_api.py index 13ccc1562..472c3b61b 100644 --- a/litellm/integrations/prometheus_helpers/prometheus_api.py +++ b/litellm/integrations/prometheus_helpers/prometheus_api.py @@ -8,11 +8,16 @@ import time import litellm from litellm._logging import verbose_logger -from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler +from litellm.llms.custom_httpx.http_handler import ( + get_async_httpx_client, + httpxSpecialProvider, +) PROMETHEUS_URL = litellm.get_secret("PROMETHEUS_URL") PROMETHEUS_SELECTED_INSTANCE = litellm.get_secret("PROMETHEUS_SELECTED_INSTANCE") -async_http_handler = AsyncHTTPHandler() +async_http_handler = get_async_httpx_client( + llm_provider=httpxSpecialProvider.LoggingCallback +) async def get_metric_from_prometheus( diff --git a/litellm/integrations/slack_alerting.py b/litellm/integrations/slack_alerting.py index eab3f9c38..42e289b87 100644 --- a/litellm/integrations/slack_alerting.py +++ b/litellm/integrations/slack_alerting.py @@ -25,7 +25,11 @@ from litellm._logging import verbose_logger, verbose_proxy_logger from litellm.caching import DualCache from litellm.integrations.custom_logger import CustomLogger from litellm.litellm_core_utils.litellm_logging import Logging -from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + get_async_httpx_client, + httpxSpecialProvider, +) from litellm.proxy._types import ( AlertType, CallInfo, @@ -187,7 +191,9 @@ class SlackAlerting(CustomLogger): self.alerting = alerting self.alert_types = alert_types self.internal_usage_cache = internal_usage_cache or DualCache() - self.async_http_handler = AsyncHTTPHandler() + self.async_http_handler = get_async_httpx_client( + llm_provider=httpxSpecialProvider.LoggingCallback + ) self.alert_to_webhook_url = alert_to_webhook_url self.is_running = False self.alerting_args = SlackAlertingArgs(**alerting_args) diff --git a/litellm/llms/anthropic/chat.py b/litellm/llms/anthropic/chat.py index 18e530bb7..0520506bb 100644 --- a/litellm/llms/anthropic/chat.py +++ b/litellm/llms/anthropic/chat.py @@ -25,8 +25,8 @@ from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.llms.custom_httpx.http_handler import ( AsyncHTTPHandler, HTTPHandler, - _get_async_httpx_client, _get_httpx_client, + get_async_httpx_client, ) from litellm.types.llms.anthropic import ( AnthopicMessagesAssistantMessageParam, @@ -918,7 +918,9 @@ class AnthropicChatCompletion(BaseLLM): headers={}, client=None, ) -> Union[ModelResponse, CustomStreamWrapper]: - async_handler = _get_async_httpx_client() + async_handler = get_async_httpx_client( + llm_provider=litellm.LlmProviders.ANTHROPIC + ) try: response = await async_handler.post( diff --git a/litellm/llms/bedrock/chat.py b/litellm/llms/bedrock/chat.py index ee09797ba..8d6d98ba6 100644 --- a/litellm/llms/bedrock/chat.py +++ b/litellm/llms/bedrock/chat.py @@ -35,8 +35,8 @@ from litellm.litellm_core_utils.litellm_logging import Logging from litellm.llms.custom_httpx.http_handler import ( AsyncHTTPHandler, HTTPHandler, - _get_async_httpx_client, _get_httpx_client, + get_async_httpx_client, ) from litellm.types.llms.bedrock import * from litellm.types.llms.openai import ( @@ -209,7 +209,9 @@ async def make_call( ): try: if client is None: - client = _get_async_httpx_client() # Create a new client if none provided + client = get_async_httpx_client( + llm_provider=litellm.LlmProviders.BEDROCK + ) # Create a new client if none provided response = await client.post( api_base, @@ -1041,7 +1043,7 @@ class BedrockLLM(BaseAWSLLM): if isinstance(timeout, float) or isinstance(timeout, int): timeout = httpx.Timeout(timeout) _params["timeout"] = timeout - client = _get_async_httpx_client(_params) # type: ignore + client = get_async_httpx_client(params=_params, llm_provider=litellm.LlmProviders.BEDROCK) # type: ignore else: client = client # type: ignore @@ -1498,7 +1500,9 @@ class BedrockConverseLLM(BaseAWSLLM): if isinstance(timeout, float) or isinstance(timeout, int): timeout = httpx.Timeout(timeout) _params["timeout"] = timeout - client = _get_async_httpx_client(_params) # type: ignore + client = get_async_httpx_client( + params=_params, llm_provider=litellm.LlmProviders.BEDROCK + ) else: client = client # type: ignore diff --git a/litellm/llms/bedrock/embed/embedding.py b/litellm/llms/bedrock/embed/embedding.py index a7a6c173c..7d2e441da 100644 --- a/litellm/llms/bedrock/embed/embedding.py +++ b/litellm/llms/bedrock/embed/embedding.py @@ -15,8 +15,8 @@ from litellm.llms.cohere.embed import embedding as cohere_embedding from litellm.llms.custom_httpx.http_handler import ( AsyncHTTPHandler, HTTPHandler, - _get_async_httpx_client, _get_httpx_client, + get_async_httpx_client, ) from litellm.secret_managers.main import get_secret from litellm.types.llms.bedrock import AmazonEmbeddingRequest, CohereEmbeddingRequest @@ -130,7 +130,9 @@ class BedrockEmbedding(BaseAWSLLM): if isinstance(timeout, float) or isinstance(timeout, int): timeout = httpx.Timeout(timeout) _params["timeout"] = timeout - client = _get_async_httpx_client(_params) # type: ignore + client = get_async_httpx_client( + params=_params, llm_provider=litellm.LlmProviders.BEDROCK + ) else: client = client diff --git a/litellm/llms/cohere/rerank.py b/litellm/llms/cohere/rerank.py index 97cd7e399..64afcae4f 100644 --- a/litellm/llms/cohere/rerank.py +++ b/litellm/llms/cohere/rerank.py @@ -9,10 +9,11 @@ from typing import Any, Dict, List, Optional, Union import httpx from pydantic import BaseModel +import litellm from litellm.llms.base import BaseLLM from litellm.llms.custom_httpx.http_handler import ( - _get_async_httpx_client, _get_httpx_client, + get_async_httpx_client, ) from litellm.rerank_api.types import RerankRequest, RerankResponse @@ -65,7 +66,7 @@ class CohereRerank(BaseLLM): api_key: str, api_base: str, ) -> RerankResponse: - client = _get_async_httpx_client() + client = get_async_httpx_client(llm_provider=litellm.LlmProviders.COHERE) response = await client.post( api_base, diff --git a/litellm/llms/custom_httpx/http_handler.py b/litellm/llms/custom_httpx/http_handler.py index 2f07ee2f7..e63bd3f54 100644 --- a/litellm/llms/custom_httpx/http_handler.py +++ b/litellm/llms/custom_httpx/http_handler.py @@ -1,12 +1,19 @@ import asyncio import os import traceback -from typing import Any, Mapping, Optional, Union +from typing import TYPE_CHECKING, Any, Mapping, Optional, Union import httpx import litellm +from .types import httpxSpecialProvider + +if TYPE_CHECKING: + from litellm import LlmProviders +else: + LlmProviders = Any + try: from litellm._version import version except: @@ -378,7 +385,10 @@ class HTTPHandler: pass -def _get_async_httpx_client(params: Optional[dict] = None) -> AsyncHTTPHandler: +def get_async_httpx_client( + llm_provider: Union[LlmProviders, httpxSpecialProvider], + params: Optional[dict] = None, +) -> AsyncHTTPHandler: """ Retrieves the async HTTP client from the cache If not present, creates a new client @@ -393,7 +403,7 @@ def _get_async_httpx_client(params: Optional[dict] = None) -> AsyncHTTPHandler: except Exception: pass - _cache_key_name = "async_httpx_client" + _params_key_name + _cache_key_name = "async_httpx_client" + _params_key_name + llm_provider if _cache_key_name in litellm.in_memory_llm_clients_cache: return litellm.in_memory_llm_clients_cache[_cache_key_name] diff --git a/litellm/llms/custom_httpx/types.py b/litellm/llms/custom_httpx/types.py new file mode 100644 index 000000000..dc0958118 --- /dev/null +++ b/litellm/llms/custom_httpx/types.py @@ -0,0 +1,10 @@ +from enum import Enum + +import litellm + + +class httpxSpecialProvider(str, Enum): + LoggingCallback = "logging_callback" + GuardrailCallback = "guardrail_callback" + Caching = "caching" + Oauth2Check = "oauth2_check" diff --git a/litellm/llms/sagemaker/sagemaker.py b/litellm/llms/sagemaker/sagemaker.py index a7b36134b..b124308eb 100644 --- a/litellm/llms/sagemaker/sagemaker.py +++ b/litellm/llms/sagemaker/sagemaker.py @@ -19,8 +19,8 @@ from litellm.litellm_core_utils.asyncify import asyncify from litellm.llms.custom_httpx.http_handler import ( AsyncHTTPHandler, HTTPHandler, - _get_async_httpx_client, _get_httpx_client, + get_async_httpx_client, ) from litellm.types.llms.openai import ( ChatCompletionToolCallChunk, @@ -565,8 +565,8 @@ class SagemakerLLM(BaseAWSLLM): ): try: if client is None: - client = ( - _get_async_httpx_client() + client = get_async_httpx_client( + llm_provider=litellm.LlmProviders.SAGEMAKER ) # Create a new client if none provided response = await client.post( api_base, @@ -673,7 +673,9 @@ class SagemakerLLM(BaseAWSLLM): model_id: Optional[str], ): timeout = 300.0 - async_handler = _get_async_httpx_client() + async_handler = get_async_httpx_client( + llm_provider=litellm.LlmProviders.SAGEMAKER + ) async_transform_prompt = asyncify(self._transform_prompt) diff --git a/litellm/llms/togetherai/rerank.py b/litellm/llms/togetherai/rerank.py index 5d905071c..ea57c46c7 100644 --- a/litellm/llms/togetherai/rerank.py +++ b/litellm/llms/togetherai/rerank.py @@ -9,10 +9,11 @@ from typing import Any, Dict, List, Optional, Union import httpx from pydantic import BaseModel +import litellm from litellm.llms.base import BaseLLM from litellm.llms.custom_httpx.http_handler import ( - _get_async_httpx_client, _get_httpx_client, + get_async_httpx_client, ) from litellm.rerank_api.types import RerankRequest, RerankResponse @@ -77,7 +78,9 @@ class TogetherAIRerank(BaseLLM): request_data_dict: Dict[str, Any], api_key: str, ) -> RerankResponse: - client = _get_async_httpx_client() # Use async client + client = get_async_httpx_client( + llm_provider=litellm.LlmProviders.TOGETHER_AI + ) # Use async client response = await client.post( "https://api.together.xyz/v1/rerank", diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py index 000512d94..9b57c293e 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py @@ -22,7 +22,7 @@ from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.llms.custom_httpx.http_handler import ( AsyncHTTPHandler, HTTPHandler, - _get_async_httpx_client, + get_async_httpx_client, ) from litellm.llms.prompt_templates.factory import ( convert_url_to_base64, @@ -1294,7 +1294,9 @@ class VertexLLM(BaseLLM): if timeout: _async_client_params["timeout"] = timeout if client is None or not isinstance(client, AsyncHTTPHandler): - client = _get_async_httpx_client(params=_async_client_params) + client = get_async_httpx_client( + params=_async_client_params, llm_provider=litellm.LlmProviders.VERTEX_AI + ) else: client = client # type: ignore ## LOGGING diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/text_to_speech/text_to_speech_handler.py b/litellm/llms/vertex_ai_and_google_ai_studio/text_to_speech/text_to_speech_handler.py index bc2424ecc..f3837d94b 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/text_to_speech/text_to_speech_handler.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/text_to_speech/text_to_speech_handler.py @@ -4,13 +4,14 @@ from typing import Any, Coroutine, Literal, Optional, TypedDict, Union import httpx +import litellm from litellm._logging import verbose_logger from litellm.llms.base import BaseLLM from litellm.llms.custom_httpx.http_handler import ( AsyncHTTPHandler, HTTPHandler, - _get_async_httpx_client, _get_httpx_client, + get_async_httpx_client, ) from litellm.llms.OpenAI.openai import HttpxBinaryResponseContent from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( @@ -178,7 +179,9 @@ class VertexTextToSpeechAPI(VertexLLM): ) -> HttpxBinaryResponseContent: import base64 - async_handler = _get_async_httpx_client() + async_handler = get_async_httpx_client( + llm_provider=litellm.LlmProviders.VERTEX_AI + ) response = await async_handler.post( url=url, diff --git a/litellm/proxy/auth/oauth2_check.py b/litellm/proxy/auth/oauth2_check.py index ed5a3e26b..85a112ef1 100644 --- a/litellm/proxy/auth/oauth2_check.py +++ b/litellm/proxy/auth/oauth2_check.py @@ -20,7 +20,10 @@ async def check_oauth2_token(token: str) -> UserAPIKeyAuth: import httpx from litellm._logging import verbose_proxy_logger - from litellm.llms.custom_httpx.http_handler import _get_async_httpx_client + from litellm.llms.custom_httpx.http_handler import ( + get_async_httpx_client, + httpxSpecialProvider, + ) from litellm.proxy._types import CommonProxyErrors from litellm.proxy.proxy_server import premium_user @@ -40,7 +43,7 @@ async def check_oauth2_token(token: str) -> UserAPIKeyAuth: if not token_info_endpoint: raise ValueError("OAUTH_TOKEN_INFO_ENDPOINT environment variable is not set") - client = _get_async_httpx_client() + client = get_async_httpx_client(llm_provider=httpxSpecialProvider.Oauth2Check) headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} try: diff --git a/litellm/proxy/guardrails/guardrail_hooks/aporia_ai.py b/litellm/proxy/guardrails/guardrail_hooks/aporia_ai.py index c16c0543d..23ac33927 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/aporia_ai.py +++ b/litellm/proxy/guardrails/guardrail_hooks/aporia_ai.py @@ -30,7 +30,11 @@ from litellm.integrations.custom_guardrail import CustomGuardrail from litellm.litellm_core_utils.logging_utils import ( convert_litellm_response_object_to_str, ) -from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + get_async_httpx_client, + httpxSpecialProvider, +) from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata from litellm.types.guardrails import GuardrailEventHooks @@ -44,8 +48,8 @@ class AporiaGuardrail(CustomGuardrail): def __init__( self, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs ): - self.async_handler = AsyncHTTPHandler( - timeout=httpx.Timeout(timeout=600.0, connect=5.0) + self.async_handler = get_async_httpx_client( + llm_provider=httpxSpecialProvider.GuardrailCallback ) self.aporia_api_key = api_key or os.environ["APORIO_API_KEY"] self.aporia_api_base = api_base or os.environ["APORIO_API_BASE"] diff --git a/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py b/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py index eee26bd42..4c47aeb86 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py +++ b/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py @@ -33,7 +33,8 @@ from litellm.litellm_core_utils.logging_utils import ( from litellm.llms.base_aws_llm import BaseAWSLLM from litellm.llms.custom_httpx.http_handler import ( AsyncHTTPHandler, - _get_async_httpx_client, + get_async_httpx_client, + httpxSpecialProvider, ) from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata @@ -55,7 +56,9 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM): guardrailVersion: Optional[str] = None, **kwargs, ): - self.async_handler = _get_async_httpx_client() + self.async_handler = get_async_httpx_client( + llm_provider=httpxSpecialProvider.GuardrailCallback + ) self.guardrailIdentifier = guardrailIdentifier self.guardrailVersion = guardrailVersion diff --git a/litellm/proxy/guardrails/guardrail_hooks/lakera_ai.py b/litellm/proxy/guardrails/guardrail_hooks/lakera_ai.py index 5aebbcf3e..99b56daa8 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/lakera_ai.py +++ b/litellm/proxy/guardrails/guardrail_hooks/lakera_ai.py @@ -21,7 +21,11 @@ from fastapi import HTTPException import litellm from litellm._logging import verbose_proxy_logger from litellm.integrations.custom_guardrail import CustomGuardrail -from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + get_async_httpx_client, + httpxSpecialProvider, +) from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata from litellm.secret_managers.main import get_secret @@ -50,8 +54,8 @@ class lakeraAI_Moderation(CustomGuardrail): api_key: Optional[str] = None, **kwargs, ): - self.async_handler = AsyncHTTPHandler( - timeout=httpx.Timeout(timeout=600.0, connect=5.0) + self.async_handler = get_async_httpx_client( + llm_provider=httpxSpecialProvider.GuardrailCallback ) self.lakera_api_key = api_key or os.environ["LAKERA_API_KEY"] self.moderation_check = moderation_check diff --git a/litellm/proxy/guardrails/guardrail_hooks/presidio.py b/litellm/proxy/guardrails/guardrail_hooks/presidio.py index 857704bf2..eb09e5203 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/presidio.py +++ b/litellm/proxy/guardrails/guardrail_hooks/presidio.py @@ -22,7 +22,6 @@ import litellm # noqa: E401 from litellm._logging import verbose_proxy_logger from litellm.caching import DualCache from litellm.integrations.custom_guardrail import CustomGuardrail -from litellm.llms.custom_httpx.http_handler import _get_async_httpx_client from litellm.proxy._types import UserAPIKeyAuth from litellm.utils import ( EmbeddingResponse,