pass llm provider when creating async httpx clients

This commit is contained in:
Ishaan Jaff 2024-09-10 11:51:42 -07:00
parent 87bac7c026
commit 421b857714
14 changed files with 57 additions and 19 deletions

View file

@ -1244,6 +1244,7 @@ class QdrantSemanticCache(BaseCache):
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client, _get_httpx_client,
get_async_httpx_client, get_async_httpx_client,
httpxSpecialProvider,
) )
if collection_name is None: if collection_name is None:
@ -1290,7 +1291,9 @@ class QdrantSemanticCache(BaseCache):
self.headers = headers self.headers = headers
self.sync_client = _get_httpx_client() self.sync_client = _get_httpx_client()
self.async_client = get_async_httpx_client() self.async_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.Caching
)
if quantization_config is None: if quantization_config is None:
print_verbose( print_verbose(

View file

@ -918,7 +918,9 @@ class AnthropicChatCompletion(BaseLLM):
headers={}, headers={},
client=None, client=None,
) -> Union[ModelResponse, CustomStreamWrapper]: ) -> Union[ModelResponse, CustomStreamWrapper]:
async_handler = get_async_httpx_client() async_handler = get_async_httpx_client(
llm_provider=litellm.LlmProviders.ANTHROPIC
)
try: try:
response = await async_handler.post( response = await async_handler.post(

View file

@ -209,7 +209,9 @@ async def make_call(
): ):
try: try:
if client is None: if client is None:
client = get_async_httpx_client() # Create a new client if none provided client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.BEDROCK
) # Create a new client if none provided
response = await client.post( response = await client.post(
api_base, api_base,
@ -1041,7 +1043,7 @@ class BedrockLLM(BaseAWSLLM):
if isinstance(timeout, float) or isinstance(timeout, int): if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout) timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout _params["timeout"] = timeout
client = get_async_httpx_client(_params) # type: ignore client = get_async_httpx_client(_params, llm_provider=litellm.LlmProviders.BEDROCK) # type: ignore
else: else:
client = client # type: ignore client = client # type: ignore
@ -1498,7 +1500,7 @@ class BedrockConverseLLM(BaseAWSLLM):
if isinstance(timeout, float) or isinstance(timeout, int): if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout) timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout _params["timeout"] = timeout
client = get_async_httpx_client(_params) # type: ignore client = get_async_httpx_client(_params, llm_provider=litellm.LlmProviders.BEDROCK) # type: ignore
else: else:
client = client # type: ignore client = client # type: ignore

View file

@ -130,7 +130,7 @@ class BedrockEmbedding(BaseAWSLLM):
if isinstance(timeout, float) or isinstance(timeout, int): if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout) timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout _params["timeout"] = timeout
client = get_async_httpx_client(_params) # type: ignore client = get_async_httpx_client(_params, llm_provider=litellm.LlmProviders.BEDROCK) # type: ignore
else: else:
client = client client = client

View file

@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional, Union
import httpx import httpx
from pydantic import BaseModel from pydantic import BaseModel
import litellm
from litellm.llms.base import BaseLLM from litellm.llms.base import BaseLLM
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client, _get_httpx_client,
@ -65,7 +66,7 @@ class CohereRerank(BaseLLM):
api_key: str, api_key: str,
api_base: str, api_base: str,
) -> RerankResponse: ) -> RerankResponse:
client = get_async_httpx_client() client = get_async_httpx_client(llm_provider=litellm.LlmProviders.COHERE)
response = await client.post( response = await client.post(
api_base, api_base,

View file

@ -7,6 +7,8 @@ import httpx
import litellm import litellm
from .types import httpxSpecialProvider
try: try:
from litellm._version import version from litellm._version import version
except: except:
@ -378,7 +380,10 @@ class HTTPHandler:
pass pass
def get_async_httpx_client(params: Optional[dict] = None) -> AsyncHTTPHandler: def get_async_httpx_client(
llm_provider: Union[litellm.LlmProviders, httpxSpecialProvider],
params: Optional[dict] = None,
) -> AsyncHTTPHandler:
""" """
Retrieves the async HTTP client from the cache Retrieves the async HTTP client from the cache
If not present, creates a new client If not present, creates a new client
@ -393,7 +398,7 @@ def get_async_httpx_client(params: Optional[dict] = None) -> AsyncHTTPHandler:
except Exception: except Exception:
pass pass
_cache_key_name = "async_httpx_client" + _params_key_name _cache_key_name = "async_httpx_client" + _params_key_name + llm_provider
if _cache_key_name in litellm.in_memory_llm_clients_cache: if _cache_key_name in litellm.in_memory_llm_clients_cache:
return litellm.in_memory_llm_clients_cache[_cache_key_name] return litellm.in_memory_llm_clients_cache[_cache_key_name]

View file

@ -0,0 +1,10 @@
from enum import Enum
import litellm
class httpxSpecialProvider(str, Enum):
LoggingCallback = "logging_callback"
GuardrailCallback = "guardrail_callback"
Caching = "caching"
Oauth2Check = "oauth2_check"

View file

@ -565,8 +565,8 @@ class SagemakerLLM(BaseAWSLLM):
): ):
try: try:
if client is None: if client is None:
client = ( client = get_async_httpx_client(
get_async_httpx_client() llm_provider=litellm.LlmProviders.SAGEMAKER
) # Create a new client if none provided ) # Create a new client if none provided
response = await client.post( response = await client.post(
api_base, api_base,
@ -673,7 +673,9 @@ class SagemakerLLM(BaseAWSLLM):
model_id: Optional[str], model_id: Optional[str],
): ):
timeout = 300.0 timeout = 300.0
async_handler = get_async_httpx_client() async_handler = get_async_httpx_client(
llm_provider=litellm.LlmProviders.SAGEMAKER
)
async_transform_prompt = asyncify(self._transform_prompt) async_transform_prompt = asyncify(self._transform_prompt)

View file

@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional, Union
import httpx import httpx
from pydantic import BaseModel from pydantic import BaseModel
import litellm
from litellm.llms.base import BaseLLM from litellm.llms.base import BaseLLM
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client, _get_httpx_client,
@ -77,7 +78,9 @@ class TogetherAIRerank(BaseLLM):
request_data_dict: Dict[str, Any], request_data_dict: Dict[str, Any],
api_key: str, api_key: str,
) -> RerankResponse: ) -> RerankResponse:
client = get_async_httpx_client() # Use async client client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.TOGETHER_AI
) # Use async client
response = await client.post( response = await client.post(
"https://api.together.xyz/v1/rerank", "https://api.together.xyz/v1/rerank",

View file

@ -1293,7 +1293,9 @@ class VertexLLM(BaseLLM):
_async_client_params = {} _async_client_params = {}
if timeout: if timeout:
_async_client_params["timeout"] = timeout _async_client_params["timeout"] = timeout
client = get_async_httpx_client(params=_async_client_params) client = get_async_httpx_client(
params=_async_client_params, llm_provider=litellm.LlmProviders.VERTEX_AI
)
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=messages, input=messages,

View file

@ -4,6 +4,7 @@ from typing import Any, Coroutine, Literal, Optional, TypedDict, Union
import httpx import httpx
import litellm
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.llms.base import BaseLLM from litellm.llms.base import BaseLLM
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
@ -178,7 +179,9 @@ class VertexTextToSpeechAPI(VertexLLM):
) -> HttpxBinaryResponseContent: ) -> HttpxBinaryResponseContent:
import base64 import base64
async_handler = get_async_httpx_client() async_handler = get_async_httpx_client(
llm_provider=litellm.LlmProviders.VERTEX_AI
)
response = await async_handler.post( response = await async_handler.post(
url=url, url=url,

View file

@ -20,7 +20,10 @@ async def check_oauth2_token(token: str) -> UserAPIKeyAuth:
import httpx import httpx
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.proxy._types import CommonProxyErrors from litellm.proxy._types import CommonProxyErrors
from litellm.proxy.proxy_server import premium_user from litellm.proxy.proxy_server import premium_user
@ -40,7 +43,7 @@ async def check_oauth2_token(token: str) -> UserAPIKeyAuth:
if not token_info_endpoint: if not token_info_endpoint:
raise ValueError("OAUTH_TOKEN_INFO_ENDPOINT environment variable is not set") raise ValueError("OAUTH_TOKEN_INFO_ENDPOINT environment variable is not set")
client = get_async_httpx_client() client = get_async_httpx_client(llm_provider=httpxSpecialProvider.Oauth2Check)
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
try: try:

View file

@ -34,6 +34,7 @@ from litellm.llms.base_aws_llm import BaseAWSLLM
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler, AsyncHTTPHandler,
get_async_httpx_client, get_async_httpx_client,
httpxSpecialProvider,
) )
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
@ -55,7 +56,9 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
guardrailVersion: Optional[str] = None, guardrailVersion: Optional[str] = None,
**kwargs, **kwargs,
): ):
self.async_handler = get_async_httpx_client() self.async_handler = get_async_httpx_client(
llm_provider=httpxSpecialProvider.GuardrailCallback
)
self.guardrailIdentifier = guardrailIdentifier self.guardrailIdentifier = guardrailIdentifier
self.guardrailVersion = guardrailVersion self.guardrailVersion = guardrailVersion

View file

@ -22,7 +22,6 @@ import litellm # noqa: E401
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.integrations.custom_guardrail import CustomGuardrail from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.utils import ( from litellm.utils import (
EmbeddingResponse, EmbeddingResponse,