diff --git a/litellm/llms/AzureOpenAI/azure.py b/litellm/llms/AzureOpenAI/azure.py index f6a1790b6..24303ef2f 100644 --- a/litellm/llms/AzureOpenAI/azure.py +++ b/litellm/llms/AzureOpenAI/azure.py @@ -1528,7 +1528,8 @@ class AzureChatCompletion(BaseLLM): prompt: Optional[str] = None, ) -> dict: client_session = ( - litellm.aclient_session or httpx.AsyncClient() + litellm.aclient_session + or get_async_httpx_client(llm_provider=litellm.LlmProviders.AZURE).client ) # handle dall-e-2 calls if "gateway.ai.cloudflare.com" in api_base: diff --git a/litellm/llms/custom_httpx/http_handler.py b/litellm/llms/custom_httpx/http_handler.py index f1b78ea63..f5c4f694d 100644 --- a/litellm/llms/custom_httpx/http_handler.py +++ b/litellm/llms/custom_httpx/http_handler.py @@ -8,8 +8,7 @@ from httpx import USE_CLIENT_DEFAULT, AsyncHTTPTransport, HTTPTransport import litellm from litellm.caching import InMemoryCache - -from .types import httpxSpecialProvider +from litellm.types.llms.custom_http import * if TYPE_CHECKING: from litellm import LlmProviders diff --git a/litellm/llms/custom_httpx/types.py b/litellm/llms/custom_httpx/types.py deleted file mode 100644 index 8e6ad0eda..000000000 --- a/litellm/llms/custom_httpx/types.py +++ /dev/null @@ -1,11 +0,0 @@ -from enum import Enum - -import litellm - - -class httpxSpecialProvider(str, Enum): - LoggingCallback = "logging_callback" - GuardrailCallback = "guardrail_callback" - Caching = "caching" - Oauth2Check = "oauth2_check" - SecretManager = "secret_manager" diff --git a/litellm/llms/ollama.py b/litellm/llms/ollama.py index 896b93be5..e9dd2b53f 100644 --- a/litellm/llms/ollama.py +++ b/litellm/llms/ollama.py @@ -14,6 +14,7 @@ import requests # type: ignore import litellm from litellm import verbose_logger +from litellm.llms.custom_httpx.http_handler import get_async_httpx_client from litellm.secret_managers.main import get_secret_str from litellm.types.utils import ModelInfo, ProviderField, StreamingChoices @@ -456,7 +457,10 @@ def ollama_completion_stream(url, data, logging_obj): async def ollama_async_streaming(url, data, model_response, encoding, logging_obj): try: - client = httpx.AsyncClient() + _async_http_client = get_async_httpx_client( + llm_provider=litellm.LlmProviders.OLLAMA + ) + client = _async_http_client.client async with client.stream( url=f"{url}", json=data, method="POST", timeout=litellm.request_timeout ) as response: diff --git a/litellm/llms/ollama_chat.py b/litellm/llms/ollama_chat.py index 536f766e0..ce0df139d 100644 --- a/litellm/llms/ollama_chat.py +++ b/litellm/llms/ollama_chat.py @@ -13,6 +13,7 @@ from pydantic import BaseModel import litellm from litellm import verbose_logger +from litellm.llms.custom_httpx.http_handler import get_async_httpx_client from litellm.types.llms.ollama import OllamaToolCall, OllamaToolCallFunction from litellm.types.llms.openai import ChatCompletionAssistantToolCall from litellm.types.utils import StreamingChoices @@ -445,7 +446,10 @@ async def ollama_async_streaming( url, api_key, data, model_response, encoding, logging_obj ): try: - client = httpx.AsyncClient() + _async_http_client = get_async_httpx_client( + llm_provider=litellm.LlmProviders.OLLAMA + ) + client = _async_http_client.client _request = { "url": f"{url}", "json": data, diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index f60fd0166..0fd174440 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -22,6 +22,7 @@ import litellm from litellm._logging import verbose_proxy_logger from litellm.integrations.custom_logger import CustomLogger from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.llms.custom_httpx.http_handler import get_async_httpx_client from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( ModelResponseIterator, ) @@ -35,6 +36,7 @@ from litellm.proxy._types import ( ) from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.secret_managers.main import get_secret_str +from litellm.types.llms.custom_http import httpxSpecialProvider from .streaming_handler import PassThroughStreamingHandler from .success_handler import PassThroughEndpointLogging @@ -363,8 +365,11 @@ async def pass_through_request( # noqa: PLR0915 data=_parsed_body, call_type="pass_through_endpoint", ) - - async_client = httpx.AsyncClient(timeout=600) + async_client_obj = get_async_httpx_client( + llm_provider=httpxSpecialProvider.PassThroughEndpoint, + params={"timeout": 600}, + ) + async_client = async_client_obj.client litellm_call_id = str(uuid.uuid4()) diff --git a/litellm/secret_managers/aws_secret_manager_v2.py b/litellm/secret_managers/aws_secret_manager_v2.py index 69add6f23..32653f57d 100644 --- a/litellm/secret_managers/aws_secret_manager_v2.py +++ b/litellm/secret_managers/aws_secret_manager_v2.py @@ -31,8 +31,8 @@ from litellm.llms.custom_httpx.http_handler import ( _get_httpx_client, get_async_httpx_client, ) -from litellm.llms.custom_httpx.types import httpxSpecialProvider from litellm.proxy._types import KeyManagementSystem +from litellm.types.llms.custom_http import httpxSpecialProvider class AWSSecretsManagerV2(BaseAWSLLM): diff --git a/litellm/types/llms/custom_http.py b/litellm/types/llms/custom_http.py new file mode 100644 index 000000000..f43daff2a --- /dev/null +++ b/litellm/types/llms/custom_http.py @@ -0,0 +1,20 @@ +from enum import Enum + +import litellm + + +class httpxSpecialProvider(str, Enum): + """ + Httpx Clients can be created for these litellm internal providers + + Example: + - langsmith logging would need a custom async httpx client + - pass through endpoint would need a custom async httpx client + """ + + LoggingCallback = "logging_callback" + GuardrailCallback = "guardrail_callback" + Caching = "caching" + Oauth2Check = "oauth2_check" + SecretManager = "secret_manager" + PassThroughEndpoint = "pass_through_endpoint" diff --git a/tests/code_coverage_tests/ensure_async_clients_test.py b/tests/code_coverage_tests/ensure_async_clients_test.py index a509e5509..0565de9b3 100644 --- a/tests/code_coverage_tests/ensure_async_clients_test.py +++ b/tests/code_coverage_tests/ensure_async_clients_test.py @@ -5,9 +5,19 @@ ALLOWED_FILES = [ # local files "../../litellm/__init__.py", "../../litellm/llms/custom_httpx/http_handler.py", + "../../litellm/router_utils/client_initalization_utils.py", + "../../litellm/llms/custom_httpx/http_handler.py", + "../../litellm/llms/huggingface_restapi.py", + "../../litellm/llms/base.py", + "../../litellm/llms/custom_httpx/httpx_handler.py", # when running on ci/cd "./litellm/__init__.py", "./litellm/llms/custom_httpx/http_handler.py", + "./litellm/router_utils/client_initalization_utils.py", + "./litellm/llms/custom_httpx/http_handler.py", + "./litellm/llms/huggingface_restapi.py", + "./litellm/llms/base.py", + "./litellm/llms/custom_httpx/httpx_handler.py", ] warning_msg = "this is a serious violation that can impact latency. Creating Async clients per request can add +500ms per request" @@ -43,6 +53,19 @@ def check_for_async_http_handler(file_path): raise ValueError( f"found violation in file {file_path} line: {node.lineno}. Please use `get_async_httpx_client` instead. {warning_msg}" ) + # Check for attribute calls like httpx.AsyncClient() + elif isinstance(node.func, ast.Attribute): + full_name = "" + current = node.func + while isinstance(current, ast.Attribute): + full_name = "." + current.attr + full_name + current = current.value + if isinstance(current, ast.Name): + full_name = current.id + full_name + if full_name.lower() in [name.lower() for name in target_names]: + raise ValueError( + f"found violation in file {file_path} line: {node.lineno}. Please use `get_async_httpx_client` instead. {warning_msg}" + ) return violations