Merge pull request #5620 from BerriAI/litellm_use_helper_to_get_httpx_clients

[Feat-Perf] Use common helper to get async httpx clients for all providers
This commit is contained in:
Ishaan Jaff 2024-09-10 15:03:11 -07:00 committed by GitHub
commit 2d4be4cf1f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
27 changed files with 221 additions and 117 deletions

View file

@ -107,9 +107,9 @@ export LITELLM_SALT_KEY="sk-1234"
|--------------|-------| |--------------|-------|
| Avg latency | `50ms` | | Avg latency | `50ms` |
| Median latency | `51ms` | | Median latency | `51ms` |
| `/chat/completions` Requests/second | `35` | | `/chat/completions` Requests/second | `100` |
| `/chat/completions` Requests/minute | `2100` | | `/chat/completions` Requests/minute | `6000` |
| `/chat/completions` Requests/hour | `126K` | | `/chat/completions` Requests/hour | `360K` |
### Verifying Debugging logs are off ### Verifying Debugging logs are off

View file

@ -26,7 +26,10 @@ from typing import List
from datetime import datetime from datetime import datetime
import aiohttp, asyncio import aiohttp, asyncio
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
import httpx import httpx
import json import json
from litellm.types.guardrails import GuardrailEventHooks from litellm.types.guardrails import GuardrailEventHooks
@ -40,8 +43,8 @@ class AporiaGuardrail(CustomGuardrail):
def __init__( def __init__(
self, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs self, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs
): ):
self.async_handler = AsyncHTTPHandler( self.async_handler = get_async_httpx_client(
timeout=httpx.Timeout(timeout=600.0, connect=5.0) llm_provider=httpxSpecialProvider.GuardrailCallback
) )
self.aporia_api_key = api_key or os.environ["APORIO_API_KEY"] self.aporia_api_key = api_key or os.environ["APORIO_API_KEY"]
self.aporia_api_base = api_base or os.environ["APORIO_API_BASE"] self.aporia_api_base = api_base or os.environ["APORIO_API_BASE"]

View file

@ -24,6 +24,7 @@ from litellm.proxy._types import (
) )
import httpx import httpx
import dotenv import dotenv
from enum import Enum
litellm_mode = os.getenv("LITELLM_MODE", "DEV") # "PRODUCTION", "DEV" litellm_mode = os.getenv("LITELLM_MODE", "DEV") # "PRODUCTION", "DEV"
if litellm_mode == "DEV": if litellm_mode == "DEV":
@ -678,62 +679,66 @@ model_list = (
+ gemini_models + gemini_models
) )
provider_list: List = [
"openai", class LlmProviders(str, Enum):
"custom_openai", OPENAI = "openai"
"text-completion-openai", CUSTOM_OPENAI = "custom_openai"
"cohere", TEXT_COMPLETION_OPENAI = "text-completion-openai"
"cohere_chat", COHERE = "cohere"
"clarifai", COHERE_CHAT = "cohere_chat"
"anthropic", CLARIFAI = "clarifai"
"replicate", ANTHROPIC = "anthropic"
"huggingface", REPLICATE = "replicate"
"together_ai", HUGGINGFACE = "huggingface"
"openrouter", TOGETHER_AI = "together_ai"
"vertex_ai", OPENROUTER = "openrouter"
"vertex_ai_beta", VERTEX_AI = "vertex_ai"
"palm", VERTEX_AI_BETA = "vertex_ai_beta"
"gemini", PALM = "palm"
"ai21", GEMINI = "gemini"
"baseten", AI21 = "ai21"
"azure", BASETEN = "baseten"
"azure_text", AZURE = "azure"
"azure_ai", AZURE_TEXT = "azure_text"
"sagemaker", AZURE_AI = "azure_ai"
"sagemaker_chat", SAGEMAKER = "sagemaker"
"bedrock", SAGEMAKER_CHAT = "sagemaker_chat"
"vllm", BEDROCK = "bedrock"
"nlp_cloud", VLLM = "vllm"
"petals", NLP_CLOUD = "nlp_cloud"
"oobabooga", PETALS = "petals"
"ollama", OOBABOOGA = "oobabooga"
"ollama_chat", OLLAMA = "ollama"
"deepinfra", OLLAMA_CHAT = "ollama_chat"
"perplexity", DEEPINFRA = "deepinfra"
"anyscale", PERPLEXITY = "perplexity"
"mistral", ANYSCALE = "anyscale"
"groq", MISTRAL = "mistral"
"nvidia_nim", GROQ = "groq"
"cerebras", NVIDIA_NIM = "nvidia_nim"
"ai21_chat", CEREBRAS = "cerebras"
"volcengine", AI21_CHAT = "ai21_chat"
"codestral", VOLCENGINE = "volcengine"
"text-completion-codestral", CODESTRAL = "codestral"
"deepseek", TEXT_COMPLETION_CODESTRAL = "text-completion-codestral"
"maritalk", DEEPSEEK = "deepseek"
"voyage", MARITALK = "maritalk"
"cloudflare", VOYAGE = "voyage"
"xinference", CLOUDFLARE = "cloudflare"
"fireworks_ai", XINFERENCE = "xinference"
"friendliai", FIREWORKS_AI = "fireworks_ai"
"watsonx", FRIENDLIAI = "friendliai"
"triton", WATSONX = "watsonx"
"predibase", TRITON = "triton"
"databricks", PREDIBASE = "predibase"
"empower", DATABRICKS = "databricks"
"github", EMPOWER = "empower"
"custom", # custom apis GITHUB = "github"
] CUSTOM = "custom"
provider_list: List[Union[LlmProviders, str]] = list(LlmProviders)
models_by_provider: dict = { models_by_provider: dict = {
"openai": open_ai_chat_completion_models + open_ai_text_completion_models, "openai": open_ai_chat_completion_models + open_ai_text_completion_models,

View file

@ -1242,8 +1242,9 @@ class QdrantSemanticCache(BaseCache):
import os import os
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
_get_async_httpx_client,
_get_httpx_client, _get_httpx_client,
get_async_httpx_client,
httpxSpecialProvider,
) )
if collection_name is None: if collection_name is None:
@ -1290,7 +1291,9 @@ class QdrantSemanticCache(BaseCache):
self.headers = headers self.headers = headers
self.sync_client = _get_httpx_client() self.sync_client = _get_httpx_client()
self.async_client = _get_async_httpx_client() self.async_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.Caching
)
if quantization_config is None: if quantization_config is None:
print_verbose( print_verbose(

View file

@ -17,10 +17,17 @@ from pydantic import BaseModel
import litellm import litellm
from litellm import verbose_logger from litellm import verbose_logger
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.utils import get_formatted_prompt from litellm.utils import get_formatted_prompt
global_braintrust_http_handler = AsyncHTTPHandler() global_braintrust_http_handler = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback
)
global_braintrust_sync_http_handler = HTTPHandler() global_braintrust_sync_http_handler = HTTPHandler()
API_BASE = "https://api.braintrustdata.com/v1" API_BASE = "https://api.braintrustdata.com/v1"

View file

@ -8,7 +8,10 @@ from pydantic import BaseModel, Field
import litellm import litellm
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
# from here: https://docs.rungalileo.io/galileo/gen-ai-studio-products/galileo-observe/how-to/logging-data-via-restful-apis#structuring-your-records # from here: https://docs.rungalileo.io/galileo/gen-ai-studio-products/galileo-observe/how-to/logging-data-via-restful-apis#structuring-your-records
@ -39,8 +42,8 @@ class GalileoObserve(CustomLogger):
self.base_url = os.getenv("GALILEO_BASE_URL", None) self.base_url = os.getenv("GALILEO_BASE_URL", None)
self.project_id = os.getenv("GALILEO_PROJECT_ID", None) self.project_id = os.getenv("GALILEO_PROJECT_ID", None)
self.headers = None self.headers = None
self.async_httpx_handler = AsyncHTTPHandler( self.async_httpx_handler = get_async_httpx_client(
timeout=httpx.Timeout(timeout=600.0, connect=5.0) llm_provider=httpxSpecialProvider.LoggingCallback
) )
pass pass

View file

@ -13,15 +13,18 @@ from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.logging_utils import ( from litellm.litellm_core_utils.logging_utils import (
convert_litellm_response_object_to_dict, convert_litellm_response_object_to_dict,
) )
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
class GCSBucketBase(CustomLogger): class GCSBucketBase(CustomLogger):
def __init__(self, bucket_name: Optional[str] = None) -> None: def __init__(self, bucket_name: Optional[str] = None) -> None:
from litellm.proxy.proxy_server import premium_user from litellm.proxy.proxy_server import premium_user
self.async_httpx_client = AsyncHTTPHandler( self.async_httpx_client = get_async_httpx_client(
timeout=httpx.Timeout(timeout=600.0, connect=5.0) llm_provider=httpxSpecialProvider.LoggingCallback
) )
self.path_service_account_json = os.getenv("GCS_PATH_SERVICE_ACCOUNT", None) self.path_service_account_json = os.getenv("GCS_PATH_SERVICE_ACCOUNT", None)
self.BUCKET_NAME = bucket_name or os.getenv("GCS_BUCKET_NAME", None) self.BUCKET_NAME = bucket_name or os.getenv("GCS_BUCKET_NAME", None)

View file

@ -13,7 +13,11 @@ import httpx
import litellm import litellm
from litellm import verbose_logger from litellm import verbose_logger
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import (
HTTPHandler,
get_async_httpx_client,
httpxSpecialProvider,
)
def get_utc_datetime(): def get_utc_datetime():
@ -30,7 +34,9 @@ class LagoLogger(CustomLogger):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self.validate_environment() self.validate_environment()
self.async_http_handler = AsyncHTTPHandler() self.async_http_handler = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback
)
self.sync_http_handler = HTTPHandler() self.sync_http_handler = HTTPHandler()
def validate_environment(self): def validate_environment(self):

View file

@ -16,7 +16,11 @@ from pydantic import BaseModel # type: ignore
import litellm import litellm
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
get_async_httpx_client,
httpxSpecialProvider,
)
class LangsmithInputs(BaseModel): class LangsmithInputs(BaseModel):
@ -61,8 +65,8 @@ class LangsmithLogger(CustomLogger):
self.langsmith_base_url = os.getenv( self.langsmith_base_url = os.getenv(
"LANGSMITH_BASE_URL", "https://api.smith.langchain.com" "LANGSMITH_BASE_URL", "https://api.smith.langchain.com"
) )
self.async_httpx_client = AsyncHTTPHandler( self.async_httpx_client = get_async_httpx_client(
timeout=httpx.Timeout(timeout=600.0, connect=5.0) llm_provider=httpxSpecialProvider.LoggingCallback
) )
def _prepare_log_data(self, kwargs, response_obj, start_time, end_time): def _prepare_log_data(self, kwargs, response_obj, start_time, end_time):

View file

@ -12,7 +12,12 @@ import httpx
import litellm import litellm
from litellm import verbose_logger from litellm import verbose_logger
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
get_async_httpx_client,
httpxSpecialProvider,
)
def get_utc_datetime(): def get_utc_datetime():
@ -29,7 +34,9 @@ class OpenMeterLogger(CustomLogger):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self.validate_environment() self.validate_environment()
self.async_http_handler = AsyncHTTPHandler() self.async_http_handler = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback
)
self.sync_http_handler = HTTPHandler() self.sync_http_handler = HTTPHandler()
def validate_environment(self): def validate_environment(self):

View file

@ -8,11 +8,16 @@ import time
import litellm import litellm
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
PROMETHEUS_URL = litellm.get_secret("PROMETHEUS_URL") PROMETHEUS_URL = litellm.get_secret("PROMETHEUS_URL")
PROMETHEUS_SELECTED_INSTANCE = litellm.get_secret("PROMETHEUS_SELECTED_INSTANCE") PROMETHEUS_SELECTED_INSTANCE = litellm.get_secret("PROMETHEUS_SELECTED_INSTANCE")
async_http_handler = AsyncHTTPHandler() async_http_handler = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback
)
async def get_metric_from_prometheus( async def get_metric_from_prometheus(

View file

@ -25,7 +25,11 @@ from litellm._logging import verbose_logger, verbose_proxy_logger
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.litellm_logging import Logging from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.proxy._types import ( from litellm.proxy._types import (
AlertType, AlertType,
CallInfo, CallInfo,
@ -187,7 +191,9 @@ class SlackAlerting(CustomLogger):
self.alerting = alerting self.alerting = alerting
self.alert_types = alert_types self.alert_types = alert_types
self.internal_usage_cache = internal_usage_cache or DualCache() self.internal_usage_cache = internal_usage_cache or DualCache()
self.async_http_handler = AsyncHTTPHandler() self.async_http_handler = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback
)
self.alert_to_webhook_url = alert_to_webhook_url self.alert_to_webhook_url = alert_to_webhook_url
self.is_running = False self.is_running = False
self.alerting_args = SlackAlertingArgs(**alerting_args) self.alerting_args = SlackAlertingArgs(**alerting_args)

View file

@ -25,8 +25,8 @@ from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler, AsyncHTTPHandler,
HTTPHandler, HTTPHandler,
_get_async_httpx_client,
_get_httpx_client, _get_httpx_client,
get_async_httpx_client,
) )
from litellm.types.llms.anthropic import ( from litellm.types.llms.anthropic import (
AnthopicMessagesAssistantMessageParam, AnthopicMessagesAssistantMessageParam,
@ -918,7 +918,9 @@ class AnthropicChatCompletion(BaseLLM):
headers={}, headers={},
client=None, client=None,
) -> Union[ModelResponse, CustomStreamWrapper]: ) -> Union[ModelResponse, CustomStreamWrapper]:
async_handler = _get_async_httpx_client() async_handler = get_async_httpx_client(
llm_provider=litellm.LlmProviders.ANTHROPIC
)
try: try:
response = await async_handler.post( response = await async_handler.post(

View file

@ -35,8 +35,8 @@ from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler, AsyncHTTPHandler,
HTTPHandler, HTTPHandler,
_get_async_httpx_client,
_get_httpx_client, _get_httpx_client,
get_async_httpx_client,
) )
from litellm.types.llms.bedrock import * from litellm.types.llms.bedrock import *
from litellm.types.llms.openai import ( from litellm.types.llms.openai import (
@ -209,7 +209,9 @@ async def make_call(
): ):
try: try:
if client is None: if client is None:
client = _get_async_httpx_client() # Create a new client if none provided client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.BEDROCK
) # Create a new client if none provided
response = await client.post( response = await client.post(
api_base, api_base,
@ -1041,7 +1043,7 @@ class BedrockLLM(BaseAWSLLM):
if isinstance(timeout, float) or isinstance(timeout, int): if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout) timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout _params["timeout"] = timeout
client = _get_async_httpx_client(_params) # type: ignore client = get_async_httpx_client(params=_params, llm_provider=litellm.LlmProviders.BEDROCK) # type: ignore
else: else:
client = client # type: ignore client = client # type: ignore
@ -1498,7 +1500,9 @@ class BedrockConverseLLM(BaseAWSLLM):
if isinstance(timeout, float) or isinstance(timeout, int): if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout) timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout _params["timeout"] = timeout
client = _get_async_httpx_client(_params) # type: ignore client = get_async_httpx_client(
params=_params, llm_provider=litellm.LlmProviders.BEDROCK
)
else: else:
client = client # type: ignore client = client # type: ignore

View file

