(perf) use aiohttp for custom_openai (#7514)

* use aiohttp handler

* BaseLLMAIOHTTPHandler

* use CustomOpenAIChatConfig

* CustomOpenAIChatConfig

* CustomOpenAIChatConfig

* fix linting

* AiohttpOpenAIChatConfig

* fix order

* aiohttp_openai
This commit is contained in:
Ishaan Jaff 2025-01-02 22:15:17 -08:00 committed by GitHub
parent 2d57581307
commit 3a454ee2ce
8 changed files with 519 additions and 29 deletions

View file

@ -1017,6 +1017,7 @@ ALL_LITELLM_RESPONSE_TYPES = [
from .llms.custom_llm import CustomLLM from .llms.custom_llm import CustomLLM
from .llms.openai_like.chat.handler import OpenAILikeChatConfig from .llms.openai_like.chat.handler import OpenAILikeChatConfig
from .llms.aiohttp_openai.chat.transformation import AiohttpOpenAIChatConfig
from .llms.galadriel.chat.transformation import GaladrielChatConfig from .llms.galadriel.chat.transformation import GaladrielChatConfig
from .llms.github.chat.transformation import GithubChatConfig from .llms.github.chat.transformation import GithubChatConfig
from .llms.empower.chat.transformation import EmpowerChatConfig from .llms.empower.chat.transformation import EmpowerChatConfig

View file

@ -0,0 +1,52 @@
"""
*New config* for using aiohttp to make the request to the custom OpenAI-like provider
This leads to 10x higher RPS than httpx
https://github.com/BerriAI/litellm/issues/6592
New config to ensure we introduce this without causing breaking changes for users
"""
from typing import TYPE_CHECKING, Any, List, Optional
import httpx
from litellm.llms.openai_like.chat.transformation import OpenAILikeChatConfig
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class AiohttpOpenAIChatConfig(OpenAILikeChatConfig):
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
return {}
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
return ModelResponse(**raw_response.json())

View file

@ -0,0 +1,395 @@
import json
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union
import aiohttp # Add this import
import httpx # type: ignore
import litellm
import litellm.litellm_core_utils
import litellm.types
import litellm.types.utils
from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_httpx_client,
get_async_httpx_client,
)
from litellm.utils import CustomStreamWrapper, ModelResponse, ProviderConfigManager
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
DEFAULT_TIMEOUT = 600
class BaseLLMAIOHTTPHandler:
async def _make_common_async_call(
self,
async_httpx_client: AsyncHTTPHandler,
provider_config: BaseConfig,
api_base: str,
headers: dict,
data: dict,
timeout: Union[float, httpx.Timeout],
litellm_params: dict,
stream: bool = False,
) -> aiohttp.ClientResponse:
"""Common implementation across stream + non-stream calls. Meant to ensure consistent error-handling."""
max_retry_on_unprocessable_entity_error = (
provider_config.max_retry_on_unprocessable_entity_error
)
response: Optional[aiohttp.ClientResponse] = None
timeout_obj = aiohttp.ClientTimeout(
total=(
float(timeout) if isinstance(timeout, (int, float)) else DEFAULT_TIMEOUT
)
)
async with aiohttp.ClientSession(timeout=timeout_obj) as session:
for i in range(max(max_retry_on_unprocessable_entity_error, 1)):
try:
response = await session.post(
url=api_base,
headers=headers,
json=data,
)
if not response.ok:
response.raise_for_status()
except aiohttp.ClientResponseError as e:
raise self._handle_error(e=e, provider_config=provider_config)
except Exception as e:
raise self._handle_error(e=e, provider_config=provider_config)
break
if response is None:
raise provider_config.get_error_class(
error_message="No response from the API",
status_code=422,
headers={},
)
return response
def _make_common_sync_call(
self,
sync_httpx_client: HTTPHandler,
provider_config: BaseConfig,
api_base: str,
headers: dict,
data: dict,
timeout: Union[float, httpx.Timeout],
litellm_params: dict,
stream: bool = False,
) -> httpx.Response:
max_retry_on_unprocessable_entity_error = (
provider_config.max_retry_on_unprocessable_entity_error
)
response: Optional[httpx.Response] = None
for i in range(max(max_retry_on_unprocessable_entity_error, 1)):
try:
response = sync_httpx_client.post(
url=api_base,
headers=headers,
data=json.dumps(data),
timeout=timeout,
stream=stream,
)
except httpx.HTTPStatusError as e:
hit_max_retry = i + 1 == max_retry_on_unprocessable_entity_error
should_retry = provider_config.should_retry_llm_api_inside_llm_translation_on_http_error(
e=e, litellm_params=litellm_params
)
if should_retry and not hit_max_retry:
data = (
provider_config.transform_request_on_unprocessable_entity_error(
e=e, request_data=data
)
)
continue
else:
raise self._handle_error(e=e, provider_config=provider_config)
except Exception as e:
raise self._handle_error(e=e, provider_config=provider_config)
break
if response is None:
raise provider_config.get_error_class(
error_message="No response from the API",
status_code=422, # don't retry on this error
headers={},
)
return response
async def async_completion(
self,
custom_llm_provider: str,
provider_config: BaseConfig,
api_base: str,
headers: dict,
data: dict,
timeout: Union[float, httpx.Timeout],
model: str,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
messages: list,
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
client: Optional[AsyncHTTPHandler] = None,
):
if client is None:
async_httpx_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders(custom_llm_provider)
)
else:
async_httpx_client = client
_response = await self._make_common_async_call(
async_httpx_client=async_httpx_client,
provider_config=provider_config,
api_base=api_base,
headers=headers,
data=data,
timeout=timeout,
litellm_params=litellm_params,
stream=False,
)
_json_response = await _response.json()
# cast to httpx.Response
# Todo - use this until we migrate fully to aiohttp
response = httpx.Response(
status_code=_response.status,
headers=_response.headers,
json=_json_response,
)
return provider_config.transform_response(
model=model,
raw_response=response,
model_response=model_response,
logging_obj=logging_obj,
api_key=api_key,
request_data=data,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding,
)
def completion(
self,
model: str,
messages: list,
api_base: str,
custom_llm_provider: str,
model_response: ModelResponse,
encoding,
logging_obj: LiteLLMLoggingObj,
optional_params: dict,
timeout: Union[float, httpx.Timeout],
litellm_params: dict,
acompletion: bool,
stream: Optional[bool] = False,
fake_stream: bool = False,
api_key: Optional[str] = None,
headers: Optional[dict] = {},
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
):
provider_config = ProviderConfigManager.get_provider_chat_config(
model=model, provider=litellm.LlmProviders(custom_llm_provider)
)
# get config from model, custom llm provider
headers = provider_config.validate_environment(
api_key=api_key,
headers=headers or {},
model=model,
messages=messages,
optional_params=optional_params,
api_base=api_base,
)
api_base = provider_config.get_complete_url(
api_base=api_base,
model=model,
optional_params=optional_params,
stream=stream,
)
data = provider_config.transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
)
## LOGGING
logging_obj.pre_call(
input=messages,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"api_base": api_base,
"headers": headers,
},
)
if acompletion is True:
return self.async_completion(
custom_llm_provider=custom_llm_provider,
provider_config=provider_config,
api_base=api_base,
headers=headers,
data=data,
timeout=timeout,
model=model,
model_response=model_response,
logging_obj=logging_obj,
api_key=api_key,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding,
client=(
client
if client is not None and isinstance(client, AsyncHTTPHandler)
else None
),
)
if stream is True:
if fake_stream is not True:
data["stream"] = stream
completion_stream, headers = self.make_sync_call(
provider_config=provider_config,
api_base=api_base,
headers=headers, # type: ignore
data=data,
model=model,
messages=messages,
logging_obj=logging_obj,
timeout=timeout,
fake_stream=fake_stream,
client=(
client
if client is not None and isinstance(client, HTTPHandler)
else None
),
litellm_params=litellm_params,
)
return CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider=custom_llm_provider,
logging_obj=logging_obj,
)
if client is None or not isinstance(client, HTTPHandler):
sync_httpx_client = _get_httpx_client()
else:
sync_httpx_client = client
response = self._make_common_sync_call(
sync_httpx_client=sync_httpx_client,
provider_config=provider_config,
api_base=api_base,
headers=headers,
data=data,
timeout=timeout,
litellm_params=litellm_params,
)
return provider_config.transform_response(
model=model,
raw_response=response,
model_response=model_response,
logging_obj=logging_obj,
api_key=api_key,
request_data=data,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding,
)
def make_sync_call(
self,
provider_config: BaseConfig,
api_base: str,
headers: dict,
data: dict,
model: str,
messages: list,
logging_obj,
litellm_params: dict,
timeout: Union[float, httpx.Timeout],
fake_stream: bool = False,
client: Optional[HTTPHandler] = None,
) -> Tuple[Any, dict]:
if client is None or not isinstance(client, HTTPHandler):
sync_httpx_client = _get_httpx_client()
else:
sync_httpx_client = client
stream = True
if fake_stream is True:
stream = False
response = self._make_common_sync_call(
sync_httpx_client=sync_httpx_client,
provider_config=provider_config,
api_base=api_base,
headers=headers,
data=data,
timeout=timeout,
litellm_params=litellm_params,
stream=stream,
)
if fake_stream is True:
completion_stream = provider_config.get_model_response_iterator(
streaming_response=response.json(), sync_stream=True
)
else:
completion_stream = provider_config.get_model_response_iterator(
streaming_response=response.iter_lines(), sync_stream=True
)
# LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response="first stream response received",
additional_args={"complete_input_dict": data},
)
return completion_stream, dict(response.headers)
def _handle_error(self, e: Exception, provider_config: BaseConfig):
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
error_text = getattr(e, "text", str(e))
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
if error_response and hasattr(error_response, "text"):
error_text = getattr(error_response, "text", error_text)
if error_headers:
error_headers = dict(error_headers)
else:
error_headers = {}
raise provider_config.get_error_class(
error_message=error_text,
status_code=status_code,
headers=error_headers,
)

View file

@ -115,6 +115,7 @@ from .llms.bedrock.embed.embedding import BedrockEmbedding
from .llms.bedrock.image.image_handler import BedrockImageGeneration from .llms.bedrock.image.image_handler import BedrockImageGeneration
from .llms.codestral.completion.handler import CodestralTextCompletion from .llms.codestral.completion.handler import CodestralTextCompletion
from .llms.cohere.embed import handler as cohere_embed from .llms.cohere.embed import handler as cohere_embed
from .llms.custom_httpx.aiohttp_handler import BaseLLMAIOHTTPHandler
from .llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler from .llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
from .llms.custom_llm import CustomLLM, custom_chat_llm_router from .llms.custom_llm import CustomLLM, custom_chat_llm_router
from .llms.databricks.chat.handler import DatabricksChatCompletion from .llms.databricks.chat.handler import DatabricksChatCompletion
@ -217,6 +218,7 @@ openai_like_embedding = OpenAILikeEmbeddingHandler()
openai_like_chat_completion = OpenAILikeChatHandler() openai_like_chat_completion = OpenAILikeChatHandler()
databricks_embedding = DatabricksEmbeddingHandler() databricks_embedding = DatabricksEmbeddingHandler()
base_llm_http_handler = BaseLLMHTTPHandler() base_llm_http_handler = BaseLLMHTTPHandler()
base_llm_aiohttp_handler = BaseLLMAIOHTTPHandler()
sagemaker_chat_completion = SagemakerChatHandler() sagemaker_chat_completion = SagemakerChatHandler()
####### COMPLETION ENDPOINTS ################ ####### COMPLETION ENDPOINTS ################
@ -474,6 +476,7 @@ async def acompletion(
or custom_llm_provider == "clarifai" or custom_llm_provider == "clarifai"
or custom_llm_provider == "watsonx" or custom_llm_provider == "watsonx"
or custom_llm_provider == "cloudflare" or custom_llm_provider == "cloudflare"
or custom_llm_provider == "aiohttp_openai"
or custom_llm_provider in litellm.openai_compatible_providers or custom_llm_provider in litellm.openai_compatible_providers
or custom_llm_provider in litellm._custom_providers or custom_llm_provider in litellm._custom_providers
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all. ): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
@ -2851,6 +2854,42 @@ def completion( # type: ignore # noqa: PLR0915
) )
return response return response
response = model_response response = model_response
elif custom_llm_provider == "aiohttp_openai":
api_base = (
api_base # for deepinfra/perplexity/anyscale/groq/friendliai we check in get_llm_provider and pass in the api base from there
or litellm.api_base
or get_secret("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
# set API KEY
api_key = (
api_key
or litellm.api_key # for deepinfra/perplexity/anyscale/friendliai we check in get_llm_provider and pass in the api key from there
or litellm.openai_key
or get_secret("OPENAI_API_KEY")
)
headers = headers or litellm.headers
if extra_headers is not None:
optional_params["extra_headers"] = extra_headers
response = base_llm_aiohttp_handler.completion(
model=model,
messages=messages,
headers=headers,
model_response=model_response,
api_key=api_key,
api_base=api_base,
acompletion=acompletion,
logging_obj=logging,
optional_params=optional_params,
litellm_params=litellm_params,
timeout=timeout,
client=client,
custom_llm_provider=custom_llm_provider,
encoding=encoding,
stream=stream,
)
elif custom_llm_provider == "custom": elif custom_llm_provider == "custom":
url = litellm.api_base or api_base or "" url = litellm.api_base or api_base or ""
if url is None or url == "": if url is None or url == "":

View file

@ -1,33 +1,9 @@
model_list: model_list:
- model_name: "openai/*" - model_name: "fake-openai-endpoint"
litellm_params: litellm_params:
model: "openai/*" model: openai/any
api_key: os.environ/OPENAI_API_KEY api_base: https://exampleopenaiendpoint-production.up.railway.app
- model_name: "azure/*" api_key: "ishaan"
litellm_params:
model: azure/chatgpt-v-2
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
client_id: os.environ/AZURE_CLIENT_ID
azure_username: os.environ/AZURE_USERNAME
azure_password: os.environ/AZURE_PASSWORD
litellm_settings:
callbacks: ["datadog"]
general_settings: general_settings:
alerting: ["pagerduty"] master_key: sk-1234
alerting_args:
failure_threshold: 4 # Number of requests failing in a window
failure_threshold_window_seconds: 10 # Window in seconds
# Requests hanging threshold
hanging_threshold_seconds: 0.0000001 # Number of seconds of waiting for a response before a request is considered hanging
hanging_threshold_window_seconds: 10 # Window in seconds
key_management_system: "hashicorp_vault"
# For /fine_tuning/jobs endpoints
finetune_settings:
- custom_llm_provider: "vertex_ai"
vertex_project: "adroit-crow-413218"
vertex_location: "us-central1"
vertex_credentials: "/Users/ishaanjaffer/Downloads/adroit-crow-413218-a956eef1a2a8.json"

View file

@ -1799,6 +1799,7 @@ class LlmProviders(str, Enum):
GALADRIEL = "galadriel" GALADRIEL = "galadriel"
INFINITY = "infinity" INFINITY = "infinity"
DEEPGRAM = "deepgram" DEEPGRAM = "deepgram"
AIOHTTP_OPENAI = "aiohttp_openai"
class LiteLLMLoggingBaseClass: class LiteLLMLoggingBaseClass:

View file

@ -6147,6 +6147,8 @@ class ProviderConfigManager:
or litellm.LlmProviders.LITELLM_PROXY == provider or litellm.LlmProviders.LITELLM_PROXY == provider
): ):
return litellm.OpenAILikeChatConfig() return litellm.OpenAILikeChatConfig()
elif litellm.LlmProviders.AIOHTTP_OPENAI == provider:
return litellm.AiohttpOpenAIChatConfig()
elif litellm.LlmProviders.HOSTED_VLLM == provider: elif litellm.LlmProviders.HOSTED_VLLM == provider:
return litellm.HostedVLLMChatConfig() return litellm.HostedVLLMChatConfig()
elif litellm.LlmProviders.LM_STUDIO == provider: elif litellm.LlmProviders.LM_STUDIO == provider:

View file

@ -0,0 +1,24 @@
import json
import os
import sys
from datetime import datetime
import pytest
sys.path.insert(
0, os.path.abspath("../../")
) # Adds the parent directory to the system path
import litellm
@pytest.mark.asyncio
async def test_aiohttp_openai():
litellm.set_verbose = True
response = await litellm.acompletion(
model="aiohttp_openai/fake-model",
messages=[{"role": "user", "content": "Hello, world!"}],
api_base="https://exampleopenaiendpoint-production.up.railway.app/v1/chat/completions",
api_key="fake-key",
)
print(response)
print(response.model_dump_json(indent=4))