feat(router.py): support async image generation on router

This commit is contained in:
Krrish Dholakia 2023-12-20 17:24:20 +05:30
parent f355e03515
commit 4040f60feb
2 changed files with 87 additions and 5 deletions

View file

@ -7,7 +7,7 @@
#
# Thank you ! We ❤️ you! - Krrish & Ishaan
import copy
import copy, httpx
from datetime import datetime
from typing import Dict, List, Optional, Union, Literal, Any
import random, threading, time, traceback, uuid
@ -18,6 +18,7 @@ import inspect, concurrent
from openai import AsyncOpenAI
from collections import defaultdict
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
from litellm.llms.custom_httpx.azure_dall_e_2 import CustomHTTPTransport
import copy
class Router:
"""
@ -166,7 +167,7 @@ class Router:
self.print_verbose(f"Intialized router with Routing strategy: {self.routing_strategy}\n")
### COMPLETION + EMBEDDING FUNCTIONS
### COMPLETION, EMBEDDING, IMG GENERATION FUNCTIONS
def completion(self,
model: str,
@ -260,6 +261,50 @@ class Router:
self.fail_calls[model_name] +=1
raise e
async def aimage_generation(self,
prompt: str,
model: str,
**kwargs):
try:
kwargs["model"] = model
kwargs["prompt"] = prompt
kwargs["original_function"] = self._image_generation
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
timeout = kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = await self.async_function_with_fallbacks(**kwargs)
return response
except Exception as e:
raise e
async def _image_generation(self,
prompt: str,
model: str,
**kwargs):
try:
self.print_verbose(f"Inside _image_generation()- model: {model}; kwargs: {kwargs}")
deployment = self.get_available_deployment(model=model, messages=[{"role": "user", "content": "prompt"}], specific_deployment=kwargs.pop("specific_deployment", None))
kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]})
kwargs["model_info"] = deployment.get("model_info", {})
data = deployment["litellm_params"].copy()
model_name = data["model"]
for k, v in self.default_litellm_params.items():
if k not in kwargs: # prioritize model-specific params > default router params
kwargs[k] = v
elif k == "metadata":
kwargs[k].update(v)
model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async")
self.total_calls[model_name] +=1
response = await litellm.aimage_generation(**{**data, "prompt": prompt, "caching": self.cache_responses, "client": model_client, **kwargs})
self.success_calls[model_name] +=1
return response
except Exception as e:
if model_name is not None:
self.fail_calls[model_name] +=1
raise e
def text_completion(self,
model: str,
prompt: str,
@ -1009,14 +1054,16 @@ class Router:
azure_endpoint=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries
max_retries=max_retries,
http_client=httpx.Client(transport=CustomHTTPTransport(),) # type: ignore
)
model["client"] = openai.AzureOpenAI(
api_key=api_key,
azure_endpoint=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries
max_retries=max_retries,
http_client=httpx.Client(transport=CustomHTTPTransport(),) # type: ignore
)
# streaming clients should have diff timeouts
model["stream_async_client"] = openai.AsyncAzureOpenAI(
@ -1024,7 +1071,7 @@ class Router:
azure_endpoint=api_base,
api_version=api_version,
timeout=stream_timeout,
max_retries=max_retries
max_retries=max_retries,
)
model["stream_client"] = openai.AzureOpenAI(

View file

@ -423,6 +423,41 @@ def test_function_calling_on_router():
# test_function_calling_on_router()
### IMAGE GENERATION
async def test_aimg_gen_on_router():
litellm.set_verbose = True
try:
model_list = [
{
"model_name": "dall-e-3",
"litellm_params": {
"model": "dall-e-3",
},
},
{
"model_name": "dall-e-3",
"litellm_params": {
"model": "azure/dall-e-3-test",
"api_version": "2023-12-01-preview",
"api_base": os.getenv("AZURE_SWEDEN_API_BASE"),
"api_key": os.getenv("AZURE_SWEDEN_API_KEY")
}
}
]
router = Router(model_list=model_list)
response = await router.aimage_generation(
model="dall-e-3",
prompt="A cute baby sea otter"
)
print(response)
router.reset()
except Exception as e:
traceback.print_exc()
pytest.fail(f"Error occurred: {e}")
asyncio.run(test_aimg_gen_on_router())
###
def test_aembedding_on_router():
litellm.set_verbose = True
try: