diff --git a/litellm/caching.py b/litellm/caching.py index 2c0a5a2da..5add0cd8e 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -1244,6 +1244,7 @@ class QdrantSemanticCache(BaseCache): from litellm.llms.custom_httpx.http_handler import ( _get_httpx_client, get_async_httpx_client, + httpxSpecialProvider, ) if collection_name is None: @@ -1290,7 +1291,9 @@ class QdrantSemanticCache(BaseCache): self.headers = headers self.sync_client = _get_httpx_client() - self.async_client = get_async_httpx_client() + self.async_client = get_async_httpx_client( + llm_provider=httpxSpecialProvider.Caching + ) if quantization_config is None: print_verbose( diff --git a/litellm/llms/anthropic/chat.py b/litellm/llms/anthropic/chat.py index a5cee1715..0520506bb 100644 --- a/litellm/llms/anthropic/chat.py +++ b/litellm/llms/anthropic/chat.py @@ -918,7 +918,9 @@ class AnthropicChatCompletion(BaseLLM): headers={}, client=None, ) -> Union[ModelResponse, CustomStreamWrapper]: - async_handler = get_async_httpx_client() + async_handler = get_async_httpx_client( + llm_provider=litellm.LlmProviders.ANTHROPIC + ) try: response = await async_handler.post( diff --git a/litellm/llms/bedrock/chat.py b/litellm/llms/bedrock/chat.py index 80a546bd5..96c47c20c 100644 --- a/litellm/llms/bedrock/chat.py +++ b/litellm/llms/bedrock/chat.py @@ -209,7 +209,9 @@ async def make_call( ): try: if client is None: - client = get_async_httpx_client() # Create a new client if none provided + client = get_async_httpx_client( + llm_provider=litellm.LlmProviders.BEDROCK + ) # Create a new client if none provided response = await client.post( api_base, @@ -1041,7 +1043,7 @@ class BedrockLLM(BaseAWSLLM): if isinstance(timeout, float) or isinstance(timeout, int): timeout = httpx.Timeout(timeout) _params["timeout"] = timeout - client = get_async_httpx_client(_params) # type: ignore + client = get_async_httpx_client(_params, llm_provider=litellm.LlmProviders.BEDROCK) # type: ignore else: client = client # type: ignore @@ -1498,7 +1500,7 @@ class BedrockConverseLLM(BaseAWSLLM): if isinstance(timeout, float) or isinstance(timeout, int): timeout = httpx.Timeout(timeout) _params["timeout"] = timeout - client = get_async_httpx_client(_params) # type: ignore + client = get_async_httpx_client(_params, llm_provider=litellm.LlmProviders.BEDROCK) # type: ignore else: client = client # type: ignore diff --git a/litellm/llms/bedrock/embed/embedding.py b/litellm/llms/bedrock/embed/embedding.py index ddc8f2162..b0fbaa5e0 100644 --- a/litellm/llms/bedrock/embed/embedding.py +++ b/litellm/llms/bedrock/embed/embedding.py @@ -130,7 +130,7 @@ class BedrockEmbedding(BaseAWSLLM): if isinstance(timeout, float) or isinstance(timeout, int): timeout = httpx.Timeout(timeout) _params["timeout"] = timeout - client = get_async_httpx_client(_params) # type: ignore + client = get_async_httpx_client(_params, llm_provider=litellm.LlmProviders.BEDROCK) # type: ignore else: client = client diff --git a/litellm/llms/cohere/rerank.py b/litellm/llms/cohere/rerank.py index 137624ade..64afcae4f 100644 --- a/litellm/llms/cohere/rerank.py +++ b/litellm/llms/cohere/rerank.py @@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional, Union import httpx from pydantic import BaseModel +import litellm from litellm.llms.base import BaseLLM from litellm.llms.custom_httpx.http_handler import ( _get_httpx_client, @@ -65,7 +66,7 @@ class CohereRerank(BaseLLM): api_key: str, api_base: str, ) -> RerankResponse: - client = get_async_httpx_client() + client = get_async_httpx_client(llm_provider=litellm.LlmProviders.COHERE) response = await client.post( api_base, diff --git a/litellm/llms/custom_httpx/http_handler.py b/litellm/llms/custom_httpx/http_handler.py index babab8940..ed5773be8 100644 --- a/litellm/llms/custom_httpx/http_handler.py +++ b/litellm/llms/custom_httpx/http_handler.py @@ -7,6 +7,8 @@ import httpx import litellm +from .types import httpxSpecialProvider + try: from litellm._version import version except: @@ -378,7 +380,10 @@ class HTTPHandler: pass -def get_async_httpx_client(params: Optional[dict] = None) -> AsyncHTTPHandler: +def get_async_httpx_client( + llm_provider: Union[litellm.LlmProviders, httpxSpecialProvider], + params: Optional[dict] = None, +) -> AsyncHTTPHandler: """ Retrieves the async HTTP client from the cache If not present, creates a new client @@ -393,7 +398,7 @@ def get_async_httpx_client(params: Optional[dict] = None) -> AsyncHTTPHandler: except Exception: pass - _cache_key_name = "async_httpx_client" + _params_key_name + _cache_key_name = "async_httpx_client" + _params_key_name + llm_provider if _cache_key_name in litellm.in_memory_llm_clients_cache: return litellm.in_memory_llm_clients_cache[_cache_key_name] diff --git a/litellm/llms/custom_httpx/types.py b/litellm/llms/custom_httpx/types.py new file mode 100644 index 000000000..dc0958118 --- /dev/null +++ b/litellm/llms/custom_httpx/types.py @@ -0,0 +1,10 @@ +from enum import Enum + +import litellm + + +class httpxSpecialProvider(str, Enum): + LoggingCallback = "logging_callback" + GuardrailCallback = "guardrail_callback" + Caching = "caching" + Oauth2Check = "oauth2_check" diff --git a/litellm/llms/sagemaker/sagemaker.py b/litellm/llms/sagemaker/sagemaker.py index 733826aee..b124308eb 100644 --- a/litellm/llms/sagemaker/sagemaker.py +++ b/litellm/llms/sagemaker/sagemaker.py @@ -565,8 +565,8 @@ class SagemakerLLM(BaseAWSLLM): ): try: if client is None: - client = ( - get_async_httpx_client() + client = get_async_httpx_client( + llm_provider=litellm.LlmProviders.SAGEMAKER ) # Create a new client if none provided response = await client.post( api_base, @@ -673,7 +673,9 @@ class SagemakerLLM(BaseAWSLLM): model_id: Optional[str], ): timeout = 300.0 - async_handler = get_async_httpx_client() + async_handler = get_async_httpx_client( + llm_provider=litellm.LlmProviders.SAGEMAKER + ) async_transform_prompt = asyncify(self._transform_prompt) diff --git a/litellm/llms/togetherai/rerank.py b/litellm/llms/togetherai/rerank.py index 17e4ce7e5..ea57c46c7 100644 --- a/litellm/llms/togetherai/rerank.py +++ b/litellm/llms/togetherai/rerank.py @@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional, Union import httpx from pydantic import BaseModel +import litellm from litellm.llms.base import BaseLLM from litellm.llms.custom_httpx.http_handler import ( _get_httpx_client, @@ -77,7 +78,9 @@ class TogetherAIRerank(BaseLLM): request_data_dict: Dict[str, Any], api_key: str, ) -> RerankResponse: - client = get_async_httpx_client() # Use async client + client = get_async_httpx_client( + llm_provider=litellm.LlmProviders.TOGETHER_AI + ) # Use async client response = await client.post( "https://api.together.xyz/v1/rerank", diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py index d2456fd1f..7d3c09271 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py @@ -1293,7 +1293,9 @@ class VertexLLM(BaseLLM): _async_client_params = {} if timeout: _async_client_params["timeout"] = timeout - client = get_async_httpx_client(params=_async_client_params) + client = get_async_httpx_client( + params=_async_client_params, llm_provider=litellm.LlmProviders.VERTEX_AI + ) ## LOGGING logging_obj.pre_call( input=messages, diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/text_to_speech/text_to_speech_handler.py b/litellm/llms/vertex_ai_and_google_ai_studio/text_to_speech/text_to_speech_handler.py index dbb38a715..f3837d94b 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/text_to_speech/text_to_speech_handler.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/text_to_speech/text_to_speech_handler.py @@ -4,6 +4,7 @@ from typing import Any, Coroutine, Literal, Optional, TypedDict, Union import httpx +import litellm from litellm._logging import verbose_logger from litellm.llms.base import BaseLLM from litellm.llms.custom_httpx.http_handler import ( @@ -178,7 +179,9 @@ class VertexTextToSpeechAPI(VertexLLM): ) -> HttpxBinaryResponseContent: import base64 - async_handler = get_async_httpx_client() + async_handler = get_async_httpx_client( + llm_provider=litellm.LlmProviders.VERTEX_AI + ) response = await async_handler.post( url=url, diff --git a/litellm/proxy/auth/oauth2_check.py b/litellm/proxy/auth/oauth2_check.py index f2e175632..85a112ef1 100644 --- a/litellm/proxy/auth/oauth2_check.py +++ b/litellm/proxy/auth/oauth2_check.py @@ -20,7 +20,10 @@ async def check_oauth2_token(token: str) -> UserAPIKeyAuth: import httpx from litellm._logging import verbose_proxy_logger - from litellm.llms.custom_httpx.http_handler import get_async_httpx_client + from litellm.llms.custom_httpx.http_handler import ( + get_async_httpx_client, + httpxSpecialProvider, + ) from litellm.proxy._types import CommonProxyErrors from litellm.proxy.proxy_server import premium_user @@ -40,7 +43,7 @@ async def check_oauth2_token(token: str) -> UserAPIKeyAuth: if not token_info_endpoint: raise ValueError("OAUTH_TOKEN_INFO_ENDPOINT environment variable is not set") - client = get_async_httpx_client() + client = get_async_httpx_client(llm_provider=httpxSpecialProvider.Oauth2Check) headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} try: diff --git a/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py b/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py index f4e1a345e..4c47aeb86 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py +++ b/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py @@ -34,6 +34,7 @@ from litellm.llms.base_aws_llm import BaseAWSLLM from litellm.llms.custom_httpx.http_handler import ( AsyncHTTPHandler, get_async_httpx_client, + httpxSpecialProvider, ) from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata @@ -55,7 +56,9 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM): guardrailVersion: Optional[str] = None, **kwargs, ): - self.async_handler = get_async_httpx_client() + self.async_handler = get_async_httpx_client( + llm_provider=httpxSpecialProvider.GuardrailCallback + ) self.guardrailIdentifier = guardrailIdentifier self.guardrailVersion = guardrailVersion diff --git a/litellm/proxy/guardrails/guardrail_hooks/presidio.py b/litellm/proxy/guardrails/guardrail_hooks/presidio.py index db22c1872..eb09e5203 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/presidio.py +++ b/litellm/proxy/guardrails/guardrail_hooks/presidio.py @@ -22,7 +22,6 @@ import litellm # noqa: E401 from litellm._logging import verbose_proxy_logger from litellm.caching import DualCache from litellm.integrations.custom_guardrail import CustomGuardrail -from litellm.llms.custom_httpx.http_handler import get_async_httpx_client from litellm.proxy._types import UserAPIKeyAuth from litellm.utils import ( EmbeddingResponse,