fix get_async_httpx_client

This commit is contained in:
Ishaan Jaff 2024-11-21 11:18:07 -08:00
parent 81c0125737
commit e63ea48894
4 changed files with 35 additions and 12 deletions

View file

@ -133,7 +133,7 @@ use_client: bool = False
ssl_verify: Union[str, bool] = True
ssl_certificate: Optional[str] = None
disable_streaming_logging: bool = False
in_memory_llm_clients_cache: dict = {}
in_memory_llm_clients_cache: InMemoryCache = InMemoryCache()
safe_memory_mode: bool = False
enable_azure_ad_token_refresh: Optional[bool] = False
### DEFAULT AZURE API VERSION ###

View file

@ -18,6 +18,7 @@ import litellm
from litellm import LlmProviders
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS
from litellm.secret_managers.main import get_secret_str
from litellm.types.utils import ProviderField
from litellm.utils import (
@ -562,8 +563,9 @@ class OpenAIChatCompletion(BaseLLM):
_cache_key = f"hashed_api_key={hashed_api_key},api_base={api_base},timeout={timeout},max_retries={max_retries},organization={organization},is_async={is_async}"
if _cache_key in litellm.in_memory_llm_clients_cache:
return litellm.in_memory_llm_clients_cache[_cache_key]
_cached_client = litellm.in_memory_llm_clients_cache.get_cache(_cache_key)
if _cached_client:
return _cached_client
if is_async:
_new_client: Union[OpenAI, AsyncOpenAI] = AsyncOpenAI(
api_key=api_key,
@ -584,7 +586,11 @@ class OpenAIChatCompletion(BaseLLM):
)
## SAVE CACHE KEY
litellm.in_memory_llm_clients_cache[_cache_key] = _new_client
litellm.in_memory_llm_clients_cache.set_cache(
key=_cache_key,
value=_new_client,
ttl=_DEFAULT_TTL_FOR_HTTPX_CLIENTS,
)
return _new_client
else:

View file

@ -7,6 +7,7 @@ import httpx
from httpx import USE_CLIENT_DEFAULT, AsyncHTTPTransport, HTTPTransport
import litellm
from litellm.caching import InMemoryCache
from .types import httpxSpecialProvider
@ -26,6 +27,7 @@ headers = {
# https://www.python-httpx.org/advanced/timeouts
_DEFAULT_TIMEOUT = httpx.Timeout(timeout=5.0, connect=5.0)
_DEFAULT_TTL_FOR_HTTPX_CLIENTS = 3600 # 1 hour, re-use the same httpx client for 1 hour
class AsyncHTTPHandler:
@ -476,8 +478,9 @@ def get_async_httpx_client(
pass
_cache_key_name = "async_httpx_client" + _params_key_name + llm_provider
if _cache_key_name in litellm.in_memory_llm_clients_cache:
return litellm.in_memory_llm_clients_cache[_cache_key_name]
_cached_client = litellm.in_memory_llm_clients_cache.get_cache(_cache_key_name)
if _cached_client:
return _cached_client
if params is not None:
_new_client = AsyncHTTPHandler(**params)
@ -485,7 +488,11 @@ def get_async_httpx_client(
_new_client = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
litellm.in_memory_llm_clients_cache[_cache_key_name] = _new_client
litellm.in_memory_llm_clients_cache.set_cache(
key=_cache_key_name,
value=_new_client,
ttl=_DEFAULT_TTL_FOR_HTTPX_CLIENTS,
)
return _new_client
@ -505,13 +512,18 @@ def _get_httpx_client(params: Optional[dict] = None) -> HTTPHandler:
pass
_cache_key_name = "httpx_client" + _params_key_name
if _cache_key_name in litellm.in_memory_llm_clients_cache:
return litellm.in_memory_llm_clients_cache[_cache_key_name]
_cached_client = litellm.in_memory_llm_clients_cache.get_cache(_cache_key_name)
if _cached_client:
return _cached_client
if params is not None:
_new_client = HTTPHandler(**params)
else:
_new_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
litellm.in_memory_llm_clients_cache[_cache_key_name] = _new_client
litellm.in_memory_llm_clients_cache.set_cache(
key=_cache_key_name,
value=_new_client,
ttl=_DEFAULT_TTL_FOR_HTTPX_CLIENTS,
)
return _new_client

View file

@ -14,6 +14,7 @@ from pydantic import BaseModel
import litellm
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS
from litellm.llms.prompt_templates.factory import (
convert_to_anthropic_image_obj,
convert_to_gemini_tool_call_invoke,
@ -93,11 +94,15 @@ def _get_client_cache_key(
def _get_client_from_cache(client_cache_key: str):
return litellm.in_memory_llm_clients_cache.get(client_cache_key, None)
return litellm.in_memory_llm_clients_cache.get_cache(client_cache_key)
def _set_client_in_cache(client_cache_key: str, vertex_llm_model: Any):
litellm.in_memory_llm_clients_cache[client_cache_key] = vertex_llm_model
litellm.in_memory_llm_clients_cache.set_cache(
key=client_cache_key,
value=vertex_llm_model,
ttl=_DEFAULT_TTL_FOR_HTTPX_CLIENTS,
)
def completion( # noqa: PLR0915