Merge pull request #5620 from BerriAI/litellm_use_helper_to_get_httpx_clients

[Feat-Perf] Use common helper to get async httpx clients for all providers
This commit is contained in:
Ishaan Jaff 2024-09-10 15:03:11 -07:00 committed by GitHub
commit 2d4be4cf1f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
27 changed files with 221 additions and 117 deletions

View file

@ -107,9 +107,9 @@ export LITELLM_SALT_KEY="sk-1234"
|--------------|-------|
| Avg latency | `50ms` |
| Median latency | `51ms` |
| `/chat/completions` Requests/second | `35` |
| `/chat/completions` Requests/minute | `2100` |
| `/chat/completions` Requests/hour | `126K` |
| `/chat/completions` Requests/second | `100` |
| `/chat/completions` Requests/minute | `6000` |
| `/chat/completions` Requests/hour | `360K` |
### Verifying Debugging logs are off

View file

@ -26,7 +26,10 @@ from typing import List
from datetime import datetime
import aiohttp, asyncio
from litellm._logging import verbose_proxy_logger
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
import httpx
import json
from litellm.types.guardrails import GuardrailEventHooks
@ -40,8 +43,8 @@ class AporiaGuardrail(CustomGuardrail):
def __init__(
self, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs
):
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
self.async_handler = get_async_httpx_client(
llm_provider=httpxSpecialProvider.GuardrailCallback
)
self.aporia_api_key = api_key or os.environ["APORIO_API_KEY"]
self.aporia_api_base = api_base or os.environ["APORIO_API_BASE"]

View file

@ -24,6 +24,7 @@ from litellm.proxy._types import (
)
import httpx
import dotenv
from enum import Enum
litellm_mode = os.getenv("LITELLM_MODE", "DEV") # "PRODUCTION", "DEV"
if litellm_mode == "DEV":
@ -678,62 +679,66 @@ model_list = (
+ gemini_models
)
provider_list: List = [
"openai",
"custom_openai",
"text-completion-openai",
"cohere",
"cohere_chat",
"clarifai",
"anthropic",
"replicate",
"huggingface",
"together_ai",
"openrouter",
"vertex_ai",
"vertex_ai_beta",
"palm",
"gemini",
"ai21",
"baseten",
"azure",
"azure_text",
"azure_ai",
"sagemaker",
"sagemaker_chat",
"bedrock",
"vllm",
"nlp_cloud",
"petals",
"oobabooga",
"ollama",
"ollama_chat",
"deepinfra",
"perplexity",
"anyscale",
"mistral",
"groq",
"nvidia_nim",
"cerebras",
"ai21_chat",
"volcengine",
"codestral",
"text-completion-codestral",
"deepseek",
"maritalk",
"voyage",
"cloudflare",
"xinference",
"fireworks_ai",
"friendliai",
"watsonx",
"triton",
"predibase",
"databricks",
"empower",
"github",
"custom", # custom apis
]
class LlmProviders(str, Enum):
OPENAI = "openai"
CUSTOM_OPENAI = "custom_openai"
TEXT_COMPLETION_OPENAI = "text-completion-openai"
COHERE = "cohere"
COHERE_CHAT = "cohere_chat"
CLARIFAI = "clarifai"
ANTHROPIC = "anthropic"
REPLICATE = "replicate"
HUGGINGFACE = "huggingface"
TOGETHER_AI = "together_ai"
OPENROUTER = "openrouter"
VERTEX_AI = "vertex_ai"
VERTEX_AI_BETA = "vertex_ai_beta"
PALM = "palm"
GEMINI = "gemini"
AI21 = "ai21"
BASETEN = "baseten"
AZURE = "azure"
AZURE_TEXT = "azure_text"
AZURE_AI = "azure_ai"
SAGEMAKER = "sagemaker"
SAGEMAKER_CHAT = "sagemaker_chat"
BEDROCK = "bedrock"
VLLM = "vllm"
NLP_CLOUD = "nlp_cloud"
PETALS = "petals"
OOBABOOGA = "oobabooga"
OLLAMA = "ollama"
OLLAMA_CHAT = "ollama_chat"
DEEPINFRA = "deepinfra"
PERPLEXITY = "perplexity"
ANYSCALE = "anyscale"
MISTRAL = "mistral"
GROQ = "groq"
NVIDIA_NIM = "nvidia_nim"
CEREBRAS = "cerebras"
AI21_CHAT = "ai21_chat"
VOLCENGINE = "volcengine"
CODESTRAL = "codestral"
TEXT_COMPLETION_CODESTRAL = "text-completion-codestral"
DEEPSEEK = "deepseek"
MARITALK = "maritalk"
VOYAGE = "voyage"
CLOUDFLARE = "cloudflare"
XINFERENCE = "xinference"
FIREWORKS_AI = "fireworks_ai"
FRIENDLIAI = "friendliai"
WATSONX = "watsonx"
TRITON = "triton"
PREDIBASE = "predibase"
DATABRICKS = "databricks"
EMPOWER = "empower"
GITHUB = "github"
CUSTOM = "custom"
provider_list: List[Union[LlmProviders, str]] = list(LlmProviders)
models_by_provider: dict = {
"openai": open_ai_chat_completion_models + open_ai_text_completion_models,

View file

@ -1242,8 +1242,9 @@ class QdrantSemanticCache(BaseCache):
import os
from litellm.llms.custom_httpx.http_handler import (
_get_async_httpx_client,
_get_httpx_client,
get_async_httpx_client,
httpxSpecialProvider,
)
if collection_name is None:
@ -1290,7 +1291,9 @@ class QdrantSemanticCache(BaseCache):
self.headers = headers
self.sync_client = _get_httpx_client()
self.async_client = _get_async_httpx_client()
self.async_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.Caching
)
if quantization_config is None:
print_verbose(

View file

@ -17,10 +17,17 @@ from pydantic import BaseModel
import litellm
from litellm import verbose_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.utils import get_formatted_prompt
global_braintrust_http_handler = AsyncHTTPHandler()
global_braintrust_http_handler = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback
)
global_braintrust_sync_http_handler = HTTPHandler()
API_BASE = "https://api.braintrustdata.com/v1"

View file

@ -8,7 +8,10 @@ from pydantic import BaseModel, Field
import litellm
from litellm._logging import verbose_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
# from here: https://docs.rungalileo.io/galileo/gen-ai-studio-products/galileo-observe/how-to/logging-data-via-restful-apis#structuring-your-records
@ -39,8 +42,8 @@ class GalileoObserve(CustomLogger):
self.base_url = os.getenv("GALILEO_BASE_URL", None)
self.project_id = os.getenv("GALILEO_PROJECT_ID", None)
self.headers = None
self.async_httpx_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
self.async_httpx_handler = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback
)
pass

View file

@ -13,15 +13,18 @@ from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.logging_utils import (
convert_litellm_response_object_to_dict,
)
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
class GCSBucketBase(CustomLogger):
def __init__(self, bucket_name: Optional[str] = None) -> None:
from litellm.proxy.proxy_server import premium_user
self.async_httpx_client = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
self.async_httpx_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback
)
self.path_service_account_json = os.getenv("GCS_PATH_SERVICE_ACCOUNT", None)
self.BUCKET_NAME = bucket_name or os.getenv("GCS_BUCKET_NAME", None)

View file

@ -13,7 +13,11 @@ import httpx
import litellm
from litellm import verbose_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.custom_httpx.http_handler import (
HTTPHandler,
get_async_httpx_client,
httpxSpecialProvider,
)
def get_utc_datetime():
@ -30,7 +34,9 @@ class LagoLogger(CustomLogger):
def __init__(self) -> None:
super().__init__()
self.validate_environment()
self.async_http_handler = AsyncHTTPHandler()
self.async_http_handler = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback
)
self.sync_http_handler = HTTPHandler()
def validate_environment(self):

View file

@ -16,7 +16,11 @@ from pydantic import BaseModel # type: ignore
import litellm
from litellm._logging import verbose_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
get_async_httpx_client,
httpxSpecialProvider,
)
class LangsmithInputs(BaseModel):
@ -61,8 +65,8 @@ class LangsmithLogger(CustomLogger):
self.langsmith_base_url = os.getenv(
"LANGSMITH_BASE_URL", "https://api.smith.langchain.com"
)
self.async_httpx_client = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
self.async_httpx_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback
)
def _prepare_log_data(self, kwargs, response_obj, start_time, end_time):

View file

@ -12,7 +12,12 @@ import httpx
import litellm
from litellm import verbose_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
get_async_httpx_client,
httpxSpecialProvider,
)
def get_utc_datetime():
@ -29,7 +34,9 @@ class OpenMeterLogger(CustomLogger):
def __init__(self) -> None:
super().__init__()
self.validate_environment()
self.async_http_handler = AsyncHTTPHandler()
self.async_http_handler = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback
)
self.sync_http_handler = HTTPHandler()
def validate_environment(self):

View file

@ -8,11 +8,16 @@ import time
import litellm
from litellm._logging import verbose_logger
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
PROMETHEUS_URL = litellm.get_secret("PROMETHEUS_URL")
PROMETHEUS_SELECTED_INSTANCE = litellm.get_secret("PROMETHEUS_SELECTED_INSTANCE")
async_http_handler = AsyncHTTPHandler()
async_http_handler = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback
)
async def get_metric_from_prometheus(

View file

@ -25,7 +25,11 @@ from litellm._logging import verbose_logger, verbose_proxy_logger
from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.proxy._types import (
AlertType,
CallInfo,
@ -187,7 +191,9 @@ class SlackAlerting(CustomLogger):
self.alerting = alerting
self.alert_types = alert_types
self.internal_usage_cache = internal_usage_cache or DualCache()
self.async_http_handler = AsyncHTTPHandler()
self.async_http_handler = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback
)
self.alert_to_webhook_url = alert_to_webhook_url
self.is_running = False
self.alerting_args = SlackAlertingArgs(**alerting_args)

View file

@ -25,8 +25,8 @@ from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_async_httpx_client,
_get_httpx_client,
get_async_httpx_client,
)
from litellm.types.llms.anthropic import (
AnthopicMessagesAssistantMessageParam,
@ -918,7 +918,9 @@ class AnthropicChatCompletion(BaseLLM):
headers={},
client=None,
) -> Union[ModelResponse, CustomStreamWrapper]:
async_handler = _get_async_httpx_client()
async_handler = get_async_httpx_client(
llm_provider=litellm.LlmProviders.ANTHROPIC
)
try:
response = await async_handler.post(

View file

@ -35,8 +35,8 @@ from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_async_httpx_client,
_get_httpx_client,
get_async_httpx_client,
)
from litellm.types.llms.bedrock import *
from litellm.types.llms.openai import (
@ -209,7 +209,9 @@ async def make_call(
):
try:
if client is None:
client = _get_async_httpx_client() # Create a new client if none provided
client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.BEDROCK
) # Create a new client if none provided
response = await client.post(
api_base,
@ -1041,7 +1043,7 @@ class BedrockLLM(BaseAWSLLM):
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
client = _get_async_httpx_client(_params) # type: ignore
client = get_async_httpx_client(params=_params, llm_provider=litellm.LlmProviders.BEDROCK) # type: ignore
else:
client = client # type: ignore
@ -1498,7 +1500,9 @@ class BedrockConverseLLM(BaseAWSLLM):
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
client = _get_async_httpx_client(_params) # type: ignore
client = get_async_httpx_client(
params=_params, llm_provider=litellm.LlmProviders.BEDROCK
)
else:
client = client # type: ignore

View file

@ -15,8 +15,8 @@ from litellm.llms.cohere.embed import embedding as cohere_embedding
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_async_httpx_client,
_get_httpx_client,
get_async_httpx_client,
)
from litellm.secret_managers.main import get_secret
from litellm.types.llms.bedrock import AmazonEmbeddingRequest, CohereEmbeddingRequest
@ -130,7 +130,9 @@ class BedrockEmbedding(BaseAWSLLM):
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
client = _get_async_httpx_client(_params) # type: ignore
client = get_async_httpx_client(
params=_params, llm_provider=litellm.LlmProviders.BEDROCK
)
else:
client = client

View file

@ -9,10 +9,11 @@ from typing import Any, Dict, List, Optional, Union
import httpx
from pydantic import BaseModel
import litellm
from litellm.llms.base import BaseLLM
from litellm.llms.custom_httpx.http_handler import (
_get_async_httpx_client,
_get_httpx_client,
get_async_httpx_client,
)
from litellm.rerank_api.types import RerankRequest, RerankResponse
@ -65,7 +66,7 @@ class CohereRerank(BaseLLM):
api_key: str,
api_base: str,
) -> RerankResponse:
client = _get_async_httpx_client()
client = get_async_httpx_client(llm_provider=litellm.LlmProviders.COHERE)
response = await client.post(
api_base,

View file

@ -1,12 +1,19 @@
import asyncio
import os
import traceback
from typing import Any, Mapping, Optional, Union
from typing import TYPE_CHECKING, Any, Mapping, Optional, Union
import httpx
import litellm
from .types import httpxSpecialProvider
if TYPE_CHECKING:
from litellm import LlmProviders
else:
LlmProviders = Any
try:
from litellm._version import version
except:
@ -378,7 +385,10 @@ class HTTPHandler:
pass
def _get_async_httpx_client(params: Optional[dict] = None) -> AsyncHTTPHandler:
def get_async_httpx_client(
llm_provider: Union[LlmProviders, httpxSpecialProvider],
params: Optional[dict] = None,
) -> AsyncHTTPHandler:
"""
Retrieves the async HTTP client from the cache
If not present, creates a new client
@ -393,7 +403,7 @@ def _get_async_httpx_client(params: Optional[dict] = None) -> AsyncHTTPHandler:
except Exception:
pass
_cache_key_name = "async_httpx_client" + _params_key_name
_cache_key_name = "async_httpx_client" + _params_key_name + llm_provider
if _cache_key_name in litellm.in_memory_llm_clients_cache:
return litellm.in_memory_llm_clients_cache[_cache_key_name]

View file

@ -0,0 +1,10 @@
from enum import Enum
import litellm
class httpxSpecialProvider(str, Enum):
LoggingCallback = "logging_callback"
GuardrailCallback = "guardrail_callback"
Caching = "caching"
Oauth2Check = "oauth2_check"

View file

@ -19,8 +19,8 @@ from litellm.litellm_core_utils.asyncify import asyncify
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_async_httpx_client,
_get_httpx_client,
get_async_httpx_client,
)
from litellm.types.llms.openai import (
ChatCompletionToolCallChunk,
@ -565,8 +565,8 @@ class SagemakerLLM(BaseAWSLLM):
):
try:
if client is None:
client = (
_get_async_httpx_client()
client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.SAGEMAKER
) # Create a new client if none provided
response = await client.post(
api_base,
@ -673,7 +673,9 @@ class SagemakerLLM(BaseAWSLLM):
model_id: Optional[str],
):
timeout = 300.0
async_handler = _get_async_httpx_client()
async_handler = get_async_httpx_client(
llm_provider=litellm.LlmProviders.SAGEMAKER
)
async_transform_prompt = asyncify(self._transform_prompt)

View file

@ -9,10 +9,11 @@ from typing import Any, Dict, List, Optional, Union
import httpx
from pydantic import BaseModel
import litellm
from litellm.llms.base import BaseLLM
from litellm.llms.custom_httpx.http_handler import (
_get_async_httpx_client,
_get_httpx_client,
get_async_httpx_client,
)
from litellm.rerank_api.types import RerankRequest, RerankResponse
@ -77,7 +78,9 @@ class TogetherAIRerank(BaseLLM):
request_data_dict: Dict[str, Any],
api_key: str,
) -> RerankResponse:
client = _get_async_httpx_client() # Use async client
client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.TOGETHER_AI
) # Use async client
response = await client.post(
"https://api.together.xyz/v1/rerank",

View file

@ -22,7 +22,7 @@ from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_async_httpx_client,
get_async_httpx_client,
)
from litellm.llms.prompt_templates.factory import (
convert_url_to_base64,
@ -1294,7 +1294,9 @@ class VertexLLM(BaseLLM):
if timeout:
_async_client_params["timeout"] = timeout
if client is None or not isinstance(client, AsyncHTTPHandler):
client = _get_async_httpx_client(params=_async_client_params)
client = get_async_httpx_client(
params=_async_client_params, llm_provider=litellm.LlmProviders.VERTEX_AI
)
else:
client = client # type: ignore
## LOGGING

View file

@ -4,13 +4,14 @@ from typing import Any, Coroutine, Literal, Optional, TypedDict, Union
import httpx
import litellm
from litellm._logging import verbose_logger
from litellm.llms.base import BaseLLM
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_async_httpx_client,
_get_httpx_client,
get_async_httpx_client,
)
from litellm.llms.OpenAI.openai import HttpxBinaryResponseContent
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
@ -178,7 +179,9 @@ class VertexTextToSpeechAPI(VertexLLM):
) -> HttpxBinaryResponseContent:
import base64
async_handler = _get_async_httpx_client()
async_handler = get_async_httpx_client(
llm_provider=litellm.LlmProviders.VERTEX_AI
)
response = await async_handler.post(
url=url,

View file

@ -20,7 +20,10 @@ async def check_oauth2_token(token: str) -> UserAPIKeyAuth:
import httpx
from litellm._logging import verbose_proxy_logger
from litellm.llms.custom_httpx.http_handler import _get_async_httpx_client
from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.proxy._types import CommonProxyErrors
from litellm.proxy.proxy_server import premium_user
@ -40,7 +43,7 @@ async def check_oauth2_token(token: str) -> UserAPIKeyAuth:
if not token_info_endpoint:
raise ValueError("OAUTH_TOKEN_INFO_ENDPOINT environment variable is not set")
client = _get_async_httpx_client()
client = get_async_httpx_client(llm_provider=httpxSpecialProvider.Oauth2Check)
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
try:

View file

@ -30,7 +30,11 @@ from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.litellm_core_utils.logging_utils import (
convert_litellm_response_object_to_str,
)
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
from litellm.types.guardrails import GuardrailEventHooks
@ -44,8 +48,8 @@ class AporiaGuardrail(CustomGuardrail):
def __init__(
self, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs
):
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
self.async_handler = get_async_httpx_client(
llm_provider=httpxSpecialProvider.GuardrailCallback
)
self.aporia_api_key = api_key or os.environ["APORIO_API_KEY"]
self.aporia_api_base = api_base or os.environ["APORIO_API_BASE"]

View file

@ -33,7 +33,8 @@ from litellm.litellm_core_utils.logging_utils import (
from litellm.llms.base_aws_llm import BaseAWSLLM
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
_get_async_httpx_client,
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
@ -55,7 +56,9 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
guardrailVersion: Optional[str] = None,
**kwargs,
):
self.async_handler = _get_async_httpx_client()
self.async_handler = get_async_httpx_client(
llm_provider=httpxSpecialProvider.GuardrailCallback
)
self.guardrailIdentifier = guardrailIdentifier
self.guardrailVersion = guardrailVersion

View file

@ -21,7 +21,11 @@ from fastapi import HTTPException
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
from litellm.secret_managers.main import get_secret
@ -50,8 +54,8 @@ class lakeraAI_Moderation(CustomGuardrail):
api_key: Optional[str] = None,
**kwargs,
):
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
self.async_handler = get_async_httpx_client(
llm_provider=httpxSpecialProvider.GuardrailCallback
)
self.lakera_api_key = api_key or os.environ["LAKERA_API_KEY"]
self.moderation_check = moderation_check

View file

@ -22,7 +22,6 @@ import litellm # noqa: E401
from litellm._logging import verbose_proxy_logger
from litellm.caching import DualCache
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.llms.custom_httpx.http_handler import _get_async_httpx_client
from litellm.proxy._types import UserAPIKeyAuth
from litellm.utils import (
EmbeddingResponse,