Compare commits

...
Sign in to create a new pull request.

23 commits

Author SHA1 Message Date
Ishaan Jaff
f6f5529621
Merge branch 'main' into litellm_fix_async_http_handler 2024-11-21 19:02:54 -08:00
Ishaan Jaff
d03455a72c fix import 2024-11-21 13:11:06 -08:00
Ishaan Jaff
9067a5031b fix langfuse import 2024-11-21 12:48:17 -08:00
Ishaan Jaff
45130c2d4c fix tests using in_memory_llm_clients_cache 2024-11-21 12:41:09 -08:00
Ishaan Jaff
e63ea48894 fix get_async_httpx_client 2024-11-21 11:18:07 -08:00
Ishaan Jaff
81c0125737 fix check_for_async_http_handler 2024-11-21 10:45:57 -08:00
Ishaan Jaff
ce0061d136 add check for AsyncClient 2024-11-21 10:39:34 -08:00
Krrish Dholakia
e8f47e96c3 test: cleanup mistral model 2024-11-21 10:32:08 -08:00
Ishaan Jaff
bb75af618f fix check_for_async_http_handler 2024-11-21 10:30:16 -08:00
Ishaan Jaff
d4dc8e60b6 fix make_async_azure_httpx_request 2024-11-21 10:27:08 -08:00
Ishaan Jaff
89d76d1eb7 fix get_async_httpx_client 2024-11-21 10:26:18 -08:00
Ishaan Jaff
398e6d0ac6 fix get_async_httpx_client 2024-11-21 10:24:18 -08:00
Ishaan Jaff
0a10b1ef1c fix get_async_httpx_client vertex 2024-11-21 10:22:30 -08:00
Ishaan Jaff
f7f9e8c41f fix dbricks get_async_httpx_client 2024-11-21 10:21:06 -08:00
Ishaan Jaff
0ee9f0fa44 fix vertex fine tuning 2024-11-21 10:20:16 -08:00
Ishaan Jaff
6af0494483 fix anthropic use get_async_httpx_client 2024-11-21 10:18:26 -08:00
Ishaan Jaff
fb5cc97387 fix PREDIBASE 2024-11-21 10:17:18 -08:00
Ishaan Jaff
4d56249eb9 add test_no_async_http_handler_usage 2024-11-21 10:16:07 -08:00
Ishaan Jaff
77232f9bc4 fix HUGGINGFACE 2024-11-21 09:46:04 -08:00
Ishaan Jaff
2719f7fcbf fix CLARIFAI 2024-11-21 09:43:04 -08:00
Ishaan Jaff
3d3d651b89 fix REPLICATE 2024-11-21 09:42:01 -08:00
Ishaan Jaff
fdaee84b82 fix TEXT_COMPLETION_CODESTRAL 2024-11-21 09:40:26 -08:00
Ishaan Jaff
0420b07c13 fix triton 2024-11-21 09:39:48 -08:00
26 changed files with 288 additions and 62 deletions

View file

@ -771,6 +771,7 @@ jobs:
- run: python ./tests/code_coverage_tests/litellm_logging_code_coverage.py
- run: python ./tests/documentation_tests/test_env_keys.py
- run: python ./tests/documentation_tests/test_api_docs.py
- run: python ./tests/code_coverage_tests/ensure_async_clients_test.py
- run: helm lint ./deploy/charts/litellm-helm
db_migration_disable_update_check:

View file

@ -133,7 +133,7 @@ use_client: bool = False
ssl_verify: Union[str, bool] = True
ssl_certificate: Optional[str] = None
disable_streaming_logging: bool = False
in_memory_llm_clients_cache: dict = {}
in_memory_llm_clients_cache: InMemoryCache = InMemoryCache()
safe_memory_mode: bool = False
enable_azure_ad_token_refresh: Optional[bool] = False
### DEFAULT AZURE API VERSION ###

View file

@ -12,7 +12,11 @@ from typing_extensions import overload
import litellm
from litellm.caching.caching import DualCache
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
get_async_httpx_client,
)
from litellm.types.utils import EmbeddingResponse
from litellm.utils import (
CustomStreamWrapper,
@ -977,7 +981,10 @@ class AzureChatCompletion(BaseLLM):
else:
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
async_handler = AsyncHTTPHandler(**_params) # type: ignore
async_handler = get_async_httpx_client(
llm_provider=litellm.LlmProviders.AZURE,
params=_params,
)
else:
async_handler = client # type: ignore

View file

@ -18,6 +18,7 @@ import litellm
from litellm import LlmProviders
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS
from litellm.secret_managers.main import get_secret_str
from litellm.types.utils import ProviderField
from litellm.utils import (
@ -562,8 +563,9 @@ class OpenAIChatCompletion(BaseLLM):
_cache_key = f"hashed_api_key={hashed_api_key},api_base={api_base},timeout={timeout},max_retries={max_retries},organization={organization},is_async={is_async}"
if _cache_key in litellm.in_memory_llm_clients_cache:
return litellm.in_memory_llm_clients_cache[_cache_key]
_cached_client = litellm.in_memory_llm_clients_cache.get_cache(_cache_key)
if _cached_client:
return _cached_client
if is_async:
_new_client: Union[OpenAI, AsyncOpenAI] = AsyncOpenAI(
api_key=api_key,
@ -584,7 +586,11 @@ class OpenAIChatCompletion(BaseLLM):
)
## SAVE CACHE KEY
litellm.in_memory_llm_clients_cache[_cache_key] = _new_client
litellm.in_memory_llm_clients_cache.set_cache(
key=_cache_key,
value=_new_client,
ttl=_DEFAULT_TTL_FOR_HTTPX_CLIENTS,
)
return _new_client
else:

View file

@ -13,7 +13,11 @@ import httpx
import requests
import litellm
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
get_async_httpx_client,
)
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
from ..base import BaseLLM
@ -162,7 +166,10 @@ class AnthropicTextCompletion(BaseLLM):
client=None,
):
if client is None:
client = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.ANTHROPIC,
params={"timeout": httpx.Timeout(timeout=600.0, connect=5.0)},
)
response = await client.post(api_base, headers=headers, data=json.dumps(data))
@ -198,7 +205,10 @@ class AnthropicTextCompletion(BaseLLM):
client=None,
):
if client is None:
client = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.ANTHROPIC,
params={"timeout": httpx.Timeout(timeout=600.0, connect=5.0)},
)
response = await client.post(api_base, headers=headers, data=json.dumps(data))

View file

@ -74,7 +74,10 @@ class AzureAIEmbedding(OpenAIChatCompletion):
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
) -> EmbeddingResponse:
if client is None or not isinstance(client, AsyncHTTPHandler):
client = AsyncHTTPHandler(timeout=timeout, concurrent_limit=1)
client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.AZURE_AI,
params={"timeout": timeout},
)
url = "{}/images/embeddings".format(api_base)

View file

@ -9,7 +9,10 @@ import httpx
import requests
import litellm
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
get_async_httpx_client,
)
from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage
from .prompt_templates.factory import custom_prompt, prompt_factory
@ -185,7 +188,10 @@ async def async_completion(
headers={},
):
async_handler = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
async_handler = get_async_httpx_client(
llm_provider=litellm.LlmProviders.CLARIFAI,
params={"timeout": 600.0},
)
response = await async_handler.post(
url=model, headers=headers, data=json.dumps(data)
)

View file

@ -11,7 +11,11 @@ import requests # type: ignore
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
get_async_httpx_client,
)
from litellm.types.llms.bedrock import CohereEmbeddingRequest
from litellm.utils import Choices, Message, ModelResponse, Usage
@ -71,7 +75,10 @@ async def async_embedding(
)
## COMPLETION CALL
if client is None:
client = AsyncHTTPHandler(concurrent_limit=1, timeout=timeout)
client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.COHERE,
params={"timeout": timeout},
)
try:
response = await client.post(api_base, headers=headers, data=json.dumps(data))

View file

@ -7,6 +7,7 @@ import httpx
from httpx import USE_CLIENT_DEFAULT, AsyncHTTPTransport, HTTPTransport
import litellm
from litellm.caching import InMemoryCache
from .types import httpxSpecialProvider
@ -26,6 +27,7 @@ headers = {
# https://www.python-httpx.org/advanced/timeouts
_DEFAULT_TIMEOUT = httpx.Timeout(timeout=5.0, connect=5.0)
_DEFAULT_TTL_FOR_HTTPX_CLIENTS = 3600 # 1 hour, re-use the same httpx client for 1 hour
class AsyncHTTPHandler:
@ -476,8 +478,9 @@ def get_async_httpx_client(
pass
_cache_key_name = "async_httpx_client" + _params_key_name + llm_provider
if _cache_key_name in litellm.in_memory_llm_clients_cache:
return litellm.in_memory_llm_clients_cache[_cache_key_name]
_cached_client = litellm.in_memory_llm_clients_cache.get_cache(_cache_key_name)
if _cached_client:
return _cached_client
if params is not None:
_new_client = AsyncHTTPHandler(**params)
@ -485,7 +488,11 @@ def get_async_httpx_client(
_new_client = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
litellm.in_memory_llm_clients_cache[_cache_key_name] = _new_client
litellm.in_memory_llm_clients_cache.set_cache(
key=_cache_key_name,
value=_new_client,
ttl=_DEFAULT_TTL_FOR_HTTPX_CLIENTS,
)
return _new_client
@ -505,13 +512,18 @@ def _get_httpx_client(params: Optional[dict] = None) -> HTTPHandler:
pass
_cache_key_name = "httpx_client" + _params_key_name
if _cache_key_name in litellm.in_memory_llm_clients_cache:
return litellm.in_memory_llm_clients_cache[_cache_key_name]
_cached_client = litellm.in_memory_llm_clients_cache.get_cache(_cache_key_name)
if _cached_client:
return _cached_client
if params is not None:
_new_client = HTTPHandler(**params)
else:
_new_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
litellm.in_memory_llm_clients_cache[_cache_key_name] = _new_client
litellm.in_memory_llm_clients_cache.set_cache(
key=_cache_key_name,
value=_new_client,
ttl=_DEFAULT_TTL_FOR_HTTPX_CLIENTS,
)
return _new_client

View file

@ -393,7 +393,10 @@ class DatabricksChatCompletion(BaseLLM):
if timeout is None:
timeout = httpx.Timeout(timeout=600.0, connect=5.0)
self.async_handler = AsyncHTTPHandler(timeout=timeout)
self.async_handler = get_async_httpx_client(
llm_provider=litellm.LlmProviders.DATABRICKS,
params={"timeout": timeout},
)
try:
response = await self.async_handler.post(
@ -610,7 +613,10 @@ class DatabricksChatCompletion(BaseLLM):
response = None
try:
if client is None or isinstance(client, AsyncHTTPHandler):
self.async_client = AsyncHTTPHandler(timeout=timeout) # type: ignore
self.async_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.DATABRICKS,
params={"timeout": timeout},
)
else:
self.async_client = client

View file

@ -5,9 +5,14 @@ from typing import Any, Coroutine, Literal, Optional, Union
import httpx
from openai.types.fine_tuning.fine_tuning_job import FineTuningJob, Hyperparameters
import litellm
from litellm._logging import verbose_logger
from litellm.llms.base import BaseLLM
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
get_async_httpx_client,
)
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexLLM,
)
@ -26,8 +31,9 @@ class VertexFineTuningAPI(VertexLLM):
def __init__(self) -> None:
super().__init__()
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
self.async_handler = get_async_httpx_client(
llm_provider=litellm.LlmProviders.VERTEX_AI,
params={"timeout": 600.0},
)
def convert_response_created_at(self, response: ResponseTuningJob):

View file

@ -263,7 +263,11 @@ def get_hf_task_for_model(model: str) -> Tuple[hf_tasks, str]:
return "text-generation-inference", model # default to tgi
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
get_async_httpx_client,
)
def get_hf_task_embedding_for_model(
@ -301,7 +305,9 @@ async def async_get_hf_task_embedding_for_model(
task_type, hf_tasks_embeddings
)
)
http_client = AsyncHTTPHandler(concurrent_limit=1)
http_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.HUGGINGFACE,
)
model_info = await http_client.get(url=api_base)
@ -1067,7 +1073,9 @@ class Huggingface(BaseLLM):
)
## COMPLETION CALL
if client is None:
client = AsyncHTTPHandler(concurrent_limit=1)
client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.HUGGINGFACE,
)
response = await client.post(api_base, headers=headers, data=json.dumps(data))

View file

@ -45,7 +45,10 @@ class OpenAILikeEmbeddingHandler(OpenAILikeBase):
response = None
try:
if client is None or isinstance(client, AsyncHTTPHandler):
self.async_client = AsyncHTTPHandler(timeout=timeout) # type: ignore
self.async_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.OPENAI,
params={"timeout": timeout},
)
else:
self.async_client = client

View file

@ -19,7 +19,10 @@ import litellm.litellm_core_utils
import litellm.litellm_core_utils.litellm_logging
from litellm import verbose_logger
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
get_async_httpx_client,
)
from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage
from .base import BaseLLM
@ -549,7 +552,10 @@ class PredibaseChatCompletion(BaseLLM):
headers={},
) -> ModelResponse:
async_handler = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=timeout))
async_handler = get_async_httpx_client(
llm_provider=litellm.LlmProviders.PREDIBASE,
params={"timeout": timeout},
)
try:
response = await async_handler.post(
api_base, headers=headers, data=json.dumps(data)

View file

@ -9,7 +9,10 @@ import httpx # type: ignore
import requests # type: ignore
import litellm
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
get_async_httpx_client,
)
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
from .prompt_templates.factory import custom_prompt, prompt_factory
@ -325,7 +328,7 @@ def handle_prediction_response_streaming(prediction_url, api_token, print_verbos
async def async_handle_prediction_response_streaming(
prediction_url, api_token, print_verbose
):
http_handler = AsyncHTTPHandler(concurrent_limit=1)
http_handler = get_async_httpx_client(llm_provider=litellm.LlmProviders.REPLICATE)
previous_output = ""
output_string = ""
@ -560,7 +563,9 @@ async def async_completion(
logging_obj,
print_verbose,
) -> Union[ModelResponse, CustomStreamWrapper]:
http_handler = AsyncHTTPHandler(concurrent_limit=1)
http_handler = get_async_httpx_client(
llm_provider=litellm.LlmProviders.REPLICATE,
)
prediction_url = await async_start_prediction(
version_id,
input_data,

View file

@ -18,7 +18,10 @@ import litellm
from litellm import verbose_logger
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
get_async_httpx_client,
)
from litellm.types.llms.databricks import GenericStreamingChunk
from litellm.utils import (
Choices,
@ -479,8 +482,9 @@ class CodestralTextCompletion(BaseLLM):
headers={},
) -> TextCompletionResponse:
async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=timeout), concurrent_limit=1
async_handler = get_async_httpx_client(
llm_provider=litellm.LlmProviders.TEXT_COMPLETION_CODESTRAL,
params={"timeout": timeout},
)
try:

View file

@ -8,7 +8,11 @@ import httpx # type: ignore
import requests # type: ignore
import litellm
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
get_async_httpx_client,
)
from litellm.utils import (
Choices,
CustomStreamWrapper,
@ -50,8 +54,8 @@ class TritonChatCompletion(BaseLLM):
logging_obj: Any,
api_key: Optional[str] = None,
) -> EmbeddingResponse:
async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
async_handler = get_async_httpx_client(
llm_provider=litellm.LlmProviders.TRITON, params={"timeout": 600.0}
)
response = await async_handler.post(url=api_base, data=json.dumps(data))
@ -261,7 +265,9 @@ class TritonChatCompletion(BaseLLM):
model_response,
type_of_model,
) -> ModelResponse:
handler = AsyncHTTPHandler()
handler = get_async_httpx_client(
llm_provider=litellm.LlmProviders.TRITON, params={"timeout": 600.0}
)
if stream:
return self._ahandle_stream( # type: ignore
handler, api_base, data_for_triton, model, logging_obj

View file

@ -1026,7 +1026,9 @@ async def make_call(
logging_obj,
):
if client is None:
client = AsyncHTTPHandler() # Create a new client if none provided
client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.VERTEX_AI,
)
try:
response = await client.post(api_base, headers=headers, data=data, stream=True)

View file

@ -7,8 +7,13 @@ from typing import Any, List, Literal, Optional, Union
import httpx
import litellm
from litellm import EmbeddingResponse
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
get_async_httpx_client,
)
from litellm.types.llms.openai import EmbeddingInput
from litellm.types.llms.vertex_ai import (
VertexAIBatchEmbeddingsRequestBody,
@ -150,7 +155,10 @@ class GoogleBatchEmbeddings(VertexLLM):
else:
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
async_handler: AsyncHTTPHandler = AsyncHTTPHandler(**_params) # type: ignore
async_handler: AsyncHTTPHandler = get_async_httpx_client(
llm_provider=litellm.LlmProviders.VERTEX_AI,
params={"timeout": timeout},
)
else:
async_handler = client # type: ignore

View file

@ -5,7 +5,11 @@ import httpx
from openai.types.image import Image
import litellm
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
get_async_httpx_client,
)
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexLLM,
)
@ -156,7 +160,10 @@ class VertexImageGeneration(VertexLLM):
else:
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
self.async_handler = AsyncHTTPHandler(**_params) # type: ignore
self.async_handler = get_async_httpx_client(
llm_provider=litellm.LlmProviders.VERTEX_AI,
params={"timeout": timeout},
)
else:
self.async_handler = client # type: ignore

View file

@ -5,7 +5,11 @@ import httpx
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
get_async_httpx_client,
)
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexAIError,
VertexLLM,
@ -172,7 +176,10 @@ class VertexMultimodalEmbedding(VertexLLM):
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
client = AsyncHTTPHandler(**_params) # type: ignore
client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.VERTEX_AI,
params={"timeout": timeout},
)
else:
client = client # type: ignore

View file

@ -14,6 +14,7 @@ from pydantic import BaseModel
import litellm
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS
from litellm.llms.prompt_templates.factory import (
convert_to_anthropic_image_obj,
convert_to_gemini_tool_call_invoke,
@ -93,11 +94,15 @@ def _get_client_cache_key(
def _get_client_from_cache(client_cache_key: str):
return litellm.in_memory_llm_clients_cache.get(client_cache_key, None)
return litellm.in_memory_llm_clients_cache.get_cache(client_cache_key)
def _set_client_in_cache(client_cache_key: str, vertex_llm_model: Any):
litellm.in_memory_llm_clients_cache[client_cache_key] = vertex_llm_model
litellm.in_memory_llm_clients_cache.set_cache(
key=client_cache_key,
value=vertex_llm_model,
ttl=_DEFAULT_TTL_FOR_HTTPX_CLIENTS,
)
def completion( # noqa: PLR0915

View file

@ -24,7 +24,10 @@ import httpx # type: ignore
import requests # type: ignore
import litellm
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
get_async_httpx_client,
)
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.watsonx import WatsonXAIEndpoint
from litellm.utils import EmbeddingResponse, ModelResponse, Usage, map_finish_reason
@ -710,10 +713,13 @@ class RequestManager:
if stream:
request_params["stream"] = stream
try:
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(
timeout=request_params.pop("timeout", 600.0), connect=5.0
),
self.async_handler = get_async_httpx_client(
llm_provider=litellm.LlmProviders.WATSONX,
params={
"timeout": httpx.Timeout(
timeout=request_params.pop("timeout", 600.0), connect=5.0
),
},
)
if "json" in request_params:
request_params["data"] = json.dumps(request_params.pop("json", {}))

View file

@ -0,0 +1,88 @@
import ast
import os
ALLOWED_FILES = [
# local files
"../../litellm/__init__.py",
"../../litellm/llms/custom_httpx/http_handler.py",
# when running on ci/cd
"./litellm/__init__.py",
"./litellm/llms/custom_httpx/http_handler.py",
]
warning_msg = "this is a serious violation that can impact latency. Creating Async clients per request can add +500ms per request"
def check_for_async_http_handler(file_path):
"""
Checks if AsyncHttpHandler is instantiated in the given file.
Returns a list of line numbers where AsyncHttpHandler is used.
"""
print("..checking file=", file_path)
if file_path in ALLOWED_FILES:
return []
with open(file_path, "r") as file:
try:
tree = ast.parse(file.read())
except SyntaxError:
print(f"Warning: Syntax error in file {file_path}")
return []
violations = []
target_names = [
"AsyncHttpHandler",
"AsyncHTTPHandler",
"AsyncClient",
"httpx.AsyncClient",
] # Add variations here
for node in ast.walk(tree):
if isinstance(node, ast.Call):
if isinstance(node.func, ast.Name) and node.func.id.lower() in [
name.lower() for name in target_names
]:
raise ValueError(
f"found violation in file {file_path} line: {node.lineno}. Please use `get_async_httpx_client` instead. {warning_msg}"
)
return violations
def scan_directory_for_async_handler(base_dir):
"""
Scans all Python files in the directory tree for AsyncHttpHandler usage.
Returns a dict of files and line numbers where violations were found.
"""
violations = {}
for root, _, files in os.walk(base_dir):
for file in files:
if file.endswith(".py"):
file_path = os.path.join(root, file)
file_violations = check_for_async_http_handler(file_path)
if file_violations:
violations[file_path] = file_violations
return violations
def test_no_async_http_handler_usage():
"""
Test to ensure AsyncHttpHandler is not used anywhere in the codebase.
"""
base_dir = "./litellm" # Adjust this path as needed
# base_dir = "../../litellm" # LOCAL TESTING
violations = scan_directory_for_async_handler(base_dir)
if violations:
violation_messages = []
for file_path, line_numbers in violations.items():
violation_messages.append(
f"Found AsyncHttpHandler in {file_path} at lines: {line_numbers}"
)
raise AssertionError(
"AsyncHttpHandler usage detected:\n" + "\n".join(violation_messages)
)
if __name__ == "__main__":
test_no_async_http_handler_usage()

View file

@ -8,6 +8,7 @@ import traceback
from dotenv import load_dotenv
from openai.types.image import Image
from litellm.caching import InMemoryCache
logging.basicConfig(level=logging.DEBUG)
load_dotenv()
@ -107,7 +108,7 @@ class TestVertexImageGeneration(BaseImageGenTest):
# comment this when running locally
load_vertex_ai_credentials()
litellm.in_memory_llm_clients_cache = {}
litellm.in_memory_llm_clients_cache = InMemoryCache()
return {
"model": "vertex_ai/imagegeneration@006",
"vertex_ai_project": "adroit-crow-413218",
@ -118,13 +119,13 @@ class TestVertexImageGeneration(BaseImageGenTest):
class TestBedrockSd3(BaseImageGenTest):
def get_base_image_generation_call_args(self) -> dict:
litellm.in_memory_llm_clients_cache = {}
litellm.in_memory_llm_clients_cache = InMemoryCache()
return {"model": "bedrock/stability.sd3-large-v1:0"}
class TestBedrockSd1(BaseImageGenTest):
def get_base_image_generation_call_args(self) -> dict:
litellm.in_memory_llm_clients_cache = {}
litellm.in_memory_llm_clients_cache = InMemoryCache()
return {"model": "bedrock/stability.sd3-large-v1:0"}
@ -181,7 +182,7 @@ def test_image_generation_azure_dall_e_3():
@pytest.mark.asyncio
async def test_aimage_generation_bedrock_with_optional_params():
try:
litellm.in_memory_llm_clients_cache = {}
litellm.in_memory_llm_clients_cache = InMemoryCache()
response = await litellm.aimage_generation(
prompt="A cute baby sea otter",
model="bedrock/stability.stable-diffusion-xl-v1",

View file

@ -12,6 +12,7 @@ sys.path.insert(0, os.path.abspath("../.."))
import litellm
from litellm import completion
from litellm.caching import InMemoryCache
litellm.num_retries = 3
litellm.success_callback = ["langfuse"]
@ -29,15 +30,20 @@ def langfuse_client():
f"{os.environ['LANGFUSE_PUBLIC_KEY']}-{os.environ['LANGFUSE_SECRET_KEY']}"
)
# use a in memory langfuse client for testing, RAM util on ci/cd gets too high when we init many langfuse clients
if _langfuse_cache_key in litellm.in_memory_llm_clients_cache:
langfuse_client = litellm.in_memory_llm_clients_cache[_langfuse_cache_key]
_cached_client = litellm.in_memory_llm_clients_cache.get_cache(_langfuse_cache_key)
if _cached_client:
langfuse_client = _cached_client
else:
langfuse_client = langfuse.Langfuse(
public_key=os.environ["LANGFUSE_PUBLIC_KEY"],
secret_key=os.environ["LANGFUSE_SECRET_KEY"],
host=None,
)
litellm.in_memory_llm_clients_cache[_langfuse_cache_key] = langfuse_client
litellm.in_memory_llm_clients_cache.set_cache(
key=_langfuse_cache_key,
value=langfuse_client,
)
print("NEW LANGFUSE CLIENT")