diff --git a/litellm/__init__.py b/litellm/__init__.py index f64923ef29..e0f9b59c9e 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -1017,6 +1017,7 @@ ALL_LITELLM_RESPONSE_TYPES = [ from .llms.custom_llm import CustomLLM from .llms.openai_like.chat.handler import OpenAILikeChatConfig +from .llms.aiohttp_openai.chat.transformation import AiohttpOpenAIChatConfig from .llms.galadriel.chat.transformation import GaladrielChatConfig from .llms.github.chat.transformation import GithubChatConfig from .llms.empower.chat.transformation import EmpowerChatConfig diff --git a/litellm/llms/aiohttp_openai/chat/transformation.py b/litellm/llms/aiohttp_openai/chat/transformation.py new file mode 100644 index 0000000000..bc961026e6 --- /dev/null +++ b/litellm/llms/aiohttp_openai/chat/transformation.py @@ -0,0 +1,52 @@ +""" +*New config* for using aiohttp to make the request to the custom OpenAI-like provider + +This leads to 10x higher RPS than httpx +https://github.com/BerriAI/litellm/issues/6592 + +New config to ensure we introduce this without causing breaking changes for users +""" + +from typing import TYPE_CHECKING, Any, List, Optional + +import httpx + +from litellm.llms.openai_like.chat.transformation import OpenAILikeChatConfig +from litellm.types.llms.openai import AllMessageValues +from litellm.types.utils import ModelResponse + +if TYPE_CHECKING: + from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj + + LiteLLMLoggingObj = _LiteLLMLoggingObj +else: + LiteLLMLoggingObj = Any + + +class AiohttpOpenAIChatConfig(OpenAILikeChatConfig): + def validate_environment( + self, + headers: dict, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + ) -> dict: + return {} + + def transform_response( + self, + model: str, + raw_response: httpx.Response, + model_response: ModelResponse, + logging_obj: LiteLLMLoggingObj, + request_data: dict, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + encoding: Any, + api_key: Optional[str] = None, + json_mode: Optional[bool] = None, + ) -> ModelResponse: + return ModelResponse(**raw_response.json()) diff --git a/litellm/llms/custom_httpx/aiohttp_handler.py b/litellm/llms/custom_httpx/aiohttp_handler.py new file mode 100644 index 0000000000..ab6872e043 --- /dev/null +++ b/litellm/llms/custom_httpx/aiohttp_handler.py @@ -0,0 +1,395 @@ +import json +from typing import TYPE_CHECKING, Any, Optional, Tuple, Union + +import aiohttp # Add this import +import httpx # type: ignore + +import litellm +import litellm.litellm_core_utils +import litellm.types +import litellm.types.utils +from litellm.llms.base_llm.chat.transformation import BaseConfig +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + HTTPHandler, + _get_httpx_client, + get_async_httpx_client, +) +from litellm.utils import CustomStreamWrapper, ModelResponse, ProviderConfigManager + +if TYPE_CHECKING: + from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj + + LiteLLMLoggingObj = _LiteLLMLoggingObj +else: + LiteLLMLoggingObj = Any + +DEFAULT_TIMEOUT = 600 + + +class BaseLLMAIOHTTPHandler: + + async def _make_common_async_call( + self, + async_httpx_client: AsyncHTTPHandler, + provider_config: BaseConfig, + api_base: str, + headers: dict, + data: dict, + timeout: Union[float, httpx.Timeout], + litellm_params: dict, + stream: bool = False, + ) -> aiohttp.ClientResponse: + """Common implementation across stream + non-stream calls. Meant to ensure consistent error-handling.""" + max_retry_on_unprocessable_entity_error = ( + provider_config.max_retry_on_unprocessable_entity_error + ) + + response: Optional[aiohttp.ClientResponse] = None + timeout_obj = aiohttp.ClientTimeout( + total=( + float(timeout) if isinstance(timeout, (int, float)) else DEFAULT_TIMEOUT + ) + ) + + async with aiohttp.ClientSession(timeout=timeout_obj) as session: + for i in range(max(max_retry_on_unprocessable_entity_error, 1)): + try: + response = await session.post( + url=api_base, + headers=headers, + json=data, + ) + if not response.ok: + response.raise_for_status() + except aiohttp.ClientResponseError as e: + raise self._handle_error(e=e, provider_config=provider_config) + except Exception as e: + raise self._handle_error(e=e, provider_config=provider_config) + break + + if response is None: + raise provider_config.get_error_class( + error_message="No response from the API", + status_code=422, + headers={}, + ) + + return response + + def _make_common_sync_call( + self, + sync_httpx_client: HTTPHandler, + provider_config: BaseConfig, + api_base: str, + headers: dict, + data: dict, + timeout: Union[float, httpx.Timeout], + litellm_params: dict, + stream: bool = False, + ) -> httpx.Response: + + max_retry_on_unprocessable_entity_error = ( + provider_config.max_retry_on_unprocessable_entity_error + ) + + response: Optional[httpx.Response] = None + + for i in range(max(max_retry_on_unprocessable_entity_error, 1)): + try: + response = sync_httpx_client.post( + url=api_base, + headers=headers, + data=json.dumps(data), + timeout=timeout, + stream=stream, + ) + except httpx.HTTPStatusError as e: + hit_max_retry = i + 1 == max_retry_on_unprocessable_entity_error + should_retry = provider_config.should_retry_llm_api_inside_llm_translation_on_http_error( + e=e, litellm_params=litellm_params + ) + if should_retry and not hit_max_retry: + data = ( + provider_config.transform_request_on_unprocessable_entity_error( + e=e, request_data=data + ) + ) + continue + else: + raise self._handle_error(e=e, provider_config=provider_config) + except Exception as e: + raise self._handle_error(e=e, provider_config=provider_config) + break + + if response is None: + raise provider_config.get_error_class( + error_message="No response from the API", + status_code=422, # don't retry on this error + headers={}, + ) + + return response + + async def async_completion( + self, + custom_llm_provider: str, + provider_config: BaseConfig, + api_base: str, + headers: dict, + data: dict, + timeout: Union[float, httpx.Timeout], + model: str, + model_response: ModelResponse, + logging_obj: LiteLLMLoggingObj, + messages: list, + optional_params: dict, + litellm_params: dict, + encoding: Any, + api_key: Optional[str] = None, + client: Optional[AsyncHTTPHandler] = None, + ): + if client is None: + async_httpx_client = get_async_httpx_client( + llm_provider=litellm.LlmProviders(custom_llm_provider) + ) + else: + async_httpx_client = client + + _response = await self._make_common_async_call( + async_httpx_client=async_httpx_client, + provider_config=provider_config, + api_base=api_base, + headers=headers, + data=data, + timeout=timeout, + litellm_params=litellm_params, + stream=False, + ) + _json_response = await _response.json() + + # cast to httpx.Response + # Todo - use this until we migrate fully to aiohttp + response = httpx.Response( + status_code=_response.status, + headers=_response.headers, + json=_json_response, + ) + return provider_config.transform_response( + model=model, + raw_response=response, + model_response=model_response, + logging_obj=logging_obj, + api_key=api_key, + request_data=data, + messages=messages, + optional_params=optional_params, + litellm_params=litellm_params, + encoding=encoding, + ) + + def completion( + self, + model: str, + messages: list, + api_base: str, + custom_llm_provider: str, + model_response: ModelResponse, + encoding, + logging_obj: LiteLLMLoggingObj, + optional_params: dict, + timeout: Union[float, httpx.Timeout], + litellm_params: dict, + acompletion: bool, + stream: Optional[bool] = False, + fake_stream: bool = False, + api_key: Optional[str] = None, + headers: Optional[dict] = {}, + client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, + ): + provider_config = ProviderConfigManager.get_provider_chat_config( + model=model, provider=litellm.LlmProviders(custom_llm_provider) + ) + # get config from model, custom llm provider + headers = provider_config.validate_environment( + api_key=api_key, + headers=headers or {}, + model=model, + messages=messages, + optional_params=optional_params, + api_base=api_base, + ) + + api_base = provider_config.get_complete_url( + api_base=api_base, + model=model, + optional_params=optional_params, + stream=stream, + ) + + data = provider_config.transform_request( + model=model, + messages=messages, + optional_params=optional_params, + litellm_params=litellm_params, + headers=headers, + ) + + ## LOGGING + logging_obj.pre_call( + input=messages, + api_key=api_key, + additional_args={ + "complete_input_dict": data, + "api_base": api_base, + "headers": headers, + }, + ) + + if acompletion is True: + return self.async_completion( + custom_llm_provider=custom_llm_provider, + provider_config=provider_config, + api_base=api_base, + headers=headers, + data=data, + timeout=timeout, + model=model, + model_response=model_response, + logging_obj=logging_obj, + api_key=api_key, + messages=messages, + optional_params=optional_params, + litellm_params=litellm_params, + encoding=encoding, + client=( + client + if client is not None and isinstance(client, AsyncHTTPHandler) + else None + ), + ) + + if stream is True: + if fake_stream is not True: + data["stream"] = stream + completion_stream, headers = self.make_sync_call( + provider_config=provider_config, + api_base=api_base, + headers=headers, # type: ignore + data=data, + model=model, + messages=messages, + logging_obj=logging_obj, + timeout=timeout, + fake_stream=fake_stream, + client=( + client + if client is not None and isinstance(client, HTTPHandler) + else None + ), + litellm_params=litellm_params, + ) + return CustomStreamWrapper( + completion_stream=completion_stream, + model=model, + custom_llm_provider=custom_llm_provider, + logging_obj=logging_obj, + ) + + if client is None or not isinstance(client, HTTPHandler): + sync_httpx_client = _get_httpx_client() + else: + sync_httpx_client = client + + response = self._make_common_sync_call( + sync_httpx_client=sync_httpx_client, + provider_config=provider_config, + api_base=api_base, + headers=headers, + data=data, + timeout=timeout, + litellm_params=litellm_params, + ) + return provider_config.transform_response( + model=model, + raw_response=response, + model_response=model_response, + logging_obj=logging_obj, + api_key=api_key, + request_data=data, + messages=messages, + optional_params=optional_params, + litellm_params=litellm_params, + encoding=encoding, + ) + + def make_sync_call( + self, + provider_config: BaseConfig, + api_base: str, + headers: dict, + data: dict, + model: str, + messages: list, + logging_obj, + litellm_params: dict, + timeout: Union[float, httpx.Timeout], + fake_stream: bool = False, + client: Optional[HTTPHandler] = None, + ) -> Tuple[Any, dict]: + if client is None or not isinstance(client, HTTPHandler): + sync_httpx_client = _get_httpx_client() + else: + sync_httpx_client = client + stream = True + if fake_stream is True: + stream = False + + response = self._make_common_sync_call( + sync_httpx_client=sync_httpx_client, + provider_config=provider_config, + api_base=api_base, + headers=headers, + data=data, + timeout=timeout, + litellm_params=litellm_params, + stream=stream, + ) + + if fake_stream is True: + completion_stream = provider_config.get_model_response_iterator( + streaming_response=response.json(), sync_stream=True + ) + else: + completion_stream = provider_config.get_model_response_iterator( + streaming_response=response.iter_lines(), sync_stream=True + ) + + # LOGGING + logging_obj.post_call( + input=messages, + api_key="", + original_response="first stream response received", + additional_args={"complete_input_dict": data}, + ) + + return completion_stream, dict(response.headers) + + def _handle_error(self, e: Exception, provider_config: BaseConfig): + status_code = getattr(e, "status_code", 500) + error_headers = getattr(e, "headers", None) + error_text = getattr(e, "text", str(e)) + error_response = getattr(e, "response", None) + if error_headers is None and error_response: + error_headers = getattr(error_response, "headers", None) + if error_response and hasattr(error_response, "text"): + error_text = getattr(error_response, "text", error_text) + if error_headers: + error_headers = dict(error_headers) + else: + error_headers = {} + raise provider_config.get_error_class( + error_message=error_text, + status_code=status_code, + headers=error_headers, + ) diff --git a/litellm/main.py b/litellm/main.py index 846eb9f2dd..e35b68abf8 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -115,6 +115,7 @@ from .llms.bedrock.embed.embedding import BedrockEmbedding from .llms.bedrock.image.image_handler import BedrockImageGeneration from .llms.codestral.completion.handler import CodestralTextCompletion from .llms.cohere.embed import handler as cohere_embed +from .llms.custom_httpx.aiohttp_handler import BaseLLMAIOHTTPHandler from .llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler from .llms.custom_llm import CustomLLM, custom_chat_llm_router from .llms.databricks.chat.handler import DatabricksChatCompletion @@ -217,6 +218,7 @@ openai_like_embedding = OpenAILikeEmbeddingHandler() openai_like_chat_completion = OpenAILikeChatHandler() databricks_embedding = DatabricksEmbeddingHandler() base_llm_http_handler = BaseLLMHTTPHandler() +base_llm_aiohttp_handler = BaseLLMAIOHTTPHandler() sagemaker_chat_completion = SagemakerChatHandler() ####### COMPLETION ENDPOINTS ################ @@ -474,6 +476,7 @@ async def acompletion( or custom_llm_provider == "clarifai" or custom_llm_provider == "watsonx" or custom_llm_provider == "cloudflare" + or custom_llm_provider == "aiohttp_openai" or custom_llm_provider in litellm.openai_compatible_providers or custom_llm_provider in litellm._custom_providers ): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all. @@ -2851,6 +2854,42 @@ def completion( # type: ignore # noqa: PLR0915 ) return response response = model_response + elif custom_llm_provider == "aiohttp_openai": + api_base = ( + api_base # for deepinfra/perplexity/anyscale/groq/friendliai we check in get_llm_provider and pass in the api base from there + or litellm.api_base + or get_secret("OPENAI_API_BASE") + or "https://api.openai.com/v1" + ) + # set API KEY + api_key = ( + api_key + or litellm.api_key # for deepinfra/perplexity/anyscale/friendliai we check in get_llm_provider and pass in the api key from there + or litellm.openai_key + or get_secret("OPENAI_API_KEY") + ) + + headers = headers or litellm.headers + + if extra_headers is not None: + optional_params["extra_headers"] = extra_headers + response = base_llm_aiohttp_handler.completion( + model=model, + messages=messages, + headers=headers, + model_response=model_response, + api_key=api_key, + api_base=api_base, + acompletion=acompletion, + logging_obj=logging, + optional_params=optional_params, + litellm_params=litellm_params, + timeout=timeout, + client=client, + custom_llm_provider=custom_llm_provider, + encoding=encoding, + stream=stream, + ) elif custom_llm_provider == "custom": url = litellm.api_base or api_base or "" if url is None or url == "": diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 6c375752dd..e80b764cf9 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -1,33 +1,9 @@ model_list: - - model_name: "openai/*" + - model_name: "fake-openai-endpoint" litellm_params: - model: "openai/*" - api_key: os.environ/OPENAI_API_KEY - - model_name: "azure/*" - litellm_params: - model: azure/chatgpt-v-2 - api_base: https://openai-gpt-4-test-v-1.openai.azure.com/ - client_id: os.environ/AZURE_CLIENT_ID - azure_username: os.environ/AZURE_USERNAME - azure_password: os.environ/AZURE_PASSWORD -litellm_settings: - callbacks: ["datadog"] - + model: openai/any + api_base: https://exampleopenaiendpoint-production.up.railway.app + api_key: "ishaan" general_settings: - alerting: ["pagerduty"] - alerting_args: - failure_threshold: 4 # Number of requests failing in a window - failure_threshold_window_seconds: 10 # Window in seconds - - # Requests hanging threshold - hanging_threshold_seconds: 0.0000001 # Number of seconds of waiting for a response before a request is considered hanging - hanging_threshold_window_seconds: 10 # Window in seconds - key_management_system: "hashicorp_vault" - -# For /fine_tuning/jobs endpoints -finetune_settings: - - custom_llm_provider: "vertex_ai" - vertex_project: "adroit-crow-413218" - vertex_location: "us-central1" - vertex_credentials: "/Users/ishaanjaffer/Downloads/adroit-crow-413218-a956eef1a2a8.json" + master_key: sk-1234 \ No newline at end of file diff --git a/litellm/types/utils.py b/litellm/types/utils.py index a1ef3e6e56..623400ae45 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -1799,6 +1799,7 @@ class LlmProviders(str, Enum): GALADRIEL = "galadriel" INFINITY = "infinity" DEEPGRAM = "deepgram" + AIOHTTP_OPENAI = "aiohttp_openai" class LiteLLMLoggingBaseClass: diff --git a/litellm/utils.py b/litellm/utils.py index 1fd79cebae..68a58c6e22 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -6147,6 +6147,8 @@ class ProviderConfigManager: or litellm.LlmProviders.LITELLM_PROXY == provider ): return litellm.OpenAILikeChatConfig() + elif litellm.LlmProviders.AIOHTTP_OPENAI == provider: + return litellm.AiohttpOpenAIChatConfig() elif litellm.LlmProviders.HOSTED_VLLM == provider: return litellm.HostedVLLMChatConfig() elif litellm.LlmProviders.LM_STUDIO == provider: diff --git a/tests/llm_translation/test_aiohttp_openai.py b/tests/llm_translation/test_aiohttp_openai.py new file mode 100644 index 0000000000..16155bacf5 --- /dev/null +++ b/tests/llm_translation/test_aiohttp_openai.py @@ -0,0 +1,24 @@ +import json +import os +import sys +from datetime import datetime +import pytest + +sys.path.insert( + 0, os.path.abspath("../../") +) # Adds the parent directory to the system path + +import litellm + + +@pytest.mark.asyncio +async def test_aiohttp_openai(): + litellm.set_verbose = True + response = await litellm.acompletion( + model="aiohttp_openai/fake-model", + messages=[{"role": "user", "content": "Hello, world!"}], + api_base="https://exampleopenaiendpoint-production.up.railway.app/v1/chat/completions", + api_key="fake-key", + ) + print(response) + print(response.model_dump_json(indent=4))