fix(router.py): only do sync image gen fallbacks for now

The customhttptransport we use for dall-e-2 only works for sync httpx calls, not async. Will need to spend some time writing the async version

n
This commit is contained in:
Krrish Dholakia 2023-12-20 19:10:59 +05:30
parent 350389f501
commit 04bbd0649f
4 changed files with 82 additions and 9 deletions

View file

@ -261,6 +261,50 @@ class Router:
self.fail_calls[model_name] +=1
raise e
def image_generation(self,
prompt: str,
model: str,
**kwargs):
try:
kwargs["model"] = model
kwargs["prompt"] = prompt
kwargs["original_function"] = self._image_generation
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
timeout = kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = self.function_with_fallbacks(**kwargs)
return response
except Exception as e:
raise e
def _image_generation(self,
prompt: str,
model: str,
**kwargs):
try:
self.print_verbose(f"Inside _image_generation()- model: {model}; kwargs: {kwargs}")
deployment = self.get_available_deployment(model=model, messages=[{"role": "user", "content": "prompt"}], specific_deployment=kwargs.pop("specific_deployment", None))
kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]})
kwargs["model_info"] = deployment.get("model_info", {})
data = deployment["litellm_params"].copy()
model_name = data["model"]
for k, v in self.default_litellm_params.items():
if k not in kwargs: # prioritize model-specific params > default router params
kwargs[k] = v
elif k == "metadata":
kwargs[k].update(v)
model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async")
self.total_calls[model_name] +=1
response = litellm.image_generation(**{**data, "prompt": prompt, "caching": self.cache_responses, "client": model_client, **kwargs})
self.success_calls[model_name] +=1
return response
except Exception as e:
if model_name is not None:
self.fail_calls[model_name] +=1
raise e
async def aimage_generation(self,
prompt: str,
model: str,
@ -268,7 +312,7 @@ class Router:
try:
kwargs["model"] = model
kwargs["prompt"] = prompt
kwargs["original_function"] = self._image_generation
kwargs["original_function"] = self._aimage_generation
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
timeout = kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
@ -278,7 +322,7 @@ class Router:
except Exception as e:
raise e
async def _image_generation(self,
async def _aimage_generation(self,
prompt: str,
model: str,
**kwargs):
@ -1055,7 +1099,6 @@ class Router:
api_version=api_version,
timeout=timeout,
max_retries=max_retries,
http_client=httpx.Client(transport=CustomHTTPTransport(),) # type: ignore
)
model["client"] = openai.AzureOpenAI(
api_key=api_key,