(Perf / latency improvement) improve pass through endpoint latency to ~50ms (before PR was 400ms) (#6874)

* use correct location for types

* fix types location

* perf improvement for pass through endpoints

* update lint check

* fix import

* fix ensure async clients test

* fix azure.py health check

* fix ollama
This commit is contained in:
Ishaan Jaff 2024-11-22 18:47:26 -08:00 committed by GitHub
parent 772b2f9cd2
commit d81ae45827
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 64 additions and 19 deletions

View file

@ -1528,7 +1528,8 @@ class AzureChatCompletion(BaseLLM):
prompt: Optional[str] = None,
) -> dict:
client_session = (
litellm.aclient_session or httpx.AsyncClient()
litellm.aclient_session
or get_async_httpx_client(llm_provider=litellm.LlmProviders.AZURE).client
) # handle dall-e-2 calls
if "gateway.ai.cloudflare.com" in api_base:

View file

@ -8,8 +8,7 @@ from httpx import USE_CLIENT_DEFAULT, AsyncHTTPTransport, HTTPTransport
import litellm
from litellm.caching import InMemoryCache
from .types import httpxSpecialProvider
from litellm.types.llms.custom_http import *
if TYPE_CHECKING:
from litellm import LlmProviders

View file

@ -1,11 +0,0 @@
from enum import Enum
import litellm
class httpxSpecialProvider(str, Enum):
LoggingCallback = "logging_callback"
GuardrailCallback = "guardrail_callback"
Caching = "caching"
Oauth2Check = "oauth2_check"
SecretManager = "secret_manager"

View file

@ -14,6 +14,7 @@ import requests # type: ignore
import litellm
from litellm import verbose_logger
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
from litellm.secret_managers.main import get_secret_str
from litellm.types.utils import ModelInfo, ProviderField, StreamingChoices
@ -456,7 +457,10 @@ def ollama_completion_stream(url, data, logging_obj):
async def ollama_async_streaming(url, data, model_response, encoding, logging_obj):
try:
client = httpx.AsyncClient()
_async_http_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.OLLAMA
)
client = _async_http_client.client
async with client.stream(
url=f"{url}", json=data, method="POST", timeout=litellm.request_timeout
) as response:

View file

@ -13,6 +13,7 @@ from pydantic import BaseModel
import litellm
from litellm import verbose_logger
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
from litellm.types.llms.ollama import OllamaToolCall, OllamaToolCallFunction
from litellm.types.llms.openai import ChatCompletionAssistantToolCall
from litellm.types.utils import StreamingChoices
@ -445,7 +446,10 @@ async def ollama_async_streaming(
url, api_key, data, model_response, encoding, logging_obj
):
try:
client = httpx.AsyncClient()
_async_http_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.OLLAMA
)
client = _async_http_client.client
_request = {
"url": f"{url}",
"json": data,

View file

@ -22,6 +22,7 @@ import litellm
from litellm._logging import verbose_proxy_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
ModelResponseIterator,
)
@ -35,6 +36,7 @@ from litellm.proxy._types import (
)
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.custom_http import httpxSpecialProvider
from .streaming_handler import PassThroughStreamingHandler
from .success_handler import PassThroughEndpointLogging
@ -363,8 +365,11 @@ async def pass_through_request( # noqa: PLR0915
data=_parsed_body,
call_type="pass_through_endpoint",
)
async_client = httpx.AsyncClient(timeout=600)
async_client_obj = get_async_httpx_client(
llm_provider=httpxSpecialProvider.PassThroughEndpoint,
params={"timeout": 600},
)
async_client = async_client_obj.client
litellm_call_id = str(uuid.uuid4())

View file

@ -31,8 +31,8 @@ from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client,
get_async_httpx_client,
)
from litellm.llms.custom_httpx.types import httpxSpecialProvider
from litellm.proxy._types import KeyManagementSystem
from litellm.types.llms.custom_http import httpxSpecialProvider
class AWSSecretsManagerV2(BaseAWSLLM):

View file

@ -0,0 +1,20 @@
from enum import Enum
import litellm
class httpxSpecialProvider(str, Enum):
"""
Httpx Clients can be created for these litellm internal providers
Example:
- langsmith logging would need a custom async httpx client
- pass through endpoint would need a custom async httpx client
"""
LoggingCallback = "logging_callback"
GuardrailCallback = "guardrail_callback"
Caching = "caching"
Oauth2Check = "oauth2_check"
SecretManager = "secret_manager"
PassThroughEndpoint = "pass_through_endpoint"

View file

@ -5,9 +5,19 @@ ALLOWED_FILES = [
# local files
"../../litellm/__init__.py",
"../../litellm/llms/custom_httpx/http_handler.py",
"../../litellm/router_utils/client_initalization_utils.py",
"../../litellm/llms/custom_httpx/http_handler.py",
"../../litellm/llms/huggingface_restapi.py",
"../../litellm/llms/base.py",
"../../litellm/llms/custom_httpx/httpx_handler.py",
# when running on ci/cd
"./litellm/__init__.py",
"./litellm/llms/custom_httpx/http_handler.py",
"./litellm/router_utils/client_initalization_utils.py",
"./litellm/llms/custom_httpx/http_handler.py",
"./litellm/llms/huggingface_restapi.py",
"./litellm/llms/base.py",
"./litellm/llms/custom_httpx/httpx_handler.py",
]
warning_msg = "this is a serious violation that can impact latency. Creating Async clients per request can add +500ms per request"
@ -43,6 +53,19 @@ def check_for_async_http_handler(file_path):
raise ValueError(
f"found violation in file {file_path} line: {node.lineno}. Please use `get_async_httpx_client` instead. {warning_msg}"
)
# Check for attribute calls like httpx.AsyncClient()
elif isinstance(node.func, ast.Attribute):
full_name = ""
current = node.func
while isinstance(current, ast.Attribute):
full_name = "." + current.attr + full_name
current = current.value
if isinstance(current, ast.Name):
full_name = current.id + full_name
if full_name.lower() in [name.lower() for name in target_names]:
raise ValueError(
f"found violation in file {file_path} line: {node.lineno}. Please use `get_async_httpx_client` instead. {warning_msg}"
)
return violations