forked from phoenix/litellm-mirror
pass llm provider when creating async httpx clients
This commit is contained in:
parent
87bac7c026
commit
421b857714
14 changed files with 57 additions and 19 deletions
|
@ -1244,6 +1244,7 @@ class QdrantSemanticCache(BaseCache):
|
|||
from litellm.llms.custom_httpx.http_handler import (
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
|
||||
if collection_name is None:
|
||||
|
@ -1290,7 +1291,9 @@ class QdrantSemanticCache(BaseCache):
|
|||
self.headers = headers
|
||||
|
||||
self.sync_client = _get_httpx_client()
|
||||
self.async_client = get_async_httpx_client()
|
||||
self.async_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.Caching
|
||||
)
|
||||
|
||||
if quantization_config is None:
|
||||
print_verbose(
|
||||
|
|
|
@ -918,7 +918,9 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
headers={},
|
||||
client=None,
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
async_handler = get_async_httpx_client()
|
||||
async_handler = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.ANTHROPIC
|
||||
)
|
||||
|
||||
try:
|
||||
response = await async_handler.post(
|
||||
|
|
|
@ -209,7 +209,9 @@ async def make_call(
|
|||
):
|
||||
try:
|
||||
if client is None:
|
||||
client = get_async_httpx_client() # Create a new client if none provided
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.BEDROCK
|
||||
) # Create a new client if none provided
|
||||
|
||||
response = await client.post(
|
||||
api_base,
|
||||
|
@ -1041,7 +1043,7 @@ class BedrockLLM(BaseAWSLLM):
|
|||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = timeout
|
||||
client = get_async_httpx_client(_params) # type: ignore
|
||||
client = get_async_httpx_client(_params, llm_provider=litellm.LlmProviders.BEDROCK) # type: ignore
|
||||
else:
|
||||
client = client # type: ignore
|
||||
|
||||
|
@ -1498,7 +1500,7 @@ class BedrockConverseLLM(BaseAWSLLM):
|
|||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = timeout
|
||||
client = get_async_httpx_client(_params) # type: ignore
|
||||
client = get_async_httpx_client(_params, llm_provider=litellm.LlmProviders.BEDROCK) # type: ignore
|
||||
else:
|
||||
client = client # type: ignore
|
||||
|
||||
|
|
|
@ -130,7 +130,7 @@ class BedrockEmbedding(BaseAWSLLM):
|
|||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = timeout
|
||||
client = get_async_httpx_client(_params) # type: ignore
|
||||
client = get_async_httpx_client(_params, llm_provider=litellm.LlmProviders.BEDROCK) # type: ignore
|
||||
else:
|
||||
client = client
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional, Union
|
|||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base import BaseLLM
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
_get_httpx_client,
|
||||
|
@ -65,7 +66,7 @@ class CohereRerank(BaseLLM):
|
|||
api_key: str,
|
||||
api_base: str,
|
||||
) -> RerankResponse:
|
||||
client = get_async_httpx_client()
|
||||
client = get_async_httpx_client(llm_provider=litellm.LlmProviders.COHERE)
|
||||
|
||||
response = await client.post(
|
||||
api_base,
|
||||
|
|
|
@ -7,6 +7,8 @@ import httpx
|
|||
|
||||
import litellm
|
||||
|
||||
from .types import httpxSpecialProvider
|
||||
|
||||
try:
|
||||
from litellm._version import version
|
||||
except:
|
||||
|
@ -378,7 +380,10 @@ class HTTPHandler:
|
|||
pass
|
||||
|
||||
|
||||
def get_async_httpx_client(params: Optional[dict] = None) -> AsyncHTTPHandler:
|
||||
def get_async_httpx_client(
|
||||
llm_provider: Union[litellm.LlmProviders, httpxSpecialProvider],
|
||||
params: Optional[dict] = None,
|
||||
) -> AsyncHTTPHandler:
|
||||
"""
|
||||
Retrieves the async HTTP client from the cache
|
||||
If not present, creates a new client
|
||||
|
@ -393,7 +398,7 @@ def get_async_httpx_client(params: Optional[dict] = None) -> AsyncHTTPHandler:
|
|||
except Exception:
|
||||
pass
|
||||
|
||||
_cache_key_name = "async_httpx_client" + _params_key_name
|
||||
_cache_key_name = "async_httpx_client" + _params_key_name + llm_provider
|
||||
if _cache_key_name in litellm.in_memory_llm_clients_cache:
|
||||
return litellm.in_memory_llm_clients_cache[_cache_key_name]
|
||||
|
||||
|
|
10
litellm/llms/custom_httpx/types.py
Normal file
10
litellm/llms/custom_httpx/types.py
Normal file
|
@ -0,0 +1,10 @@
|
|||
from enum import Enum
|
||||
|
||||
import litellm
|
||||
|
||||
|
||||
class httpxSpecialProvider(str, Enum):
|
||||
LoggingCallback = "logging_callback"
|
||||
GuardrailCallback = "guardrail_callback"
|
||||
Caching = "caching"
|
||||
Oauth2Check = "oauth2_check"
|
|
@ -565,8 +565,8 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
):
|
||||
try:
|
||||
if client is None:
|
||||
client = (
|
||||
get_async_httpx_client()
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.SAGEMAKER
|
||||
) # Create a new client if none provided
|
||||
response = await client.post(
|
||||
api_base,
|
||||
|
@ -673,7 +673,9 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
model_id: Optional[str],
|
||||
):
|
||||
timeout = 300.0
|
||||
async_handler = get_async_httpx_client()
|
||||
async_handler = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.SAGEMAKER
|
||||
)
|
||||
|
||||
async_transform_prompt = asyncify(self._transform_prompt)
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional, Union
|
|||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base import BaseLLM
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
_get_httpx_client,
|
||||
|
@ -77,7 +78,9 @@ class TogetherAIRerank(BaseLLM):
|
|||
request_data_dict: Dict[str, Any],
|
||||
api_key: str,
|
||||
) -> RerankResponse:
|
||||
client = get_async_httpx_client() # Use async client
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.TOGETHER_AI
|
||||
) # Use async client
|
||||
|
||||
response = await client.post(
|
||||
"https://api.together.xyz/v1/rerank",
|
||||
|
|
|
@ -1293,7 +1293,9 @@ class VertexLLM(BaseLLM):
|
|||
_async_client_params = {}
|
||||
if timeout:
|
||||
_async_client_params["timeout"] = timeout
|
||||
client = get_async_httpx_client(params=_async_client_params)
|
||||
client = get_async_httpx_client(
|
||||
params=_async_client_params, llm_provider=litellm.LlmProviders.VERTEX_AI
|
||||
)
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
|
|
|
@ -4,6 +4,7 @@ from typing import Any, Coroutine, Literal, Optional, TypedDict, Union
|
|||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.base import BaseLLM
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
|
@ -178,7 +179,9 @@ class VertexTextToSpeechAPI(VertexLLM):
|
|||
) -> HttpxBinaryResponseContent:
|
||||
import base64
|
||||
|
||||
async_handler = get_async_httpx_client()
|
||||
async_handler = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.VERTEX_AI
|
||||
)
|
||||
|
||||
response = await async_handler.post(
|
||||
url=url,
|
||||
|
|
|
@ -20,7 +20,10 @@ async def check_oauth2_token(token: str) -> UserAPIKeyAuth:
|
|||
import httpx
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.proxy._types import CommonProxyErrors
|
||||
from litellm.proxy.proxy_server import premium_user
|
||||
|
||||
|
@ -40,7 +43,7 @@ async def check_oauth2_token(token: str) -> UserAPIKeyAuth:
|
|||
if not token_info_endpoint:
|
||||
raise ValueError("OAUTH_TOKEN_INFO_ENDPOINT environment variable is not set")
|
||||
|
||||
client = get_async_httpx_client()
|
||||
client = get_async_httpx_client(llm_provider=httpxSpecialProvider.Oauth2Check)
|
||||
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
|
||||
|
||||
try:
|
||||
|
|
|
@ -34,6 +34,7 @@ from litellm.llms.base_aws_llm import BaseAWSLLM
|
|||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
|
||||
|
@ -55,7 +56,9 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
|
|||
guardrailVersion: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.async_handler = get_async_httpx_client()
|
||||
self.async_handler = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.GuardrailCallback
|
||||
)
|
||||
self.guardrailIdentifier = guardrailIdentifier
|
||||
self.guardrailVersion = guardrailVersion
|
||||
|
||||
|
|
|
@ -22,7 +22,6 @@ import litellm # noqa: E401
|
|||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching import DualCache
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.utils import (
|
||||
EmbeddingResponse,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue