feat(router.py): support async image generation on router

This commit is contained in:
Krrish Dholakia 2023-12-20 17:24:20 +05:30
parent f355e03515
commit 4040f60feb
2 changed files with 87 additions and 5 deletions

View file

@ -7,7 +7,7 @@
#
# Thank you ! We ❤️ you! - Krrish & Ishaan
import copy
import copy, httpx
from datetime import datetime
from typing import Dict, List, Optional, Union, Literal, Any
import random, threading, time, traceback, uuid
@ -18,6 +18,7 @@ import inspect, concurrent
from openai import AsyncOpenAI
from collections import defaultdict
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
from litellm.llms.custom_httpx.azure_dall_e_2 import CustomHTTPTransport
import copy
class Router:
"""
@ -166,7 +167,7 @@ class Router:
self.print_verbose(f"Intialized router with Routing strategy: {self.routing_strategy}\n")
### COMPLETION + EMBEDDING FUNCTIONS
### COMPLETION, EMBEDDING, IMG GENERATION FUNCTIONS
def completion(self,
model: str,
@ -260,6 +261,50 @@ class Router:
self.fail_calls[model_name] +=1
raise e
async def aimage_generation(self,
prompt: str,
model: str,
**kwargs):
try:
kwargs["model"] = model
kwargs["prompt"] = prompt
kwargs["original_function"] = self._image_generation
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
timeout = kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = await self.async_function_with_fallbacks(**kwargs)
return response
except Exception as e:
raise e
async def _image_generation(self,
prompt: str,
model: str,
**kwargs):
try:
self.print_verbose(f"Inside _image_generation()- model: {model}; kwargs: {kwargs}")
deployment = self.get_available_deployment(model=model, messages=[{"role": "user", "content": "prompt"}], specific_deployment=kwargs.pop("specific_deployment", None))
kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]})
kwargs["model_info"] = deployment.get("model_info", {})
data = deployment["litellm_params"].copy()
model_name = data["model"]
for k, v in self.default_litellm_params.items():
if k not in kwargs: # prioritize model-specific params > default router params
kwargs[k] = v
elif k == "metadata":
kwargs[k].update(v)
model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async")
self.total_calls[model_name] +=1
response = await litellm.aimage_generation(**{**data, "prompt": prompt, "caching": self.cache_responses, "client": model_client, **kwargs})
self.success_calls[model_name] +=1
return response
except Exception as e:
if model_name is not None:
self.fail_calls[model_name] +=1
raise e
def text_completion(self,
model: str,
prompt: str,
@ -1009,14 +1054,16 @@ class Router:
azure_endpoint=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries
max_retries=max_retries,
http_client=httpx.Client(transport=CustomHTTPTransport(),) # type: ignore
)
model["client"] = openai.AzureOpenAI(
api_key=api_key,
azure_endpoint=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries
max_retries=max_retries,
http_client=httpx.Client(transport=CustomHTTPTransport(),) # type: ignore
)
# streaming clients should have diff timeouts
model["stream_async_client"] = openai.AsyncAzureOpenAI(
@ -1024,7 +1071,7 @@ class Router:
azure_endpoint=api_base,
api_version=api_version,
timeout=stream_timeout,
max_retries=max_retries
max_retries=max_retries,
)
model["stream_client"] = openai.AzureOpenAI(