mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
feat(router.py): support async image generation on router
This commit is contained in:
parent
f355e03515
commit
4040f60feb
2 changed files with 87 additions and 5 deletions
|
@ -7,7 +7,7 @@
|
|||
#
|
||||
# Thank you ! We ❤️ you! - Krrish & Ishaan
|
||||
|
||||
import copy
|
||||
import copy, httpx
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Union, Literal, Any
|
||||
import random, threading, time, traceback, uuid
|
||||
|
@ -18,6 +18,7 @@ import inspect, concurrent
|
|||
from openai import AsyncOpenAI
|
||||
from collections import defaultdict
|
||||
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
|
||||
from litellm.llms.custom_httpx.azure_dall_e_2 import CustomHTTPTransport
|
||||
import copy
|
||||
class Router:
|
||||
"""
|
||||
|
@ -166,7 +167,7 @@ class Router:
|
|||
self.print_verbose(f"Intialized router with Routing strategy: {self.routing_strategy}\n")
|
||||
|
||||
|
||||
### COMPLETION + EMBEDDING FUNCTIONS
|
||||
### COMPLETION, EMBEDDING, IMG GENERATION FUNCTIONS
|
||||
|
||||
def completion(self,
|
||||
model: str,
|
||||
|
@ -260,6 +261,50 @@ class Router:
|
|||
self.fail_calls[model_name] +=1
|
||||
raise e
|
||||
|
||||
async def aimage_generation(self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
**kwargs):
|
||||
try:
|
||||
kwargs["model"] = model
|
||||
kwargs["prompt"] = prompt
|
||||
kwargs["original_function"] = self._image_generation
|
||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||
timeout = kwargs.get("request_timeout", self.timeout)
|
||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||
response = await self.async_function_with_fallbacks(**kwargs)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def _image_generation(self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
**kwargs):
|
||||
try:
|
||||
self.print_verbose(f"Inside _image_generation()- model: {model}; kwargs: {kwargs}")
|
||||
deployment = self.get_available_deployment(model=model, messages=[{"role": "user", "content": "prompt"}], specific_deployment=kwargs.pop("specific_deployment", None))
|
||||
kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]})
|
||||
kwargs["model_info"] = deployment.get("model_info", {})
|
||||
data = deployment["litellm_params"].copy()
|
||||
model_name = data["model"]
|
||||
for k, v in self.default_litellm_params.items():
|
||||
if k not in kwargs: # prioritize model-specific params > default router params
|
||||
kwargs[k] = v
|
||||
elif k == "metadata":
|
||||
kwargs[k].update(v)
|
||||
|
||||
model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async")
|
||||
self.total_calls[model_name] +=1
|
||||
response = await litellm.aimage_generation(**{**data, "prompt": prompt, "caching": self.cache_responses, "client": model_client, **kwargs})
|
||||
self.success_calls[model_name] +=1
|
||||
return response
|
||||
except Exception as e:
|
||||
if model_name is not None:
|
||||
self.fail_calls[model_name] +=1
|
||||
raise e
|
||||
|
||||
def text_completion(self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
|
@ -1009,14 +1054,16 @@ class Router:
|
|||
azure_endpoint=api_base,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries
|
||||
max_retries=max_retries,
|
||||
http_client=httpx.Client(transport=CustomHTTPTransport(),) # type: ignore
|
||||
)
|
||||
model["client"] = openai.AzureOpenAI(
|
||||
api_key=api_key,
|
||||
azure_endpoint=api_base,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries
|
||||
max_retries=max_retries,
|
||||
http_client=httpx.Client(transport=CustomHTTPTransport(),) # type: ignore
|
||||
)
|
||||
# streaming clients should have diff timeouts
|
||||
model["stream_async_client"] = openai.AsyncAzureOpenAI(
|
||||
|
@ -1024,7 +1071,7 @@ class Router:
|
|||
azure_endpoint=api_base,
|
||||
api_version=api_version,
|
||||
timeout=stream_timeout,
|
||||
max_retries=max_retries
|
||||
max_retries=max_retries,
|
||||
)
|
||||
|
||||
model["stream_client"] = openai.AzureOpenAI(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue