fix(router.py): only do sync image gen fallbacks for now

The customhttptransport we use for dall-e-2 only works for sync httpx calls, not async. Will need to spend some time writing the async version

n
This commit is contained in:
Krrish Dholakia 2023-12-20 19:10:59 +05:30
parent 350389f501
commit 04bbd0649f
4 changed files with 82 additions and 9 deletions

View file

@ -1,6 +1,4 @@
import time import time, json, httpx, asyncio
import json
import httpx
class CustomHTTPTransport(httpx.HTTPTransport): class CustomHTTPTransport(httpx.HTTPTransport):
""" """

View file

@ -261,7 +261,7 @@ class Router:
self.fail_calls[model_name] +=1 self.fail_calls[model_name] +=1
raise e raise e
async def aimage_generation(self, def image_generation(self,
prompt: str, prompt: str,
model: str, model: str,
**kwargs): **kwargs):
@ -272,13 +272,57 @@ class Router:
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
timeout = kwargs.get("request_timeout", self.timeout) timeout = kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model}) kwargs.setdefault("metadata", {}).update({"model_group": model})
response = self.function_with_fallbacks(**kwargs)
return response
except Exception as e:
raise e
def _image_generation(self,
prompt: str,
model: str,
**kwargs):
try:
self.print_verbose(f"Inside _image_generation()- model: {model}; kwargs: {kwargs}")
deployment = self.get_available_deployment(model=model, messages=[{"role": "user", "content": "prompt"}], specific_deployment=kwargs.pop("specific_deployment", None))
kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]})
kwargs["model_info"] = deployment.get("model_info", {})
data = deployment["litellm_params"].copy()
model_name = data["model"]
for k, v in self.default_litellm_params.items():
if k not in kwargs: # prioritize model-specific params > default router params
kwargs[k] = v
elif k == "metadata":
kwargs[k].update(v)
model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async")
self.total_calls[model_name] +=1
response = litellm.image_generation(**{**data, "prompt": prompt, "caching": self.cache_responses, "client": model_client, **kwargs})
self.success_calls[model_name] +=1
return response
except Exception as e:
if model_name is not None:
self.fail_calls[model_name] +=1
raise e
async def aimage_generation(self,
prompt: str,
model: str,
**kwargs):
try:
kwargs["model"] = model
kwargs["prompt"] = prompt
kwargs["original_function"] = self._aimage_generation
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
timeout = kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = await self.async_function_with_fallbacks(**kwargs) response = await self.async_function_with_fallbacks(**kwargs)
return response return response
except Exception as e: except Exception as e:
raise e raise e
async def _image_generation(self, async def _aimage_generation(self,
prompt: str, prompt: str,
model: str, model: str,
**kwargs): **kwargs):
@ -1055,7 +1099,6 @@ class Router:
api_version=api_version, api_version=api_version,
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
http_client=httpx.Client(transport=CustomHTTPTransport(),) # type: ignore
) )
model["client"] = openai.AzureOpenAI( model["client"] = openai.AzureOpenAI(
api_key=api_key, api_key=api_key,

View file

@ -99,8 +99,6 @@ def test_embedding(client):
def test_chat_completion(client): def test_chat_completion(client):
try: try:
# Your test data # Your test data
print("initialized proxy")
litellm.set_verbose=False litellm.set_verbose=False
from litellm.proxy.utils import get_instance_fn from litellm.proxy.utils import get_instance_fn
my_custom_logger = get_instance_fn( my_custom_logger = get_instance_fn(

View file

@ -455,7 +455,41 @@ async def test_aimg_gen_on_router():
traceback.print_exc() traceback.print_exc()
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
asyncio.run(test_aimg_gen_on_router()) # asyncio.run(test_aimg_gen_on_router())
def test_img_gen_on_router():
litellm.set_verbose = True
try:
model_list = [
{
"model_name": "dall-e-3",
"litellm_params": {
"model": "dall-e-3",
},
},
{
"model_name": "dall-e-3",
"litellm_params": {
"model": "azure/dall-e-3-test",
"api_version": "2023-12-01-preview",
"api_base": os.getenv("AZURE_SWEDEN_API_BASE"),
"api_key": os.getenv("AZURE_SWEDEN_API_KEY")
}
}
]
router = Router(model_list=model_list)
response = router.image_generation(
model="dall-e-3",
prompt="A cute baby sea otter"
)
print(response)
assert len(response.data) > 0
router.reset()
except Exception as e:
traceback.print_exc()
pytest.fail(f"Error occurred: {e}")
test_img_gen_on_router()
### ###
def test_aembedding_on_router(): def test_aembedding_on_router():