@ -15,8 +15,8 @@ from litellm.llms.cohere.embed import embedding as cohere_embedding
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler, AsyncHTTPHandler,
HTTPHandler, HTTPHandler,
_get_async_httpx_client,
_get_httpx_client, _get_httpx_client,
get_async_httpx_client,
) )
from litellm.secret_managers.main import get_secret from litellm.secret_managers.main import get_secret
from litellm.types.llms.bedrock import AmazonEmbeddingRequest, CohereEmbeddingRequest from litellm.types.llms.bedrock import AmazonEmbeddingRequest, CohereEmbeddingRequest
@ -130,7 +130,9 @@ class BedrockEmbedding(BaseAWSLLM):
if isinstance(timeout, float) or isinstance(timeout, int): if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout) timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout _params["timeout"] = timeout
client = _get_async_httpx_client(_params) # type: ignore client = get_async_httpx_client(
params=_params, llm_provider=litellm.LlmProviders.BEDROCK
)
else: else:
client = client client = client

View file

@ -9,10 +9,11 @@ from typing import Any, Dict, List, Optional, Union
import httpx import httpx
from pydantic import BaseModel from pydantic import BaseModel
import litellm
from litellm.llms.base import BaseLLM from litellm.llms.base import BaseLLM
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
_get_async_httpx_client,
_get_httpx_client, _get_httpx_client,
get_async_httpx_client,
) )
from litellm.rerank_api.types import RerankRequest, RerankResponse from litellm.rerank_api.types import RerankRequest, RerankResponse
@ -65,7 +66,7 @@ class CohereRerank(BaseLLM):
api_key: str, api_key: str,
api_base: str, api_base: str,
) -> RerankResponse: ) -> RerankResponse:
client = _get_async_httpx_client() client = get_async_httpx_client(llm_provider=litellm.LlmProviders.COHERE)
response = await client.post( response = await client.post(
api_base, api_base,

View file

@ -1,12 +1,19 @@
import asyncio import asyncio
import os import os
import traceback import traceback
from typing import Any, Mapping, Optional, Union from typing import TYPE_CHECKING, Any, Mapping, Optional, Union
import httpx import httpx
import litellm import litellm
from .types import httpxSpecialProvider
if TYPE_CHECKING:
from litellm import LlmProviders
else:
LlmProviders = Any
try: try:
from litellm._version import version from litellm._version import version
except: except:
@ -378,7 +385,10 @@ class HTTPHandler:
pass pass
def _get_async_httpx_client(params: Optional[dict] = None) -> AsyncHTTPHandler: def get_async_httpx_client(
llm_provider: Union[LlmProviders, httpxSpecialProvider],
params: Optional[dict] = None,
) -> AsyncHTTPHandler:
""" """
Retrieves the async HTTP client from the cache Retrieves the async HTTP client from the cache
If not present, creates a new client If not present, creates a new client
@ -393,7 +403,7 @@ def _get_async_httpx_client(params: Optional[dict] = None) -> AsyncHTTPHandler:
except Exception: except Exception:
pass pass
_cache_key_name = "async_httpx_client" + _params_key_name _cache_key_name = "async_httpx_client" + _params_key_name + llm_provider
if _cache_key_name in litellm.in_memory_llm_clients_cache: if _cache_key_name in litellm.in_memory_llm_clients_cache:
return litellm.in_memory_llm_clients_cache[_cache_key_name] return litellm.in_memory_llm_clients_cache[_cache_key_name]

View file

@ -0,0 +1,10 @@
from enum import Enum
import litellm
class httpxSpecialProvider(str, Enum):
LoggingCallback = "logging_callback"
GuardrailCallback = "guardrail_callback"
Caching = "caching"
Oauth2Check = "oauth2_check"

View file

@ -19,8 +19,8 @@ from litellm.litellm_core_utils.asyncify import asyncify
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler, AsyncHTTPHandler,
HTTPHandler, HTTPHandler,
_get_async_httpx_client,
_get_httpx_client, _get_httpx_client,
get_async_httpx_client,
) )
from litellm.types.llms.openai import ( from litellm.types.llms.openai import (
ChatCompletionToolCallChunk, ChatCompletionToolCallChunk,
@ -565,8 +565,8 @@ class SagemakerLLM(BaseAWSLLM):
): ):
try: try:
if client is None: if client is None:
client = ( client = get_async_httpx_client(
_get_async_httpx_client() llm_provider=litellm.LlmProviders.SAGEMAKER
) # Create a new client if none provided ) # Create a new client if none provided
response = await client.post( response = await client.post(
api_base, api_base,
@ -673,7 +673,9 @@ class SagemakerLLM(BaseAWSLLM):
model_id: Optional[str], model_id: Optional[str],
): ):
timeout = 300.0 timeout = 300.0
async_handler = _get_async_httpx_client() async_handler = get_async_httpx_client(
llm_provider=litellm.LlmProviders.SAGEMAKER
)
async_transform_prompt = asyncify(self._transform_prompt) async_transform_prompt = asyncify(self._transform_prompt)

View file

@ -9,10 +9,11 @@ from typing import Any, Dict, List, Optional, Union
import httpx import httpx
from pydantic import BaseModel from pydantic import BaseModel
import litellm
from litellm.llms.base import BaseLLM from litellm.llms.base import BaseLLM
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
_get_async_httpx_client,
_get_httpx_client, _get_httpx_client,
get_async_httpx_client,
) )
from litellm.rerank_api.types import RerankRequest, RerankResponse from litellm.rerank_api.types import RerankRequest, RerankResponse
@ -77,7 +78,9 @@ class TogetherAIRerank(BaseLLM):
request_data_dict: Dict[str, Any], request_data_dict: Dict[str, Any],
api_key: str, api_key: str,
) -> RerankResponse: ) -> RerankResponse:
client = _get_async_httpx_client() # Use async client client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.TOGETHER_AI
) # Use async client
response = await client.post( response = await client.post(
"https://api.together.xyz/v1/rerank", "https://api.together.xyz/v1/rerank",

View file

@ -22,7 +22,7 @@ from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler, AsyncHTTPHandler,
HTTPHandler, HTTPHandler,
_get_async_httpx_client, get_async_httpx_client,
) )
from litellm.llms.prompt_templates.factory import ( from litellm.llms.prompt_templates.factory import (
convert_url_to_base64, convert_url_to_base64,
@ -1294,7 +1294,9 @@ class VertexLLM(BaseLLM):
if timeout: if timeout:
_async_client_params["timeout"] = timeout _async_client_params["timeout"] = timeout
if client is None or not isinstance(client, AsyncHTTPHandler): if client is None or not isinstance(client, AsyncHTTPHandler):
client = _get_async_httpx_client(params=_async_client_params) client = get_async_httpx_client(
params=_async_client_params, llm_provider=litellm.LlmProviders.VERTEX_AI
)
else: else:
client = client # type: ignore client = client # type: ignore
## LOGGING ## LOGGING

View file

@ -4,13 +4,14 @@ from typing import Any, Coroutine, Literal, Optional, TypedDict, Union
import httpx import httpx
import litellm
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.llms.base import BaseLLM from litellm.llms.base import BaseLLM
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler, AsyncHTTPHandler,
HTTPHandler, HTTPHandler,
_get_async_httpx_client,
_get_httpx_client, _get_httpx_client,
get_async_httpx_client,
) )
from litellm.llms.OpenAI.openai import HttpxBinaryResponseContent from litellm.llms.OpenAI.openai import HttpxBinaryResponseContent
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
@ -178,7 +179,9 @@ class VertexTextToSpeechAPI(VertexLLM):
) -> HttpxBinaryResponseContent: ) -> HttpxBinaryResponseContent:
import base64 import base64
async_handler = _get_async_httpx_client() async_handler = get_async_httpx_client(
llm_provider=litellm.LlmProviders.VERTEX_AI
)
response = await async_handler.post( response = await async_handler.post(
url=url, url=url,

View file

@ -20,7 +20,10 @@ async def check_oauth2_token(token: str) -> UserAPIKeyAuth:
import httpx import httpx
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.llms.custom_httpx.http_handler import _get_async_httpx_client from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.proxy._types import CommonProxyErrors from litellm.proxy._types import CommonProxyErrors
from litellm.proxy.proxy_server import premium_user from litellm.proxy.proxy_server import premium_user
@ -40,7 +43,7 @@ async def check_oauth2_token(token: str) -> UserAPIKeyAuth:
if not token_info_endpoint: if not token_info_endpoint:
raise ValueError("OAUTH_TOKEN_INFO_ENDPOINT environment variable is not set") raise ValueError("OAUTH_TOKEN_INFO_ENDPOINT environment variable is not set")
client = _get_async_httpx_client() client = get_async_httpx_client(llm_provider=httpxSpecialProvider.Oauth2Check)
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
try: try:

View file

@ -30,7 +30,11 @@ from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.litellm_core_utils.logging_utils import ( from litellm.litellm_core_utils.logging_utils import (
convert_litellm_response_object_to_str, convert_litellm_response_object_to_str,
) )
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
from litellm.types.guardrails import GuardrailEventHooks from litellm.types.guardrails import GuardrailEventHooks
@ -44,8 +48,8 @@ class AporiaGuardrail(CustomGuardrail):
def __init__( def __init__(
self, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs self, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs
): ):
self.async_handler = AsyncHTTPHandler( self.async_handler = get_async_httpx_client(
timeout=httpx.Timeout(timeout=600.0, connect=5.0) llm_provider=httpxSpecialProvider.GuardrailCallback
) )
self.aporia_api_key = api_key or os.environ["APORIO_API_KEY"] self.aporia_api_key = api_key or os.environ["APORIO_API_KEY"]
self.aporia_api_base = api_base or os.environ["APORIO_API_BASE"] self.aporia_api_base = api_base or os.environ["APORIO_API_BASE"]

View file

@ -33,7 +33,8 @@ from litellm.litellm_core_utils.logging_utils import (
from litellm.llms.base_aws_llm import BaseAWSLLM from litellm.llms.base_aws_llm import BaseAWSLLM
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler, AsyncHTTPHandler,
_get_async_httpx_client, get_async_httpx_client,
httpxSpecialProvider,
) )
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
@ -55,7 +56,9 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
guardrailVersion: Optional[str] = None, guardrailVersion: Optional[str] = None,
**kwargs, **kwargs,
): ):
self.async_handler = _get_async_httpx_client() self.async_handler = get_async_httpx_client(
llm_provider=httpxSpecialProvider.GuardrailCallback
)
self.guardrailIdentifier = guardrailIdentifier self.guardrailIdentifier = guardrailIdentifier
self.guardrailVersion = guardrailVersion self.guardrailVersion = guardrailVersion

View file

@ -21,7 +21,11 @@ from fastapi import HTTPException
import litellm import litellm
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.integrations.custom_guardrail import CustomGuardrail from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
from litellm.secret_managers.main import get_secret from litellm.secret_managers.main import get_secret
@ -50,8 +54,8 @@ class lakeraAI_Moderation(CustomGuardrail):
api_key: Optional[str] = None, api_key: Optional[str] = None,
**kwargs, **kwargs,
): ):
self.async_handler = AsyncHTTPHandler( self.async_handler = get_async_httpx_client(
timeout=httpx.Timeout(timeout=600.0, connect=5.0) llm_provider=httpxSpecialProvider.GuardrailCallback
) )
self.lakera_api_key = api_key or os.environ["LAKERA_API_KEY"] self.lakera_api_key = api_key or os.environ["LAKERA_API_KEY"]
self.moderation_check = moderation_check self.moderation_check = moderation_check

View file

@ -22,7 +22,6 @@ import litellm # noqa: E401
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.integrations.custom_guardrail import CustomGuardrail from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.llms.custom_httpx.http_handler import _get_async_httpx_client
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.utils import ( from litellm.utils import (
EmbeddingResponse, EmbeddingResponse,