feat(router.py): support async image generation on router

This commit is contained in:
Krrish Dholakia 2023-12-20 17:24:20 +05:30
parent f355e03515
commit 4040f60feb
2 changed files with 87 additions and 5 deletions

View file

@ -7,7 +7,7 @@
# #
# Thank you ! We ❤️ you! - Krrish & Ishaan # Thank you ! We ❤️ you! - Krrish & Ishaan
import copy import copy, httpx
from datetime import datetime from datetime import datetime
from typing import Dict, List, Optional, Union, Literal, Any from typing import Dict, List, Optional, Union, Literal, Any
import random, threading, time, traceback, uuid import random, threading, time, traceback, uuid
@ -18,6 +18,7 @@ import inspect, concurrent
from openai import AsyncOpenAI from openai import AsyncOpenAI
from collections import defaultdict from collections import defaultdict
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
from litellm.llms.custom_httpx.azure_dall_e_2 import CustomHTTPTransport
import copy import copy
class Router: class Router:
""" """
@ -166,7 +167,7 @@ class Router:
self.print_verbose(f"Intialized router with Routing strategy: {self.routing_strategy}\n") self.print_verbose(f"Intialized router with Routing strategy: {self.routing_strategy}\n")
### COMPLETION + EMBEDDING FUNCTIONS ### COMPLETION, EMBEDDING, IMG GENERATION FUNCTIONS
def completion(self, def completion(self,
model: str, model: str,
@ -260,6 +261,50 @@ class Router:
self.fail_calls[model_name] +=1 self.fail_calls[model_name] +=1
raise e raise e
async def aimage_generation(self,
prompt: str,
model: str,
**kwargs):
try:
kwargs["model"] = model
kwargs["prompt"] = prompt
kwargs["original_function"] = self._image_generation
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
timeout = kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = await self.async_function_with_fallbacks(**kwargs)
return response
except Exception as e:
raise e
async def _image_generation(self,
prompt: str,
model: str,
**kwargs):
try:
self.print_verbose(f"Inside _image_generation()- model: {model}; kwargs: {kwargs}")
deployment = self.get_available_deployment(model=model, messages=[{"role": "user", "content": "prompt"}], specific_deployment=kwargs.pop("specific_deployment", None))
kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]})
kwargs["model_info"] = deployment.get("model_info", {})
data = deployment["litellm_params"].copy()
model_name = data["model"]
for k, v in self.default_litellm_params.items():
if k not in kwargs: # prioritize model-specific params > default router params
kwargs[k] = v
elif k == "metadata":
kwargs[k].update(v)
model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async")
self.total_calls[model_name] +=1
response = await litellm.aimage_generation(**{**data, "prompt": prompt, "caching": self.cache_responses, "client": model_client, **kwargs})
self.success_calls[model_name] +=1
return response
except Exception as e:
if model_name is not None:
self.fail_calls[model_name] +=1
raise e
def text_completion(self, def text_completion(self,
model: str, model: str,
prompt: str, prompt: str,
@ -1009,14 +1054,16 @@ class Router:
azure_endpoint=api_base, azure_endpoint=api_base,
api_version=api_version, api_version=api_version,
timeout=timeout, timeout=timeout,
max_retries=max_retries max_retries=max_retries,
http_client=httpx.Client(transport=CustomHTTPTransport(),) # type: ignore
) )
model["client"] = openai.AzureOpenAI( model["client"] = openai.AzureOpenAI(
api_key=api_key, api_key=api_key,
azure_endpoint=api_base, azure_endpoint=api_base,
api_version=api_version, api_version=api_version,
timeout=timeout, timeout=timeout,
max_retries=max_retries max_retries=max_retries,
http_client=httpx.Client(transport=CustomHTTPTransport(),) # type: ignore
) )
# streaming clients should have diff timeouts # streaming clients should have diff timeouts
model["stream_async_client"] = openai.AsyncAzureOpenAI( model["stream_async_client"] = openai.AsyncAzureOpenAI(
@ -1024,7 +1071,7 @@ class Router:
azure_endpoint=api_base, azure_endpoint=api_base,
api_version=api_version, api_version=api_version,
timeout=stream_timeout, timeout=stream_timeout,
max_retries=max_retries max_retries=max_retries,
) )
model["stream_client"] = openai.AzureOpenAI( model["stream_client"] = openai.AzureOpenAI(

View file

@ -423,6 +423,41 @@ def test_function_calling_on_router():
# test_function_calling_on_router() # test_function_calling_on_router()
### IMAGE GENERATION
async def test_aimg_gen_on_router():
litellm.set_verbose = True
try:
model_list = [
{
"model_name": "dall-e-3",
"litellm_params": {
"model": "dall-e-3",
},
},
{
"model_name": "dall-e-3",
"litellm_params": {
"model": "azure/dall-e-3-test",
"api_version": "2023-12-01-preview",
"api_base": os.getenv("AZURE_SWEDEN_API_BASE"),
"api_key": os.getenv("AZURE_SWEDEN_API_KEY")
}
}
]
router = Router(model_list=model_list)
response = await router.aimage_generation(
model="dall-e-3",
prompt="A cute baby sea otter"
)
print(response)
router.reset()
except Exception as e:
traceback.print_exc()
pytest.fail(f"Error occurred: {e}")
asyncio.run(test_aimg_gen_on_router())
###
def test_aembedding_on_router(): def test_aembedding_on_router():
litellm.set_verbose = True litellm.set_verbose = True
try: try: