diff --git a/.circleci/config.yml b/.circleci/config.yml index b0a369a35..db7c4ef5b 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -771,6 +771,7 @@ jobs: - run: python ./tests/code_coverage_tests/litellm_logging_code_coverage.py - run: python ./tests/documentation_tests/test_env_keys.py - run: python ./tests/documentation_tests/test_api_docs.py + - run: python ./tests/code_coverage_tests/ensure_async_clients_test.py - run: helm lint ./deploy/charts/litellm-helm db_migration_disable_update_check: diff --git a/litellm/__init__.py b/litellm/__init__.py index 9a8c56a56..c978b24ee 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -133,7 +133,7 @@ use_client: bool = False ssl_verify: Union[str, bool] = True ssl_certificate: Optional[str] = None disable_streaming_logging: bool = False -in_memory_llm_clients_cache: dict = {} +in_memory_llm_clients_cache: InMemoryCache = InMemoryCache() safe_memory_mode: bool = False enable_azure_ad_token_refresh: Optional[bool] = False ### DEFAULT AZURE API VERSION ### diff --git a/litellm/llms/AzureOpenAI/azure.py b/litellm/llms/AzureOpenAI/azure.py index 39dea14e2..f6a1790b6 100644 --- a/litellm/llms/AzureOpenAI/azure.py +++ b/litellm/llms/AzureOpenAI/azure.py @@ -12,7 +12,11 @@ from typing_extensions import overload import litellm from litellm.caching.caching import DualCache from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj -from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + HTTPHandler, + get_async_httpx_client, +) from litellm.types.utils import EmbeddingResponse from litellm.utils import ( CustomStreamWrapper, @@ -977,7 +981,10 @@ class AzureChatCompletion(BaseLLM): else: _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0) - async_handler = AsyncHTTPHandler(**_params) # type: ignore + async_handler = get_async_httpx_client( + llm_provider=litellm.LlmProviders.AZURE, + params=_params, + ) else: async_handler = client # type: ignore diff --git a/litellm/llms/OpenAI/openai.py b/litellm/llms/OpenAI/openai.py index 7d701d26c..057340b51 100644 --- a/litellm/llms/OpenAI/openai.py +++ b/litellm/llms/OpenAI/openai.py @@ -18,6 +18,7 @@ import litellm from litellm import LlmProviders from litellm._logging import verbose_logger from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS from litellm.secret_managers.main import get_secret_str from litellm.types.utils import ProviderField from litellm.utils import ( @@ -562,8 +563,9 @@ class OpenAIChatCompletion(BaseLLM): _cache_key = f"hashed_api_key={hashed_api_key},api_base={api_base},timeout={timeout},max_retries={max_retries},organization={organization},is_async={is_async}" - if _cache_key in litellm.in_memory_llm_clients_cache: - return litellm.in_memory_llm_clients_cache[_cache_key] + _cached_client = litellm.in_memory_llm_clients_cache.get_cache(_cache_key) + if _cached_client: + return _cached_client if is_async: _new_client: Union[OpenAI, AsyncOpenAI] = AsyncOpenAI( api_key=api_key, @@ -584,7 +586,11 @@ class OpenAIChatCompletion(BaseLLM): ) ## SAVE CACHE KEY - litellm.in_memory_llm_clients_cache[_cache_key] = _new_client + litellm.in_memory_llm_clients_cache.set_cache( + key=_cache_key, + value=_new_client, + ttl=_DEFAULT_TTL_FOR_HTTPX_CLIENTS, + ) return _new_client else: diff --git a/litellm/llms/anthropic/completion.py b/litellm/llms/anthropic/completion.py index 89a50db6a..dc06401d6 100644 --- a/litellm/llms/anthropic/completion.py +++ b/litellm/llms/anthropic/completion.py @@ -13,7 +13,11 @@ import httpx import requests import litellm -from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + HTTPHandler, + get_async_httpx_client, +) from litellm.utils import CustomStreamWrapper, ModelResponse, Usage from ..base import BaseLLM @@ -162,7 +166,10 @@ class AnthropicTextCompletion(BaseLLM): client=None, ): if client is None: - client = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0)) + client = get_async_httpx_client( + llm_provider=litellm.LlmProviders.ANTHROPIC, + params={"timeout": httpx.Timeout(timeout=600.0, connect=5.0)}, + ) response = await client.post(api_base, headers=headers, data=json.dumps(data)) @@ -198,7 +205,10 @@ class AnthropicTextCompletion(BaseLLM): client=None, ): if client is None: - client = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0)) + client = get_async_httpx_client( + llm_provider=litellm.LlmProviders.ANTHROPIC, + params={"timeout": httpx.Timeout(timeout=600.0, connect=5.0)}, + ) response = await client.post(api_base, headers=headers, data=json.dumps(data)) diff --git a/litellm/llms/azure_ai/embed/handler.py b/litellm/llms/azure_ai/embed/handler.py index 638a77479..2946a84dd 100644 --- a/litellm/llms/azure_ai/embed/handler.py +++ b/litellm/llms/azure_ai/embed/handler.py @@ -74,7 +74,10 @@ class AzureAIEmbedding(OpenAIChatCompletion): client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, ) -> EmbeddingResponse: if client is None or not isinstance(client, AsyncHTTPHandler): - client = AsyncHTTPHandler(timeout=timeout, concurrent_limit=1) + client = get_async_httpx_client( + llm_provider=litellm.LlmProviders.AZURE_AI, + params={"timeout": timeout}, + ) url = "{}/images/embeddings".format(api_base) diff --git a/litellm/llms/clarifai.py b/litellm/llms/clarifai.py index 2011c0bee..61d445423 100644 --- a/litellm/llms/clarifai.py +++ b/litellm/llms/clarifai.py @@ -9,7 +9,10 @@ import httpx import requests import litellm -from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + get_async_httpx_client, +) from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage from .prompt_templates.factory import custom_prompt, prompt_factory @@ -185,7 +188,10 @@ async def async_completion( headers={}, ): - async_handler = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0)) + async_handler = get_async_httpx_client( + llm_provider=litellm.LlmProviders.CLARIFAI, + params={"timeout": 600.0}, + ) response = await async_handler.post( url=model, headers=headers, data=json.dumps(data) ) diff --git a/litellm/llms/cohere/embed/handler.py b/litellm/llms/cohere/embed/handler.py index 95cbec225..5b224c375 100644 --- a/litellm/llms/cohere/embed/handler.py +++ b/litellm/llms/cohere/embed/handler.py @@ -11,7 +11,11 @@ import requests # type: ignore import litellm from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj -from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + HTTPHandler, + get_async_httpx_client, +) from litellm.types.llms.bedrock import CohereEmbeddingRequest from litellm.utils import Choices, Message, ModelResponse, Usage @@ -71,7 +75,10 @@ async def async_embedding( ) ## COMPLETION CALL if client is None: - client = AsyncHTTPHandler(concurrent_limit=1, timeout=timeout) + client = get_async_httpx_client( + llm_provider=litellm.LlmProviders.COHERE, + params={"timeout": timeout}, + ) try: response = await client.post(api_base, headers=headers, data=json.dumps(data)) diff --git a/litellm/llms/custom_httpx/http_handler.py b/litellm/llms/custom_httpx/http_handler.py index 020af7e90..f1b78ea63 100644 --- a/litellm/llms/custom_httpx/http_handler.py +++ b/litellm/llms/custom_httpx/http_handler.py @@ -7,6 +7,7 @@ import httpx from httpx import USE_CLIENT_DEFAULT, AsyncHTTPTransport, HTTPTransport import litellm +from litellm.caching import InMemoryCache from .types import httpxSpecialProvider @@ -26,6 +27,7 @@ headers = { # https://www.python-httpx.org/advanced/timeouts _DEFAULT_TIMEOUT = httpx.Timeout(timeout=5.0, connect=5.0) +_DEFAULT_TTL_FOR_HTTPX_CLIENTS = 3600 # 1 hour, re-use the same httpx client for 1 hour class AsyncHTTPHandler: @@ -476,8 +478,9 @@ def get_async_httpx_client( pass _cache_key_name = "async_httpx_client" + _params_key_name + llm_provider - if _cache_key_name in litellm.in_memory_llm_clients_cache: - return litellm.in_memory_llm_clients_cache[_cache_key_name] + _cached_client = litellm.in_memory_llm_clients_cache.get_cache(_cache_key_name) + if _cached_client: + return _cached_client if params is not None: _new_client = AsyncHTTPHandler(**params) @@ -485,7 +488,11 @@ def get_async_httpx_client( _new_client = AsyncHTTPHandler( timeout=httpx.Timeout(timeout=600.0, connect=5.0) ) - litellm.in_memory_llm_clients_cache[_cache_key_name] = _new_client + litellm.in_memory_llm_clients_cache.set_cache( + key=_cache_key_name, + value=_new_client, + ttl=_DEFAULT_TTL_FOR_HTTPX_CLIENTS, + ) return _new_client @@ -505,13 +512,18 @@ def _get_httpx_client(params: Optional[dict] = None) -> HTTPHandler: pass _cache_key_name = "httpx_client" + _params_key_name - if _cache_key_name in litellm.in_memory_llm_clients_cache: - return litellm.in_memory_llm_clients_cache[_cache_key_name] + _cached_client = litellm.in_memory_llm_clients_cache.get_cache(_cache_key_name) + if _cached_client: + return _cached_client if params is not None: _new_client = HTTPHandler(**params) else: _new_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0)) - litellm.in_memory_llm_clients_cache[_cache_key_name] = _new_client + litellm.in_memory_llm_clients_cache.set_cache( + key=_cache_key_name, + value=_new_client, + ttl=_DEFAULT_TTL_FOR_HTTPX_CLIENTS, + ) return _new_client diff --git a/litellm/llms/databricks/chat.py b/litellm/llms/databricks/chat.py index 79e885646..e752f4d98 100644 --- a/litellm/llms/databricks/chat.py +++ b/litellm/llms/databricks/chat.py @@ -393,7 +393,10 @@ class DatabricksChatCompletion(BaseLLM): if timeout is None: timeout = httpx.Timeout(timeout=600.0, connect=5.0) - self.async_handler = AsyncHTTPHandler(timeout=timeout) + self.async_handler = get_async_httpx_client( + llm_provider=litellm.LlmProviders.DATABRICKS, + params={"timeout": timeout}, + ) try: response = await self.async_handler.post( @@ -610,7 +613,10 @@ class DatabricksChatCompletion(BaseLLM): response = None try: if client is None or isinstance(client, AsyncHTTPHandler): - self.async_client = AsyncHTTPHandler(timeout=timeout) # type: ignore + self.async_client = get_async_httpx_client( + llm_provider=litellm.LlmProviders.DATABRICKS, + params={"timeout": timeout}, + ) else: self.async_client = client diff --git a/litellm/llms/fine_tuning_apis/vertex_ai.py b/litellm/llms/fine_tuning_apis/vertex_ai.py index 11d052191..fd418103e 100644 --- a/litellm/llms/fine_tuning_apis/vertex_ai.py +++ b/litellm/llms/fine_tuning_apis/vertex_ai.py @@ -5,9 +5,14 @@ from typing import Any, Coroutine, Literal, Optional, Union import httpx from openai.types.fine_tuning.fine_tuning_job import FineTuningJob, Hyperparameters +import litellm from litellm._logging import verbose_logger from litellm.llms.base import BaseLLM -from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + HTTPHandler, + get_async_httpx_client, +) from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( VertexLLM, ) @@ -26,8 +31,9 @@ class VertexFineTuningAPI(VertexLLM): def __init__(self) -> None: super().__init__() - self.async_handler = AsyncHTTPHandler( - timeout=httpx.Timeout(timeout=600.0, connect=5.0) + self.async_handler = get_async_httpx_client( + llm_provider=litellm.LlmProviders.VERTEX_AI, + params={"timeout": 600.0}, ) def convert_response_created_at(self, response: ResponseTuningJob): diff --git a/litellm/llms/huggingface_restapi.py b/litellm/llms/huggingface_restapi.py index 907d72a60..8b45f1ae7 100644 --- a/litellm/llms/huggingface_restapi.py +++ b/litellm/llms/huggingface_restapi.py @@ -263,7 +263,11 @@ def get_hf_task_for_model(model: str) -> Tuple[hf_tasks, str]: return "text-generation-inference", model # default to tgi -from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + HTTPHandler, + get_async_httpx_client, +) def get_hf_task_embedding_for_model( @@ -301,7 +305,9 @@ async def async_get_hf_task_embedding_for_model( task_type, hf_tasks_embeddings ) ) - http_client = AsyncHTTPHandler(concurrent_limit=1) + http_client = get_async_httpx_client( + llm_provider=litellm.LlmProviders.HUGGINGFACE, + ) model_info = await http_client.get(url=api_base) @@ -1067,7 +1073,9 @@ class Huggingface(BaseLLM): ) ## COMPLETION CALL if client is None: - client = AsyncHTTPHandler(concurrent_limit=1) + client = get_async_httpx_client( + llm_provider=litellm.LlmProviders.HUGGINGFACE, + ) response = await client.post(api_base, headers=headers, data=json.dumps(data)) diff --git a/litellm/llms/openai_like/embedding/handler.py b/litellm/llms/openai_like/embedding/handler.py index 7ddf43cb8..e786b5db8 100644 --- a/litellm/llms/openai_like/embedding/handler.py +++ b/litellm/llms/openai_like/embedding/handler.py @@ -45,7 +45,10 @@ class OpenAILikeEmbeddingHandler(OpenAILikeBase): response = None try: if client is None or isinstance(client, AsyncHTTPHandler): - self.async_client = AsyncHTTPHandler(timeout=timeout) # type: ignore + self.async_client = get_async_httpx_client( + llm_provider=litellm.LlmProviders.OPENAI, + params={"timeout": timeout}, + ) else: self.async_client = client diff --git a/litellm/llms/predibase.py b/litellm/llms/predibase.py index 96796f9dc..e80964551 100644 --- a/litellm/llms/predibase.py +++ b/litellm/llms/predibase.py @@ -19,7 +19,10 @@ import litellm.litellm_core_utils import litellm.litellm_core_utils.litellm_logging from litellm import verbose_logger from litellm.litellm_core_utils.core_helpers import map_finish_reason -from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + get_async_httpx_client, +) from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage from .base import BaseLLM @@ -549,7 +552,10 @@ class PredibaseChatCompletion(BaseLLM): headers={}, ) -> ModelResponse: - async_handler = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=timeout)) + async_handler = get_async_httpx_client( + llm_provider=litellm.LlmProviders.PREDIBASE, + params={"timeout": timeout}, + ) try: response = await async_handler.post( api_base, headers=headers, data=json.dumps(data) diff --git a/litellm/llms/replicate.py b/litellm/llms/replicate.py index 094110234..2e9bbb333 100644 --- a/litellm/llms/replicate.py +++ b/litellm/llms/replicate.py @@ -9,7 +9,10 @@ import httpx # type: ignore import requests # type: ignore import litellm -from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + get_async_httpx_client, +) from litellm.utils import CustomStreamWrapper, ModelResponse, Usage from .prompt_templates.factory import custom_prompt, prompt_factory @@ -325,7 +328,7 @@ def handle_prediction_response_streaming(prediction_url, api_token, print_verbos async def async_handle_prediction_response_streaming( prediction_url, api_token, print_verbose ): - http_handler = AsyncHTTPHandler(concurrent_limit=1) + http_handler = get_async_httpx_client(llm_provider=litellm.LlmProviders.REPLICATE) previous_output = "" output_string = "" @@ -560,7 +563,9 @@ async def async_completion( logging_obj, print_verbose, ) -> Union[ModelResponse, CustomStreamWrapper]: - http_handler = AsyncHTTPHandler(concurrent_limit=1) + http_handler = get_async_httpx_client( + llm_provider=litellm.LlmProviders.REPLICATE, + ) prediction_url = await async_start_prediction( version_id, input_data, diff --git a/litellm/llms/text_completion_codestral.py b/litellm/llms/text_completion_codestral.py index 21582d26c..d3c1ae3cb 100644 --- a/litellm/llms/text_completion_codestral.py +++ b/litellm/llms/text_completion_codestral.py @@ -18,7 +18,10 @@ import litellm from litellm import verbose_logger from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging -from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + get_async_httpx_client, +) from litellm.types.llms.databricks import GenericStreamingChunk from litellm.utils import ( Choices, @@ -479,8 +482,9 @@ class CodestralTextCompletion(BaseLLM): headers={}, ) -> TextCompletionResponse: - async_handler = AsyncHTTPHandler( - timeout=httpx.Timeout(timeout=timeout), concurrent_limit=1 + async_handler = get_async_httpx_client( + llm_provider=litellm.LlmProviders.TEXT_COMPLETION_CODESTRAL, + params={"timeout": timeout}, ) try: diff --git a/litellm/llms/triton.py b/litellm/llms/triton.py index be4179ccc..efd0d0a2d 100644 --- a/litellm/llms/triton.py +++ b/litellm/llms/triton.py @@ -8,7 +8,11 @@ import httpx # type: ignore import requests # type: ignore import litellm -from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + HTTPHandler, + get_async_httpx_client, +) from litellm.utils import ( Choices, CustomStreamWrapper, @@ -50,8 +54,8 @@ class TritonChatCompletion(BaseLLM): logging_obj: Any, api_key: Optional[str] = None, ) -> EmbeddingResponse: - async_handler = AsyncHTTPHandler( - timeout=httpx.Timeout(timeout=600.0, connect=5.0) + async_handler = get_async_httpx_client( + llm_provider=litellm.LlmProviders.TRITON, params={"timeout": 600.0} ) response = await async_handler.post(url=api_base, data=json.dumps(data)) @@ -261,7 +265,9 @@ class TritonChatCompletion(BaseLLM): model_response, type_of_model, ) -> ModelResponse: - handler = AsyncHTTPHandler() + handler = get_async_httpx_client( + llm_provider=litellm.LlmProviders.TRITON, params={"timeout": 600.0} + ) if stream: return self._ahandle_stream( # type: ignore handler, api_base, data_for_triton, model, logging_obj diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py index 39c63dbb3..f2fc599ed 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py @@ -1026,7 +1026,9 @@ async def make_call( logging_obj, ): if client is None: - client = AsyncHTTPHandler() # Create a new client if none provided + client = get_async_httpx_client( + llm_provider=litellm.LlmProviders.VERTEX_AI, + ) try: response = await client.post(api_base, headers=headers, data=data, stream=True) diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/gemini_embeddings/batch_embed_content_handler.py b/litellm/llms/vertex_ai_and_google_ai_studio/gemini_embeddings/batch_embed_content_handler.py index 314e129c2..8e2d1f39a 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/gemini_embeddings/batch_embed_content_handler.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/gemini_embeddings/batch_embed_content_handler.py @@ -7,8 +7,13 @@ from typing import Any, List, Literal, Optional, Union import httpx +import litellm from litellm import EmbeddingResponse -from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + HTTPHandler, + get_async_httpx_client, +) from litellm.types.llms.openai import EmbeddingInput from litellm.types.llms.vertex_ai import ( VertexAIBatchEmbeddingsRequestBody, @@ -150,7 +155,10 @@ class GoogleBatchEmbeddings(VertexLLM): else: _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0) - async_handler: AsyncHTTPHandler = AsyncHTTPHandler(**_params) # type: ignore + async_handler: AsyncHTTPHandler = get_async_httpx_client( + llm_provider=litellm.LlmProviders.VERTEX_AI, + params={"timeout": timeout}, + ) else: async_handler = client # type: ignore diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/image_generation/image_generation_handler.py b/litellm/llms/vertex_ai_and_google_ai_studio/image_generation/image_generation_handler.py index 1531464c8..6cb5771e6 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/image_generation/image_generation_handler.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/image_generation/image_generation_handler.py @@ -5,7 +5,11 @@ import httpx from openai.types.image import Image import litellm -from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + HTTPHandler, + get_async_httpx_client, +) from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( VertexLLM, ) @@ -156,7 +160,10 @@ class VertexImageGeneration(VertexLLM): else: _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0) - self.async_handler = AsyncHTTPHandler(**_params) # type: ignore + self.async_handler = get_async_httpx_client( + llm_provider=litellm.LlmProviders.VERTEX_AI, + params={"timeout": timeout}, + ) else: self.async_handler = client # type: ignore diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/multimodal_embeddings/embedding_handler.py b/litellm/llms/vertex_ai_and_google_ai_studio/multimodal_embeddings/embedding_handler.py index d8af891b0..27b77fdd9 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/multimodal_embeddings/embedding_handler.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/multimodal_embeddings/embedding_handler.py @@ -5,7 +5,11 @@ import httpx import litellm from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj -from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + HTTPHandler, + get_async_httpx_client, +) from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( VertexAIError, VertexLLM, @@ -172,7 +176,10 @@ class VertexMultimodalEmbedding(VertexLLM): if isinstance(timeout, float) or isinstance(timeout, int): timeout = httpx.Timeout(timeout) _params["timeout"] = timeout - client = AsyncHTTPHandler(**_params) # type: ignore + client = get_async_httpx_client( + llm_provider=litellm.LlmProviders.VERTEX_AI, + params={"timeout": timeout}, + ) else: client = client # type: ignore diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_non_gemini.py b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_non_gemini.py index 80295ec40..829bf6528 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_non_gemini.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_non_gemini.py @@ -14,6 +14,7 @@ from pydantic import BaseModel import litellm from litellm._logging import verbose_logger from litellm.litellm_core_utils.core_helpers import map_finish_reason +from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS from litellm.llms.prompt_templates.factory import ( convert_to_anthropic_image_obj, convert_to_gemini_tool_call_invoke, @@ -93,11 +94,15 @@ def _get_client_cache_key( def _get_client_from_cache(client_cache_key: str): - return litellm.in_memory_llm_clients_cache.get(client_cache_key, None) + return litellm.in_memory_llm_clients_cache.get_cache(client_cache_key) def _set_client_in_cache(client_cache_key: str, vertex_llm_model: Any): - litellm.in_memory_llm_clients_cache[client_cache_key] = vertex_llm_model + litellm.in_memory_llm_clients_cache.set_cache( + key=client_cache_key, + value=vertex_llm_model, + ttl=_DEFAULT_TTL_FOR_HTTPX_CLIENTS, + ) def completion( # noqa: PLR0915 diff --git a/litellm/llms/watsonx/completion/handler.py b/litellm/llms/watsonx/completion/handler.py index fda25ba0f..9618f6342 100644 --- a/litellm/llms/watsonx/completion/handler.py +++ b/litellm/llms/watsonx/completion/handler.py @@ -24,7 +24,10 @@ import httpx # type: ignore import requests # type: ignore import litellm -from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + get_async_httpx_client, +) from litellm.secret_managers.main import get_secret_str from litellm.types.llms.watsonx import WatsonXAIEndpoint from litellm.utils import EmbeddingResponse, ModelResponse, Usage, map_finish_reason @@ -710,10 +713,13 @@ class RequestManager: if stream: request_params["stream"] = stream try: - self.async_handler = AsyncHTTPHandler( - timeout=httpx.Timeout( - timeout=request_params.pop("timeout", 600.0), connect=5.0 - ), + self.async_handler = get_async_httpx_client( + llm_provider=litellm.LlmProviders.WATSONX, + params={ + "timeout": httpx.Timeout( + timeout=request_params.pop("timeout", 600.0), connect=5.0 + ), + }, ) if "json" in request_params: request_params["data"] = json.dumps(request_params.pop("json", {})) diff --git a/tests/code_coverage_tests/ensure_async_clients_test.py b/tests/code_coverage_tests/ensure_async_clients_test.py new file mode 100644 index 000000000..a509e5509 --- /dev/null +++ b/tests/code_coverage_tests/ensure_async_clients_test.py @@ -0,0 +1,88 @@ +import ast +import os + +ALLOWED_FILES = [ + # local files + "../../litellm/__init__.py", + "../../litellm/llms/custom_httpx/http_handler.py", + # when running on ci/cd + "./litellm/__init__.py", + "./litellm/llms/custom_httpx/http_handler.py", +] + +warning_msg = "this is a serious violation that can impact latency. Creating Async clients per request can add +500ms per request" + + +def check_for_async_http_handler(file_path): + """ + Checks if AsyncHttpHandler is instantiated in the given file. + Returns a list of line numbers where AsyncHttpHandler is used. + """ + print("..checking file=", file_path) + if file_path in ALLOWED_FILES: + return [] + with open(file_path, "r") as file: + try: + tree = ast.parse(file.read()) + except SyntaxError: + print(f"Warning: Syntax error in file {file_path}") + return [] + + violations = [] + target_names = [ + "AsyncHttpHandler", + "AsyncHTTPHandler", + "AsyncClient", + "httpx.AsyncClient", + ] # Add variations here + for node in ast.walk(tree): + if isinstance(node, ast.Call): + if isinstance(node.func, ast.Name) and node.func.id.lower() in [ + name.lower() for name in target_names + ]: + raise ValueError( + f"found violation in file {file_path} line: {node.lineno}. Please use `get_async_httpx_client` instead. {warning_msg}" + ) + return violations + + +def scan_directory_for_async_handler(base_dir): + """ + Scans all Python files in the directory tree for AsyncHttpHandler usage. + Returns a dict of files and line numbers where violations were found. + """ + violations = {} + + for root, _, files in os.walk(base_dir): + for file in files: + if file.endswith(".py"): + file_path = os.path.join(root, file) + file_violations = check_for_async_http_handler(file_path) + if file_violations: + violations[file_path] = file_violations + + return violations + + +def test_no_async_http_handler_usage(): + """ + Test to ensure AsyncHttpHandler is not used anywhere in the codebase. + """ + base_dir = "./litellm" # Adjust this path as needed + + # base_dir = "../../litellm" # LOCAL TESTING + violations = scan_directory_for_async_handler(base_dir) + + if violations: + violation_messages = [] + for file_path, line_numbers in violations.items(): + violation_messages.append( + f"Found AsyncHttpHandler in {file_path} at lines: {line_numbers}" + ) + raise AssertionError( + "AsyncHttpHandler usage detected:\n" + "\n".join(violation_messages) + ) + + +if __name__ == "__main__": + test_no_async_http_handler_usage() diff --git a/tests/image_gen_tests/test_image_generation.py b/tests/image_gen_tests/test_image_generation.py index 692a0e4e9..6605b3e3d 100644 --- a/tests/image_gen_tests/test_image_generation.py +++ b/tests/image_gen_tests/test_image_generation.py @@ -8,6 +8,7 @@ import traceback from dotenv import load_dotenv from openai.types.image import Image +from litellm.caching import InMemoryCache logging.basicConfig(level=logging.DEBUG) load_dotenv() @@ -107,7 +108,7 @@ class TestVertexImageGeneration(BaseImageGenTest): # comment this when running locally load_vertex_ai_credentials() - litellm.in_memory_llm_clients_cache = {} + litellm.in_memory_llm_clients_cache = InMemoryCache() return { "model": "vertex_ai/imagegeneration@006", "vertex_ai_project": "adroit-crow-413218", @@ -118,13 +119,13 @@ class TestVertexImageGeneration(BaseImageGenTest): class TestBedrockSd3(BaseImageGenTest): def get_base_image_generation_call_args(self) -> dict: - litellm.in_memory_llm_clients_cache = {} + litellm.in_memory_llm_clients_cache = InMemoryCache() return {"model": "bedrock/stability.sd3-large-v1:0"} class TestBedrockSd1(BaseImageGenTest): def get_base_image_generation_call_args(self) -> dict: - litellm.in_memory_llm_clients_cache = {} + litellm.in_memory_llm_clients_cache = InMemoryCache() return {"model": "bedrock/stability.sd3-large-v1:0"} @@ -181,7 +182,7 @@ def test_image_generation_azure_dall_e_3(): @pytest.mark.asyncio async def test_aimage_generation_bedrock_with_optional_params(): try: - litellm.in_memory_llm_clients_cache = {} + litellm.in_memory_llm_clients_cache = InMemoryCache() response = await litellm.aimage_generation( prompt="A cute baby sea otter", model="bedrock/stability.stable-diffusion-xl-v1", diff --git a/tests/local_testing/test_alangfuse.py b/tests/local_testing/test_alangfuse.py index 8c69f567b..ec0cb335e 100644 --- a/tests/local_testing/test_alangfuse.py +++ b/tests/local_testing/test_alangfuse.py @@ -12,6 +12,7 @@ sys.path.insert(0, os.path.abspath("../..")) import litellm from litellm import completion +from litellm.caching import InMemoryCache litellm.num_retries = 3 litellm.success_callback = ["langfuse"] @@ -29,15 +30,20 @@ def langfuse_client(): f"{os.environ['LANGFUSE_PUBLIC_KEY']}-{os.environ['LANGFUSE_SECRET_KEY']}" ) # use a in memory langfuse client for testing, RAM util on ci/cd gets too high when we init many langfuse clients - if _langfuse_cache_key in litellm.in_memory_llm_clients_cache: - langfuse_client = litellm.in_memory_llm_clients_cache[_langfuse_cache_key] + + _cached_client = litellm.in_memory_llm_clients_cache.get_cache(_langfuse_cache_key) + if _cached_client: + langfuse_client = _cached_client else: langfuse_client = langfuse.Langfuse( public_key=os.environ["LANGFUSE_PUBLIC_KEY"], secret_key=os.environ["LANGFUSE_SECRET_KEY"], host=None, ) - litellm.in_memory_llm_clients_cache[_langfuse_cache_key] = langfuse_client + litellm.in_memory_llm_clients_cache.set_cache( + key=_langfuse_cache_key, + value=langfuse_client, + ) print("NEW LANGFUSE CLIENT")