forked from phoenix/litellm-mirror
(Perf / latency improvement) improve pass through endpoint latency to ~50ms (before PR was 400ms) (#6874)
* use correct location for types * fix types location * perf improvement for pass through endpoints * update lint check * fix import * fix ensure async clients test * fix azure.py health check * fix ollama
This commit is contained in:
parent
772b2f9cd2
commit
d81ae45827
9 changed files with 64 additions and 19 deletions
|
@ -1528,7 +1528,8 @@ class AzureChatCompletion(BaseLLM):
|
|||
prompt: Optional[str] = None,
|
||||
) -> dict:
|
||||
client_session = (
|
||||
litellm.aclient_session or httpx.AsyncClient()
|
||||
litellm.aclient_session
|
||||
or get_async_httpx_client(llm_provider=litellm.LlmProviders.AZURE).client
|
||||
) # handle dall-e-2 calls
|
||||
|
||||
if "gateway.ai.cloudflare.com" in api_base:
|
||||
|
|
|
@ -8,8 +8,7 @@ from httpx import USE_CLIENT_DEFAULT, AsyncHTTPTransport, HTTPTransport
|
|||
|
||||
import litellm
|
||||
from litellm.caching import InMemoryCache
|
||||
|
||||
from .types import httpxSpecialProvider
|
||||
from litellm.types.llms.custom_http import *
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm import LlmProviders
|
||||
|
|
|
@ -1,11 +0,0 @@
|
|||
from enum import Enum
|
||||
|
||||
import litellm
|
||||
|
||||
|
||||
class httpxSpecialProvider(str, Enum):
|
||||
LoggingCallback = "logging_callback"
|
||||
GuardrailCallback = "guardrail_callback"
|
||||
Caching = "caching"
|
||||
Oauth2Check = "oauth2_check"
|
||||
SecretManager = "secret_manager"
|
|
@ -14,6 +14,7 @@ import requests # type: ignore
|
|||
|
||||
import litellm
|
||||
from litellm import verbose_logger
|
||||
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.utils import ModelInfo, ProviderField, StreamingChoices
|
||||
|
||||
|
@ -456,7 +457,10 @@ def ollama_completion_stream(url, data, logging_obj):
|
|||
|
||||
async def ollama_async_streaming(url, data, model_response, encoding, logging_obj):
|
||||
try:
|
||||
client = httpx.AsyncClient()
|
||||
_async_http_client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.OLLAMA
|
||||
)
|
||||
client = _async_http_client.client
|
||||
async with client.stream(
|
||||
url=f"{url}", json=data, method="POST", timeout=litellm.request_timeout
|
||||
) as response:
|
||||
|
|
|
@ -13,6 +13,7 @@ from pydantic import BaseModel
|
|||
|
||||
import litellm
|
||||
from litellm import verbose_logger
|
||||
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
||||
from litellm.types.llms.ollama import OllamaToolCall, OllamaToolCallFunction
|
||||
from litellm.types.llms.openai import ChatCompletionAssistantToolCall
|
||||
from litellm.types.utils import StreamingChoices
|
||||
|
@ -445,7 +446,10 @@ async def ollama_async_streaming(
|
|||
url, api_key, data, model_response, encoding, logging_obj
|
||||
):
|
||||
try:
|
||||
client = httpx.AsyncClient()
|
||||
_async_http_client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.OLLAMA
|
||||
)
|
||||
client = _async_http_client.client
|
||||
_request = {
|
||||
"url": f"{url}",
|
||||
"json": data,
|
||||
|
|
|
@ -22,6 +22,7 @@ import litellm
|
|||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
||||
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
||||
ModelResponseIterator,
|
||||
)
|
||||
|
@ -35,6 +36,7 @@ from litellm.proxy._types import (
|
|||
)
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.custom_http import httpxSpecialProvider
|
||||
|
||||
from .streaming_handler import PassThroughStreamingHandler
|
||||
from .success_handler import PassThroughEndpointLogging
|
||||
|
@ -363,8 +365,11 @@ async def pass_through_request( # noqa: PLR0915
|
|||
data=_parsed_body,
|
||||
call_type="pass_through_endpoint",
|
||||
)
|
||||
|
||||
async_client = httpx.AsyncClient(timeout=600)
|
||||
async_client_obj = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.PassThroughEndpoint,
|
||||
params={"timeout": 600},
|
||||
)
|
||||
async_client = async_client_obj.client
|
||||
|
||||
litellm_call_id = str(uuid.uuid4())
|
||||
|
||||
|
|
|
@ -31,8 +31,8 @@ from litellm.llms.custom_httpx.http_handler import (
|
|||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.llms.custom_httpx.types import httpxSpecialProvider
|
||||
from litellm.proxy._types import KeyManagementSystem
|
||||
from litellm.types.llms.custom_http import httpxSpecialProvider
|
||||
|
||||
|
||||
class AWSSecretsManagerV2(BaseAWSLLM):
|
||||
|
|
20
litellm/types/llms/custom_http.py
Normal file
20
litellm/types/llms/custom_http.py
Normal file
|
@ -0,0 +1,20 @@
|
|||
from enum import Enum
|
||||
|
||||
import litellm
|
||||
|
||||
|
||||
class httpxSpecialProvider(str, Enum):
|
||||
"""
|
||||
Httpx Clients can be created for these litellm internal providers
|
||||
|
||||
Example:
|
||||
- langsmith logging would need a custom async httpx client
|
||||
- pass through endpoint would need a custom async httpx client
|
||||
"""
|
||||
|
||||
LoggingCallback = "logging_callback"
|
||||
GuardrailCallback = "guardrail_callback"
|
||||
Caching = "caching"
|
||||
Oauth2Check = "oauth2_check"
|
||||
SecretManager = "secret_manager"
|
||||
PassThroughEndpoint = "pass_through_endpoint"
|
|
@ -5,9 +5,19 @@ ALLOWED_FILES = [
|
|||
# local files
|
||||
"../../litellm/__init__.py",
|
||||
"../../litellm/llms/custom_httpx/http_handler.py",
|
||||
"../../litellm/router_utils/client_initalization_utils.py",
|
||||
"../../litellm/llms/custom_httpx/http_handler.py",
|
||||
"../../litellm/llms/huggingface_restapi.py",
|
||||
"../../litellm/llms/base.py",
|
||||
"../../litellm/llms/custom_httpx/httpx_handler.py",
|
||||
# when running on ci/cd
|
||||
"./litellm/__init__.py",
|
||||
"./litellm/llms/custom_httpx/http_handler.py",
|
||||
"./litellm/router_utils/client_initalization_utils.py",
|
||||
"./litellm/llms/custom_httpx/http_handler.py",
|
||||
"./litellm/llms/huggingface_restapi.py",
|
||||
"./litellm/llms/base.py",
|
||||
"./litellm/llms/custom_httpx/httpx_handler.py",
|
||||
]
|
||||
|
||||
warning_msg = "this is a serious violation that can impact latency. Creating Async clients per request can add +500ms per request"
|
||||
|
@ -43,6 +53,19 @@ def check_for_async_http_handler(file_path):
|
|||
raise ValueError(
|
||||
f"found violation in file {file_path} line: {node.lineno}. Please use `get_async_httpx_client` instead. {warning_msg}"
|
||||
)
|
||||
# Check for attribute calls like httpx.AsyncClient()
|
||||
elif isinstance(node.func, ast.Attribute):
|
||||
full_name = ""
|
||||
current = node.func
|
||||
while isinstance(current, ast.Attribute):
|
||||
full_name = "." + current.attr + full_name
|
||||
current = current.value
|
||||
if isinstance(current, ast.Name):
|
||||
full_name = current.id + full_name
|
||||
if full_name.lower() in [name.lower() for name in target_names]:
|
||||
raise ValueError(
|
||||
f"found violation in file {file_path} line: {node.lineno}. Please use `get_async_httpx_client` instead. {warning_msg}"
|
||||
)
|
||||
return violations
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue