refactor: remove custom transport logic

Not needed after azure dall-e-2 refactor
This commit is contained in:
Krrish Dholakia 2024-07-02 17:35:27 -07:00
parent 637369d2ac
commit cd51f292b6
4 changed files with 71 additions and 241 deletions

View file

@ -55,7 +55,6 @@ from ..types.llms.openai import (
Thread, Thread,
) )
from .base import BaseLLM from .base import BaseLLM
from .custom_httpx.azure_dall_e_2 import AsyncCustomHTTPTransport, CustomHTTPTransport
azure_ad_cache = DualCache() azure_ad_cache = DualCache()
@ -1706,9 +1705,7 @@ class AzureChatCompletion(BaseLLM):
input: Optional[list] = None, input: Optional[list] = None,
prompt: Optional[str] = None, prompt: Optional[str] = None,
) -> dict: ) -> dict:
client_session = litellm.client_session or httpx.Client( client_session = litellm.client_session or httpx.Client()
transport=CustomHTTPTransport(), # handle dall-e-2 calls
)
if "gateway.ai.cloudflare.com" in api_base: if "gateway.ai.cloudflare.com" in api_base:
## build base url - assume api base includes resource name ## build base url - assume api base includes resource name
if not api_base.endswith("/"): if not api_base.endswith("/"):
@ -1781,9 +1778,10 @@ class AzureChatCompletion(BaseLLM):
input: Optional[list] = None, input: Optional[list] = None,
prompt: Optional[str] = None, prompt: Optional[str] = None,
) -> dict: ) -> dict:
client_session = litellm.aclient_session or httpx.AsyncClient( client_session = (
transport=AsyncCustomHTTPTransport(), # handle dall-e-2 calls litellm.aclient_session or httpx.AsyncClient()
) ) # handle dall-e-2 calls
if "gateway.ai.cloudflare.com" in api_base: if "gateway.ai.cloudflare.com" in api_base:
## build base url - assume api base includes resource name ## build base url - assume api base includes resource name
if not api_base.endswith("/"): if not api_base.endswith("/"):

View file

@ -1,24 +1,27 @@
from typing import Optional, Union, Any import json
import types, requests # type: ignore import types # type: ignore
from .base import BaseLLM
from litellm.utils import (
ModelResponse,
Choices,
Message,
CustomStreamWrapper,
convert_to_model_response_object,
TranscriptionResponse,
TextCompletionResponse,
)
from typing import Callable, Optional, BinaryIO
from litellm import OpenAIConfig
import litellm, json
import httpx
from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport
from openai import AzureOpenAI, AsyncAzureOpenAI
from ..llms.openai import OpenAITextCompletion, OpenAITextCompletionConfig
import uuid import uuid
from .prompt_templates.factory import prompt_factory, custom_prompt from typing import Any, BinaryIO, Callable, Optional, Union
import httpx
import requests
from openai import AsyncAzureOpenAI, AzureOpenAI
import litellm
from litellm import OpenAIConfig
from litellm.utils import (
Choices,
CustomStreamWrapper,
Message,
ModelResponse,
TextCompletionResponse,
TranscriptionResponse,
convert_to_model_response_object,
)
from ..llms.openai import OpenAITextCompletion, OpenAITextCompletionConfig
from .base import BaseLLM
from .prompt_templates.factory import custom_prompt, prompt_factory
openai_text_completion_config = OpenAITextCompletionConfig() openai_text_completion_config = OpenAITextCompletionConfig()

View file

@ -1,143 +0,0 @@
import asyncio
import json
import time
import httpx
class AsyncCustomHTTPTransport(httpx.AsyncHTTPTransport):
"""
Async implementation of custom http transport
"""
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
_api_version = request.url.params.get("api-version", "")
if (
"images/generations" in request.url.path
and _api_version
in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict
"2023-06-01-preview",
"2023-07-01-preview",
"2023-08-01-preview",
"2023-09-01-preview",
"2023-10-01-preview",
]
):
request.url = request.url.copy_with(
path="/openai/images/generations:submit"
)
response = await super().handle_async_request(request)
operation_location_url = response.headers["operation-location"]
request.url = httpx.URL(operation_location_url)
request.method = "GET"
response = await super().handle_async_request(request)
await response.aread()
timeout_secs: int = 120
start_time = time.time()
while response.json()["status"] not in ["succeeded", "failed"]:
if time.time() - start_time > timeout_secs:
timeout = {
"error": {
"code": "Timeout",
"message": "Operation polling timed out.",
}
}
return httpx.Response(
status_code=400,
headers=response.headers,
content=json.dumps(timeout).encode("utf-8"),
request=request,
)
await asyncio.sleep(int(response.headers.get("retry-after") or 10))
response = await super().handle_async_request(request)
await response.aread()
if response.json()["status"] == "failed":
error_data = response.json()
return httpx.Response(
status_code=400,
headers=response.headers,
content=json.dumps(error_data).encode("utf-8"),
request=request,
)
result = response.json()["result"]
return httpx.Response(
status_code=200,
headers=response.headers,
content=json.dumps(result).encode("utf-8"),
request=request,
)
return await super().handle_async_request(request)
class CustomHTTPTransport(httpx.HTTPTransport):
"""
This class was written as a workaround to support dall-e-2 on openai > v1.x
Refer to this issue for more: https://github.com/openai/openai-python/issues/692
"""
def handle_request(
self,
request: httpx.Request,
) -> httpx.Response:
_api_version = request.url.params.get("api-version", "")
if (
"images/generations" in request.url.path
and _api_version
in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict
"2023-06-01-preview",
"2023-07-01-preview",
"2023-08-01-preview",
"2023-09-01-preview",
"2023-10-01-preview",
]
):
request.url = request.url.copy_with(
path="/openai/images/generations:submit"
)
response = super().handle_request(request)
operation_location_url = response.headers["operation-location"]
request.url = httpx.URL(operation_location_url)
request.method = "GET"
response = super().handle_request(request)
response.read()
timeout_secs: int = 120
start_time = time.time()
while response.json()["status"] not in ["succeeded", "failed"]:
if time.time() - start_time > timeout_secs:
timeout = {
"error": {
"code": "Timeout",
"message": "Operation polling timed out.",
}
}
return httpx.Response(
status_code=400,
headers=response.headers,
content=json.dumps(timeout).encode("utf-8"),
request=request,
)
time.sleep(int(response.headers.get("retry-after", None) or 10))
response = super().handle_request(request)
response.read()
if response.json()["status"] == "failed":
error_data = response.json()
return httpx.Response(
status_code=400,
headers=response.headers,
content=json.dumps(error_data).encode("utf-8"),
request=request,
)
result = response.json()["result"]
return httpx.Response(
status_code=200,
headers=response.headers,
content=json.dumps(result).encode("utf-8"),
request=request,
)
return super().handle_request(request)

View file

@ -46,10 +46,6 @@ from litellm._logging import verbose_router_logger
from litellm.caching import DualCache, InMemoryCache, RedisCache from litellm.caching import DualCache, InMemoryCache, RedisCache
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.azure import get_azure_ad_token_from_oidc from litellm.llms.azure import get_azure_ad_token_from_oidc
from litellm.llms.custom_httpx.azure_dall_e_2 import (
AsyncCustomHTTPTransport,
CustomHTTPTransport,
)
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler from litellm.router_strategy.lowest_cost import LowestCostLoggingHandler
from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler from litellm.router_strategy.lowest_latency import LowestLatencyLoggingHandler
@ -3452,12 +3448,10 @@ class Router:
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
http_client=httpx.AsyncClient( http_client=httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(
limits=httpx.Limits( limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100 max_connections=1000, max_keepalive_connections=100
), ),
verify=litellm.ssl_verify, verify=litellm.ssl_verify,
),
mounts=async_proxy_mounts, mounts=async_proxy_mounts,
), # type: ignore ), # type: ignore
) )
@ -3477,13 +3471,11 @@ class Router:
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
http_client=httpx.Client( http_client=httpx.Client(
transport=CustomHTTPTransport( mounts=sync_proxy_mounts,
limits=httpx.Limits( limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100 max_connections=1000, max_keepalive_connections=100
), ),
verify=litellm.ssl_verify, verify=litellm.ssl_verify,
),
mounts=sync_proxy_mounts,
), # type: ignore ), # type: ignore
) )
self.cache.set_cache( self.cache.set_cache(
@ -3502,12 +3494,10 @@ class Router:
timeout=stream_timeout, timeout=stream_timeout,
max_retries=max_retries, max_retries=max_retries,
http_client=httpx.AsyncClient( http_client=httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(
limits=httpx.Limits( limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100 max_connections=1000, max_keepalive_connections=100
), ),
verify=litellm.ssl_verify, verify=litellm.ssl_verify,
),
mounts=async_proxy_mounts, mounts=async_proxy_mounts,
), # type: ignore ), # type: ignore
) )
@ -3527,13 +3517,11 @@ class Router:
timeout=stream_timeout, timeout=stream_timeout,
max_retries=max_retries, max_retries=max_retries,
http_client=httpx.Client( http_client=httpx.Client(
transport=CustomHTTPTransport( mounts=sync_proxy_mounts,
limits=httpx.Limits( limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100 max_connections=1000, max_keepalive_connections=100
), ),
verify=litellm.ssl_verify, verify=litellm.ssl_verify,
),
mounts=sync_proxy_mounts,
), # type: ignore ), # type: ignore
) )
self.cache.set_cache( self.cache.set_cache(
@ -3570,12 +3558,10 @@ class Router:
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
http_client=httpx.AsyncClient( http_client=httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(
limits=httpx.Limits( limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100 max_connections=1000, max_keepalive_connections=100
), ),
verify=litellm.ssl_verify, verify=litellm.ssl_verify,
),
mounts=async_proxy_mounts, mounts=async_proxy_mounts,
), # type: ignore ), # type: ignore
) )
@ -3592,13 +3578,11 @@ class Router:
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
http_client=httpx.Client( http_client=httpx.Client(
transport=CustomHTTPTransport( mounts=sync_proxy_mounts,
verify=litellm.ssl_verify,
limits=httpx.Limits( limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100 max_connections=1000, max_keepalive_connections=100
), ),
), verify=litellm.ssl_verify,
mounts=sync_proxy_mounts,
), # type: ignore ), # type: ignore
) )
self.cache.set_cache( self.cache.set_cache(
@ -3615,14 +3599,12 @@ class Router:
timeout=stream_timeout, timeout=stream_timeout,
max_retries=max_retries, max_retries=max_retries,
http_client=httpx.AsyncClient( http_client=httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(
limits=httpx.Limits( limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100 max_connections=1000, max_keepalive_connections=100
), ),
verify=litellm.ssl_verify, verify=litellm.ssl_verify,
),
mounts=async_proxy_mounts, mounts=async_proxy_mounts,
), ), # type: ignore
) )
self.cache.set_cache( self.cache.set_cache(
key=cache_key, key=cache_key,
@ -3637,14 +3619,12 @@ class Router:
timeout=stream_timeout, timeout=stream_timeout,
max_retries=max_retries, max_retries=max_retries,
http_client=httpx.Client( http_client=httpx.Client(
transport=CustomHTTPTransport( mounts=sync_proxy_mounts,
limits=httpx.Limits( limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100 max_connections=1000, max_keepalive_connections=100
), ),
verify=litellm.ssl_verify, verify=litellm.ssl_verify,
), ),
mounts=sync_proxy_mounts,
),
) )
self.cache.set_cache( self.cache.set_cache(
key=cache_key, key=cache_key,
@ -3669,12 +3649,10 @@ class Router:
max_retries=max_retries, max_retries=max_retries,
organization=organization, organization=organization,
http_client=httpx.AsyncClient( http_client=httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(
limits=httpx.Limits( limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100 max_connections=1000, max_keepalive_connections=100
), ),
verify=litellm.ssl_verify, verify=litellm.ssl_verify,
),
mounts=async_proxy_mounts, mounts=async_proxy_mounts,
), # type: ignore ), # type: ignore
) )
@ -3693,13 +3671,11 @@ class Router:
max_retries=max_retries, max_retries=max_retries,
organization=organization, organization=organization,
http_client=httpx.Client( http_client=httpx.Client(
transport=CustomHTTPTransport( mounts=sync_proxy_mounts,
limits=httpx.Limits( limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100 max_connections=1000, max_keepalive_connections=100
), ),
verify=litellm.ssl_verify, verify=litellm.ssl_verify,
),
mounts=sync_proxy_mounts,
), # type: ignore ), # type: ignore
) )
self.cache.set_cache( self.cache.set_cache(
@ -3718,12 +3694,10 @@ class Router:
max_retries=max_retries, max_retries=max_retries,
organization=organization, organization=organization,
http_client=httpx.AsyncClient( http_client=httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(
limits=httpx.Limits( limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100 max_connections=1000, max_keepalive_connections=100
), ),
verify=litellm.ssl_verify, verify=litellm.ssl_verify,
),
mounts=async_proxy_mounts, mounts=async_proxy_mounts,
), # type: ignore ), # type: ignore
) )
@ -3743,13 +3717,11 @@ class Router:
max_retries=max_retries, max_retries=max_retries,
organization=organization, organization=organization,
http_client=httpx.Client( http_client=httpx.Client(
transport=CustomHTTPTransport( mounts=sync_proxy_mounts,
limits=httpx.Limits( limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100 max_connections=1000, max_keepalive_connections=100
), ),
verify=litellm.ssl_verify, verify=litellm.ssl_verify,
),
mounts=sync_proxy_mounts,
), # type: ignore ), # type: ignore
) )
self.cache.set_cache( self.cache.set_cache(