mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
fix(router.py): only do sync image gen fallbacks for now
The customhttptransport we use for dall-e-2 only works for sync httpx calls, not async. Will need to spend some time writing the async version n
This commit is contained in:
parent
350389f501
commit
04bbd0649f
4 changed files with 82 additions and 9 deletions
|
@ -1,6 +1,4 @@
|
||||||
import time
|
import time, json, httpx, asyncio
|
||||||
import json
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
class CustomHTTPTransport(httpx.HTTPTransport):
|
class CustomHTTPTransport(httpx.HTTPTransport):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -261,7 +261,7 @@ class Router:
|
||||||
self.fail_calls[model_name] +=1
|
self.fail_calls[model_name] +=1
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def aimage_generation(self,
|
def image_generation(self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
model: str,
|
model: str,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
@ -272,13 +272,57 @@ class Router:
|
||||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||||
timeout = kwargs.get("request_timeout", self.timeout)
|
timeout = kwargs.get("request_timeout", self.timeout)
|
||||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||||
|
response = self.function_with_fallbacks(**kwargs)
|
||||||
|
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def _image_generation(self,
|
||||||
|
prompt: str,
|
||||||
|
model: str,
|
||||||
|
**kwargs):
|
||||||
|
try:
|
||||||
|
self.print_verbose(f"Inside _image_generation()- model: {model}; kwargs: {kwargs}")
|
||||||
|
deployment = self.get_available_deployment(model=model, messages=[{"role": "user", "content": "prompt"}], specific_deployment=kwargs.pop("specific_deployment", None))
|
||||||
|
kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]})
|
||||||
|
kwargs["model_info"] = deployment.get("model_info", {})
|
||||||
|
data = deployment["litellm_params"].copy()
|
||||||
|
model_name = data["model"]
|
||||||
|
for k, v in self.default_litellm_params.items():
|
||||||
|
if k not in kwargs: # prioritize model-specific params > default router params
|
||||||
|
kwargs[k] = v
|
||||||
|
elif k == "metadata":
|
||||||
|
kwargs[k].update(v)
|
||||||
|
|
||||||
|
model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async")
|
||||||
|
self.total_calls[model_name] +=1
|
||||||
|
response = litellm.image_generation(**{**data, "prompt": prompt, "caching": self.cache_responses, "client": model_client, **kwargs})
|
||||||
|
self.success_calls[model_name] +=1
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
if model_name is not None:
|
||||||
|
self.fail_calls[model_name] +=1
|
||||||
|
raise e
|
||||||
|
|
||||||
|
async def aimage_generation(self,
|
||||||
|
prompt: str,
|
||||||
|
model: str,
|
||||||
|
**kwargs):
|
||||||
|
try:
|
||||||
|
kwargs["model"] = model
|
||||||
|
kwargs["prompt"] = prompt
|
||||||
|
kwargs["original_function"] = self._aimage_generation
|
||||||
|
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||||
|
timeout = kwargs.get("request_timeout", self.timeout)
|
||||||
|
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||||
response = await self.async_function_with_fallbacks(**kwargs)
|
response = await self.async_function_with_fallbacks(**kwargs)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def _image_generation(self,
|
async def _aimage_generation(self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
model: str,
|
model: str,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
@ -1055,7 +1099,6 @@ class Router:
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
http_client=httpx.Client(transport=CustomHTTPTransport(),) # type: ignore
|
|
||||||
)
|
)
|
||||||
model["client"] = openai.AzureOpenAI(
|
model["client"] = openai.AzureOpenAI(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
|
|
@ -99,8 +99,6 @@ def test_embedding(client):
|
||||||
def test_chat_completion(client):
|
def test_chat_completion(client):
|
||||||
try:
|
try:
|
||||||
# Your test data
|
# Your test data
|
||||||
|
|
||||||
print("initialized proxy")
|
|
||||||
litellm.set_verbose=False
|
litellm.set_verbose=False
|
||||||
from litellm.proxy.utils import get_instance_fn
|
from litellm.proxy.utils import get_instance_fn
|
||||||
my_custom_logger = get_instance_fn(
|
my_custom_logger = get_instance_fn(
|
||||||
|
|
|
@ -455,7 +455,41 @@ async def test_aimg_gen_on_router():
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
asyncio.run(test_aimg_gen_on_router())
|
# asyncio.run(test_aimg_gen_on_router())
|
||||||
|
|
||||||
|
def test_img_gen_on_router():
|
||||||
|
litellm.set_verbose = True
|
||||||
|
try:
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "dall-e-3",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "dall-e-3",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "dall-e-3",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/dall-e-3-test",
|
||||||
|
"api_version": "2023-12-01-preview",
|
||||||
|
"api_base": os.getenv("AZURE_SWEDEN_API_BASE"),
|
||||||
|
"api_key": os.getenv("AZURE_SWEDEN_API_KEY")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
router = Router(model_list=model_list)
|
||||||
|
response = router.image_generation(
|
||||||
|
model="dall-e-3",
|
||||||
|
prompt="A cute baby sea otter"
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
assert len(response.data) > 0
|
||||||
|
router.reset()
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
test_img_gen_on_router()
|
||||||
###
|
###
|
||||||
|
|
||||||
def test_aembedding_on_router():
|
def test_aembedding_on_router():
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue