mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
(perf) use aiohttp
for custom_openai
(#7514)
* use aiohttp handler * BaseLLMAIOHTTPHandler * use CustomOpenAIChatConfig * CustomOpenAIChatConfig * CustomOpenAIChatConfig * fix linting * AiohttpOpenAIChatConfig * fix order * aiohttp_openai
This commit is contained in:
parent
2d57581307
commit
3a454ee2ce
8 changed files with 519 additions and 29 deletions
|
@ -1017,6 +1017,7 @@ ALL_LITELLM_RESPONSE_TYPES = [
|
||||||
|
|
||||||
from .llms.custom_llm import CustomLLM
|
from .llms.custom_llm import CustomLLM
|
||||||
from .llms.openai_like.chat.handler import OpenAILikeChatConfig
|
from .llms.openai_like.chat.handler import OpenAILikeChatConfig
|
||||||
|
from .llms.aiohttp_openai.chat.transformation import AiohttpOpenAIChatConfig
|
||||||
from .llms.galadriel.chat.transformation import GaladrielChatConfig
|
from .llms.galadriel.chat.transformation import GaladrielChatConfig
|
||||||
from .llms.github.chat.transformation import GithubChatConfig
|
from .llms.github.chat.transformation import GithubChatConfig
|
||||||
from .llms.empower.chat.transformation import EmpowerChatConfig
|
from .llms.empower.chat.transformation import EmpowerChatConfig
|
||||||
|
|
52
litellm/llms/aiohttp_openai/chat/transformation.py
Normal file
52
litellm/llms/aiohttp_openai/chat/transformation.py
Normal file
|
@ -0,0 +1,52 @@
|
||||||
|
"""
|
||||||
|
*New config* for using aiohttp to make the request to the custom OpenAI-like provider
|
||||||
|
|
||||||
|
This leads to 10x higher RPS than httpx
|
||||||
|
https://github.com/BerriAI/litellm/issues/6592
|
||||||
|
|
||||||
|
New config to ensure we introduce this without causing breaking changes for users
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Any, List, Optional
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from litellm.llms.openai_like.chat.transformation import OpenAILikeChatConfig
|
||||||
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
|
from litellm.types.utils import ModelResponse
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||||
|
|
||||||
|
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||||
|
else:
|
||||||
|
LiteLLMLoggingObj = Any
|
||||||
|
|
||||||
|
|
||||||
|
class AiohttpOpenAIChatConfig(OpenAILikeChatConfig):
|
||||||
|
def validate_environment(
|
||||||
|
self,
|
||||||
|
headers: dict,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
|
) -> dict:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def transform_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
raw_response: httpx.Response,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
request_data: dict,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
encoding: Any,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
json_mode: Optional[bool] = None,
|
||||||
|
) -> ModelResponse:
|
||||||
|
return ModelResponse(**raw_response.json())
|
395
litellm/llms/custom_httpx/aiohttp_handler.py
Normal file
395
litellm/llms/custom_httpx/aiohttp_handler.py
Normal file
|
@ -0,0 +1,395 @@
|
||||||
|
import json
|
||||||
|
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import aiohttp # Add this import
|
||||||
|
import httpx # type: ignore
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
import litellm.litellm_core_utils
|
||||||
|
import litellm.types
|
||||||
|
import litellm.types.utils
|
||||||
|
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||||
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
AsyncHTTPHandler,
|
||||||
|
HTTPHandler,
|
||||||
|
_get_httpx_client,
|
||||||
|
get_async_httpx_client,
|
||||||
|
)
|
||||||
|
from litellm.utils import CustomStreamWrapper, ModelResponse, ProviderConfigManager
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||||
|
|
||||||
|
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||||
|
else:
|
||||||
|
LiteLLMLoggingObj = Any
|
||||||
|
|
||||||
|
DEFAULT_TIMEOUT = 600
|
||||||
|
|
||||||
|
|
||||||
|
class BaseLLMAIOHTTPHandler:
|
||||||
|
|
||||||
|
async def _make_common_async_call(
|
||||||
|
self,
|
||||||
|
async_httpx_client: AsyncHTTPHandler,
|
||||||
|
provider_config: BaseConfig,
|
||||||
|
api_base: str,
|
||||||
|
headers: dict,
|
||||||
|
data: dict,
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
litellm_params: dict,
|
||||||
|
stream: bool = False,
|
||||||
|
) -> aiohttp.ClientResponse:
|
||||||
|
"""Common implementation across stream + non-stream calls. Meant to ensure consistent error-handling."""
|
||||||
|
max_retry_on_unprocessable_entity_error = (
|
||||||
|
provider_config.max_retry_on_unprocessable_entity_error
|
||||||
|
)
|
||||||
|
|
||||||
|
response: Optional[aiohttp.ClientResponse] = None
|
||||||
|
timeout_obj = aiohttp.ClientTimeout(
|
||||||
|
total=(
|
||||||
|
float(timeout) if isinstance(timeout, (int, float)) else DEFAULT_TIMEOUT
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession(timeout=timeout_obj) as session:
|
||||||
|
for i in range(max(max_retry_on_unprocessable_entity_error, 1)):
|
||||||
|
try:
|
||||||
|
response = await session.post(
|
||||||
|
url=api_base,
|
||||||
|
headers=headers,
|
||||||
|
json=data,
|
||||||
|
)
|
||||||
|
if not response.ok:
|
||||||
|
response.raise_for_status()
|
||||||
|
except aiohttp.ClientResponseError as e:
|
||||||
|
raise self._handle_error(e=e, provider_config=provider_config)
|
||||||
|
except Exception as e:
|
||||||
|
raise self._handle_error(e=e, provider_config=provider_config)
|
||||||
|
break
|
||||||
|
|
||||||
|
if response is None:
|
||||||
|
raise provider_config.get_error_class(
|
||||||
|
error_message="No response from the API",
|
||||||
|
status_code=422,
|
||||||
|
headers={},
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
def _make_common_sync_call(
|
||||||
|
self,
|
||||||
|
sync_httpx_client: HTTPHandler,
|
||||||
|
provider_config: BaseConfig,
|
||||||
|
api_base: str,
|
||||||
|
headers: dict,
|
||||||
|
data: dict,
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
litellm_params: dict,
|
||||||
|
stream: bool = False,
|
||||||
|
) -> httpx.Response:
|
||||||
|
|
||||||
|
max_retry_on_unprocessable_entity_error = (
|
||||||
|
provider_config.max_retry_on_unprocessable_entity_error
|
||||||
|
)
|
||||||
|
|
||||||
|
response: Optional[httpx.Response] = None
|
||||||
|
|
||||||
|
for i in range(max(max_retry_on_unprocessable_entity_error, 1)):
|
||||||
|
try:
|
||||||
|
response = sync_httpx_client.post(
|
||||||
|
url=api_base,
|
||||||
|
headers=headers,
|
||||||
|
data=json.dumps(data),
|
||||||
|
timeout=timeout,
|
||||||
|
stream=stream,
|
||||||
|
)
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
hit_max_retry = i + 1 == max_retry_on_unprocessable_entity_error
|
||||||
|
should_retry = provider_config.should_retry_llm_api_inside_llm_translation_on_http_error(
|
||||||
|
e=e, litellm_params=litellm_params
|
||||||
|
)
|
||||||
|
if should_retry and not hit_max_retry:
|
||||||
|
data = (
|
||||||
|
provider_config.transform_request_on_unprocessable_entity_error(
|
||||||
|
e=e, request_data=data
|
||||||
|
)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
raise self._handle_error(e=e, provider_config=provider_config)
|
||||||
|
except Exception as e:
|
||||||
|
raise self._handle_error(e=e, provider_config=provider_config)
|
||||||
|
break
|
||||||
|
|
||||||
|
if response is None:
|
||||||
|
raise provider_config.get_error_class(
|
||||||
|
error_message="No response from the API",
|
||||||
|
status_code=422, # don't retry on this error
|
||||||
|
headers={},
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def async_completion(
|
||||||
|
self,
|
||||||
|
custom_llm_provider: str,
|
||||||
|
provider_config: BaseConfig,
|
||||||
|
api_base: str,
|
||||||
|
headers: dict,
|
||||||
|
data: dict,
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
model: str,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
messages: list,
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
encoding: Any,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
|
):
|
||||||
|
if client is None:
|
||||||
|
async_httpx_client = get_async_httpx_client(
|
||||||
|
llm_provider=litellm.LlmProviders(custom_llm_provider)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
async_httpx_client = client
|
||||||
|
|
||||||
|
_response = await self._make_common_async_call(
|
||||||
|
async_httpx_client=async_httpx_client,
|
||||||
|
provider_config=provider_config,
|
||||||
|
api_base=api_base,
|
||||||
|
headers=headers,
|
||||||
|
data=data,
|
||||||
|
timeout=timeout,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
_json_response = await _response.json()
|
||||||
|
|
||||||
|
# cast to httpx.Response
|
||||||
|
# Todo - use this until we migrate fully to aiohttp
|
||||||
|
response = httpx.Response(
|
||||||
|
status_code=_response.status,
|
||||||
|
headers=_response.headers,
|
||||||
|
json=_json_response,
|
||||||
|
)
|
||||||
|
return provider_config.transform_response(
|
||||||
|
model=model,
|
||||||
|
raw_response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
api_key=api_key,
|
||||||
|
request_data=data,
|
||||||
|
messages=messages,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
|
|
||||||
|
def completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
custom_llm_provider: str,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
encoding,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
optional_params: dict,
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
litellm_params: dict,
|
||||||
|
acompletion: bool,
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
fake_stream: bool = False,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
headers: Optional[dict] = {},
|
||||||
|
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||||
|
):
|
||||||
|
provider_config = ProviderConfigManager.get_provider_chat_config(
|
||||||
|
model=model, provider=litellm.LlmProviders(custom_llm_provider)
|
||||||
|
)
|
||||||
|
# get config from model, custom llm provider
|
||||||
|
headers = provider_config.validate_environment(
|
||||||
|
api_key=api_key,
|
||||||
|
headers=headers or {},
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
optional_params=optional_params,
|
||||||
|
api_base=api_base,
|
||||||
|
)
|
||||||
|
|
||||||
|
api_base = provider_config.get_complete_url(
|
||||||
|
api_base=api_base,
|
||||||
|
model=model,
|
||||||
|
optional_params=optional_params,
|
||||||
|
stream=stream,
|
||||||
|
)
|
||||||
|
|
||||||
|
data = provider_config.transform_request(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=messages,
|
||||||
|
api_key=api_key,
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": data,
|
||||||
|
"api_base": api_base,
|
||||||
|
"headers": headers,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if acompletion is True:
|
||||||
|
return self.async_completion(
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
provider_config=provider_config,
|
||||||
|
api_base=api_base,
|
||||||
|
headers=headers,
|
||||||
|
data=data,
|
||||||
|
timeout=timeout,
|
||||||
|
model=model,
|
||||||
|
model_response=model_response,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
api_key=api_key,
|
||||||
|
messages=messages,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
encoding=encoding,
|
||||||
|
client=(
|
||||||
|
client
|
||||||
|
if client is not None and isinstance(client, AsyncHTTPHandler)
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if stream is True:
|
||||||
|
if fake_stream is not True:
|
||||||
|
data["stream"] = stream
|
||||||
|
completion_stream, headers = self.make_sync_call(
|
||||||
|
provider_config=provider_config,
|
||||||
|
api_base=api_base,
|
||||||
|
headers=headers, # type: ignore
|
||||||
|
data=data,
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
timeout=timeout,
|
||||||
|
fake_stream=fake_stream,
|
||||||
|
client=(
|
||||||
|
client
|
||||||
|
if client is not None and isinstance(client, HTTPHandler)
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
)
|
||||||
|
return CustomStreamWrapper(
|
||||||
|
completion_stream=completion_stream,
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
|
||||||
|
if client is None or not isinstance(client, HTTPHandler):
|
||||||
|
sync_httpx_client = _get_httpx_client()
|
||||||
|
else:
|
||||||
|
sync_httpx_client = client
|
||||||
|
|
||||||
|
response = self._make_common_sync_call(
|
||||||
|
sync_httpx_client=sync_httpx_client,
|
||||||
|
provider_config=provider_config,
|
||||||
|
api_base=api_base,
|
||||||
|
headers=headers,
|
||||||
|
data=data,
|
||||||
|
timeout=timeout,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
)
|
||||||
|
return provider_config.transform_response(
|
||||||
|
model=model,
|
||||||
|
raw_response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
api_key=api_key,
|
||||||
|
request_data=data,
|
||||||
|
messages=messages,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
|
|
||||||
|
def make_sync_call(
|
||||||
|
self,
|
||||||
|
provider_config: BaseConfig,
|
||||||
|
api_base: str,
|
||||||
|
headers: dict,
|
||||||
|
data: dict,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
logging_obj,
|
||||||
|
litellm_params: dict,
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
fake_stream: bool = False,
|
||||||
|
client: Optional[HTTPHandler] = None,
|
||||||
|
) -> Tuple[Any, dict]:
|
||||||
|
if client is None or not isinstance(client, HTTPHandler):
|
||||||
|
sync_httpx_client = _get_httpx_client()
|
||||||
|
else:
|
||||||
|
sync_httpx_client = client
|
||||||
|
stream = True
|
||||||
|
if fake_stream is True:
|
||||||
|
stream = False
|
||||||
|
|
||||||
|
response = self._make_common_sync_call(
|
||||||
|
sync_httpx_client=sync_httpx_client,
|
||||||
|
provider_config=provider_config,
|
||||||
|
api_base=api_base,
|
||||||
|
headers=headers,
|
||||||
|
data=data,
|
||||||
|
timeout=timeout,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
stream=stream,
|
||||||
|
)
|
||||||
|
|
||||||
|
if fake_stream is True:
|
||||||
|
completion_stream = provider_config.get_model_response_iterator(
|
||||||
|
streaming_response=response.json(), sync_stream=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
completion_stream = provider_config.get_model_response_iterator(
|
||||||
|
streaming_response=response.iter_lines(), sync_stream=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key="",
|
||||||
|
original_response="first stream response received",
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
)
|
||||||
|
|
||||||
|
return completion_stream, dict(response.headers)
|
||||||
|
|
||||||
|
def _handle_error(self, e: Exception, provider_config: BaseConfig):
|
||||||
|
status_code = getattr(e, "status_code", 500)
|
||||||
|
error_headers = getattr(e, "headers", None)
|
||||||
|
error_text = getattr(e, "text", str(e))
|
||||||
|
error_response = getattr(e, "response", None)
|
||||||
|
if error_headers is None and error_response:
|
||||||
|
error_headers = getattr(error_response, "headers", None)
|
||||||
|
if error_response and hasattr(error_response, "text"):
|
||||||
|
error_text = getattr(error_response, "text", error_text)
|
||||||
|
if error_headers:
|
||||||
|
error_headers = dict(error_headers)
|
||||||
|
else:
|
||||||
|
error_headers = {}
|
||||||
|
raise provider_config.get_error_class(
|
||||||
|
error_message=error_text,
|
||||||
|
status_code=status_code,
|
||||||
|
headers=error_headers,
|
||||||
|
)
|
|
@ -115,6 +115,7 @@ from .llms.bedrock.embed.embedding import BedrockEmbedding
|
||||||
from .llms.bedrock.image.image_handler import BedrockImageGeneration
|
from .llms.bedrock.image.image_handler import BedrockImageGeneration
|
||||||
from .llms.codestral.completion.handler import CodestralTextCompletion
|
from .llms.codestral.completion.handler import CodestralTextCompletion
|
||||||
from .llms.cohere.embed import handler as cohere_embed
|
from .llms.cohere.embed import handler as cohere_embed
|
||||||
|
from .llms.custom_httpx.aiohttp_handler import BaseLLMAIOHTTPHandler
|
||||||
from .llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
|
from .llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
|
||||||
from .llms.custom_llm import CustomLLM, custom_chat_llm_router
|
from .llms.custom_llm import CustomLLM, custom_chat_llm_router
|
||||||
from .llms.databricks.chat.handler import DatabricksChatCompletion
|
from .llms.databricks.chat.handler import DatabricksChatCompletion
|
||||||
|
@ -217,6 +218,7 @@ openai_like_embedding = OpenAILikeEmbeddingHandler()
|
||||||
openai_like_chat_completion = OpenAILikeChatHandler()
|
openai_like_chat_completion = OpenAILikeChatHandler()
|
||||||
databricks_embedding = DatabricksEmbeddingHandler()
|
databricks_embedding = DatabricksEmbeddingHandler()
|
||||||
base_llm_http_handler = BaseLLMHTTPHandler()
|
base_llm_http_handler = BaseLLMHTTPHandler()
|
||||||
|
base_llm_aiohttp_handler = BaseLLMAIOHTTPHandler()
|
||||||
sagemaker_chat_completion = SagemakerChatHandler()
|
sagemaker_chat_completion = SagemakerChatHandler()
|
||||||
####### COMPLETION ENDPOINTS ################
|
####### COMPLETION ENDPOINTS ################
|
||||||
|
|
||||||
|
@ -474,6 +476,7 @@ async def acompletion(
|
||||||
or custom_llm_provider == "clarifai"
|
or custom_llm_provider == "clarifai"
|
||||||
or custom_llm_provider == "watsonx"
|
or custom_llm_provider == "watsonx"
|
||||||
or custom_llm_provider == "cloudflare"
|
or custom_llm_provider == "cloudflare"
|
||||||
|
or custom_llm_provider == "aiohttp_openai"
|
||||||
or custom_llm_provider in litellm.openai_compatible_providers
|
or custom_llm_provider in litellm.openai_compatible_providers
|
||||||
or custom_llm_provider in litellm._custom_providers
|
or custom_llm_provider in litellm._custom_providers
|
||||||
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
|
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
|
||||||
|
@ -2851,6 +2854,42 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
response = model_response
|
response = model_response
|
||||||
|
elif custom_llm_provider == "aiohttp_openai":
|
||||||
|
api_base = (
|
||||||
|
api_base # for deepinfra/perplexity/anyscale/groq/friendliai we check in get_llm_provider and pass in the api base from there
|
||||||
|
or litellm.api_base
|
||||||
|
or get_secret("OPENAI_API_BASE")
|
||||||
|
or "https://api.openai.com/v1"
|
||||||
|
)
|
||||||
|
# set API KEY
|
||||||
|
api_key = (
|
||||||
|
api_key
|
||||||
|
or litellm.api_key # for deepinfra/perplexity/anyscale/friendliai we check in get_llm_provider and pass in the api key from there
|
||||||
|
or litellm.openai_key
|
||||||
|
or get_secret("OPENAI_API_KEY")
|
||||||
|
)
|
||||||
|
|
||||||
|
headers = headers or litellm.headers
|
||||||
|
|
||||||
|
if extra_headers is not None:
|
||||||
|
optional_params["extra_headers"] = extra_headers
|
||||||
|
response = base_llm_aiohttp_handler.completion(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
headers=headers,
|
||||||
|
model_response=model_response,
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
acompletion=acompletion,
|
||||||
|
logging_obj=logging,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
timeout=timeout,
|
||||||
|
client=client,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
encoding=encoding,
|
||||||
|
stream=stream,
|
||||||
|
)
|
||||||
elif custom_llm_provider == "custom":
|
elif custom_llm_provider == "custom":
|
||||||
url = litellm.api_base or api_base or ""
|
url = litellm.api_base or api_base or ""
|
||||||
if url is None or url == "":
|
if url is None or url == "":
|
||||||
|
|
|
@ -1,33 +1,9 @@
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: "openai/*"
|
- model_name: "fake-openai-endpoint"
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: "openai/*"
|
model: openai/any
|
||||||
api_key: os.environ/OPENAI_API_KEY
|
api_base: https://exampleopenaiendpoint-production.up.railway.app
|
||||||
- model_name: "azure/*"
|
api_key: "ishaan"
|
||||||
litellm_params:
|
|
||||||
model: azure/chatgpt-v-2
|
|
||||||
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
|
|
||||||
client_id: os.environ/AZURE_CLIENT_ID
|
|
||||||
azure_username: os.environ/AZURE_USERNAME
|
|
||||||
azure_password: os.environ/AZURE_PASSWORD
|
|
||||||
litellm_settings:
|
|
||||||
callbacks: ["datadog"]
|
|
||||||
|
|
||||||
|
|
||||||
general_settings:
|
general_settings:
|
||||||
alerting: ["pagerduty"]
|
master_key: sk-1234
|
||||||
alerting_args:
|
|
||||||
failure_threshold: 4 # Number of requests failing in a window
|
|
||||||
failure_threshold_window_seconds: 10 # Window in seconds
|
|
||||||
|
|
||||||
# Requests hanging threshold
|
|
||||||
hanging_threshold_seconds: 0.0000001 # Number of seconds of waiting for a response before a request is considered hanging
|
|
||||||
hanging_threshold_window_seconds: 10 # Window in seconds
|
|
||||||
key_management_system: "hashicorp_vault"
|
|
||||||
|
|
||||||
# For /fine_tuning/jobs endpoints
|
|
||||||
finetune_settings:
|
|
||||||
- custom_llm_provider: "vertex_ai"
|
|
||||||
vertex_project: "adroit-crow-413218"
|
|
||||||
vertex_location: "us-central1"
|
|
||||||
vertex_credentials: "/Users/ishaanjaffer/Downloads/adroit-crow-413218-a956eef1a2a8.json"
|
|
|
@ -1799,6 +1799,7 @@ class LlmProviders(str, Enum):
|
||||||
GALADRIEL = "galadriel"
|
GALADRIEL = "galadriel"
|
||||||
INFINITY = "infinity"
|
INFINITY = "infinity"
|
||||||
DEEPGRAM = "deepgram"
|
DEEPGRAM = "deepgram"
|
||||||
|
AIOHTTP_OPENAI = "aiohttp_openai"
|
||||||
|
|
||||||
|
|
||||||
class LiteLLMLoggingBaseClass:
|
class LiteLLMLoggingBaseClass:
|
||||||
|
|
|
@ -6147,6 +6147,8 @@ class ProviderConfigManager:
|
||||||
or litellm.LlmProviders.LITELLM_PROXY == provider
|
or litellm.LlmProviders.LITELLM_PROXY == provider
|
||||||
):
|
):
|
||||||
return litellm.OpenAILikeChatConfig()
|
return litellm.OpenAILikeChatConfig()
|
||||||
|
elif litellm.LlmProviders.AIOHTTP_OPENAI == provider:
|
||||||
|
return litellm.AiohttpOpenAIChatConfig()
|
||||||
elif litellm.LlmProviders.HOSTED_VLLM == provider:
|
elif litellm.LlmProviders.HOSTED_VLLM == provider:
|
||||||
return litellm.HostedVLLMChatConfig()
|
return litellm.HostedVLLMChatConfig()
|
||||||
elif litellm.LlmProviders.LM_STUDIO == provider:
|
elif litellm.LlmProviders.LM_STUDIO == provider:
|
||||||
|
|
24
tests/llm_translation/test_aiohttp_openai.py
Normal file
24
tests/llm_translation/test_aiohttp_openai.py
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from datetime import datetime
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../../")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_aiohttp_openai():
|
||||||
|
litellm.set_verbose = True
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="aiohttp_openai/fake-model",
|
||||||
|
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||||
|
api_base="https://exampleopenaiendpoint-production.up.railway.app/v1/chat/completions",
|
||||||
|
api_key="fake-key",
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
print(response.model_dump_json(indent=4))
|
Loading…
Add table
Add a link
Reference in a new issue