[BETA] Add OpenAI /images/variations + Topaz API support (#7700)

* feat(main.py): initial commit for `/image/variations` endpoint support

* refactor(base_llm/): introduce new base llm base config for image variation endpoints

* refactor(openai/image_variations/transformation.py): implement openai image variation transformation handler

* fix: test

* feat(openai/): working openai `/image/variation` endpoint calls via sdk

* feat(topaz/): topaz sync image variation call support

Addresses https://github.com/BerriAI/litellm/issues/7593

'

* fix(topaz/transformation.py): fix linting errors

* fix(openai/image_variations/handler.py): fix passing json data

* fix(main.py): image_variation/

support async image variation route - `aimage_variation`

* fix(test_get_model_info.py): fix test

* fix: cleanup unused imports

* feat(openai/): add async `/image/variations` endpoint support

* feat(topaz/): support async `/image/variations` calls

* fix: test

* fix(utils.py): fix get_model_info_helper for no model info w/ provider config

handles situation where model info is not known but provider config exists

* test(test_router_fallbacks.py): mark flaky test

* fix: fix unused imports

* test: bump otel load test perf threshold - accounts for current load tests hitting same server
This commit is contained in:
Krish Dholakia 2025-01-11 23:27:46 -08:00 committed by GitHub
parent d21e4dedbd
commit 8ee79dd5d9
25 changed files with 1254 additions and 20 deletions

View file

@ -73,6 +73,7 @@ from litellm.secret_managers.main import get_secret_str
from litellm.types.router import GenericLiteLLMParams
from litellm.utils import (
CustomStreamWrapper,
ProviderConfigManager,
Usage,
async_completion_with_fallbacks,
async_mock_completion_streaming_obj,
@ -131,6 +132,7 @@ from .llms.nlp_cloud.chat.handler import completion as nlp_cloud_chat_completion
from .llms.ollama.completion import handler as ollama
from .llms.oobabooga.chat import oobabooga
from .llms.openai.completion.handler import OpenAITextCompletion
from .llms.openai.image_variations.handler import OpenAIImageVariationsHandler
from .llms.openai.openai import OpenAIChatCompletion
from .llms.openai.transcriptions.handler import OpenAIAudioTranscription
from .llms.openai_like.chat.handler import OpenAILikeChatHandler
@ -167,11 +169,13 @@ from .types.llms.openai import (
HttpxBinaryResponseContent,
)
from .types.utils import (
LITELLM_IMAGE_VARIATION_PROVIDERS,
AdapterCompletionStreamWrapper,
ChatCompletionMessageToolCall,
CompletionTokensDetails,
FileTypes,
HiddenParams,
LlmProviders,
PromptTokensDetails,
all_litellm_params,
)
@ -193,6 +197,7 @@ from litellm.utils import (
openai_chat_completions = OpenAIChatCompletion()
openai_text_completions = OpenAITextCompletion()
openai_audio_transcriptions = OpenAIAudioTranscription()
openai_image_variations = OpenAIImageVariationsHandler()
databricks_chat_completions = DatabricksChatCompletion()
groq_chat_completions = GroqChatCompletion()
azure_ai_embedding = AzureAIEmbedding()
@ -4595,6 +4600,156 @@ def image_generation( # noqa: PLR0915
)
@client
async def aimage_variation(*args, **kwargs) -> ImageResponse:
"""
Asynchronously calls the `image_variation` function with the given arguments and keyword arguments.
Parameters:
- `args` (tuple): Positional arguments to be passed to the `image_variation` function.
- `kwargs` (dict): Keyword arguments to be passed to the `image_variation` function.
Returns:
- `response` (Any): The response returned by the `image_variation` function.
"""
loop = asyncio.get_event_loop()
model = kwargs.get("model", None)
custom_llm_provider = kwargs.get("custom_llm_provider", None)
### PASS ARGS TO Image Generation ###
kwargs["async_call"] = True
try:
# Use a partial function to pass your keyword arguments
func = partial(image_variation, *args, **kwargs)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
if custom_llm_provider is None and model is not None:
_, custom_llm_provider, _, _ = get_llm_provider(
model=model, api_base=kwargs.get("api_base", None)
)
# Await normally
init_response = await loop.run_in_executor(None, func_with_context)
if isinstance(init_response, dict) or isinstance(
init_response, ImageResponse
): ## CACHING SCENARIO
if isinstance(init_response, dict):
init_response = ImageResponse(**init_response)
response = init_response
elif asyncio.iscoroutine(init_response):
response = await init_response # type: ignore
else:
# Call the synchronous function using run_in_executor
response = await loop.run_in_executor(None, func_with_context)
return response
except Exception as e:
custom_llm_provider = custom_llm_provider or "openai"
raise exception_type(
model=model,
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs=args,
extra_kwargs=kwargs,
)
@client
def image_variation(
image: FileTypes,
model: str = "dall-e-2", # set to dall-e-2 by default - like OpenAI.
n: int = 1,
response_format: Literal["url", "b64_json"] = "url",
size: Optional[str] = None,
user: Optional[str] = None,
**kwargs,
) -> ImageResponse:
# get non-default params
# get logging object
litellm_logging_obj = cast(LiteLLMLoggingObj, kwargs.get("litellm_logging_obj"))
# get the litellm params
litellm_params = get_litellm_params(**kwargs)
# get the custom llm provider
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(
model=model,
custom_llm_provider=litellm_params.get("custom_llm_provider", None),
api_base=litellm_params.get("api_base", None),
api_key=litellm_params.get("api_key", None),
)
# route to the correct provider w/ the params
try:
llm_provider = LlmProviders(custom_llm_provider)
image_variation_provider = LITELLM_IMAGE_VARIATION_PROVIDERS(llm_provider)
except ValueError:
raise ValueError(
f"Invalid image variation provider: {custom_llm_provider}. Supported providers are: {LITELLM_IMAGE_VARIATION_PROVIDERS}"
)
model_response = ImageResponse()
response: Optional[ImageResponse] = None
provider_config = ProviderConfigManager.get_provider_model_info(
model=model or "", # openai defaults to dall-e-2
provider=llm_provider,
)
if provider_config is None:
raise ValueError(
f"image variation provider has no known model info config - required for getting api keys, etc.: {custom_llm_provider}. Supported providers are: {LITELLM_IMAGE_VARIATION_PROVIDERS}"
)
api_key = provider_config.get_api_key(litellm_params.get("api_key", None))
api_base = provider_config.get_api_base(litellm_params.get("api_base", None))
if image_variation_provider == LITELLM_IMAGE_VARIATION_PROVIDERS.OPENAI:
if api_key is None:
raise ValueError("API key is required for OpenAI image variations")
if api_base is None:
raise ValueError("API base is required for OpenAI image variations")
response = openai_image_variations.image_variations(
model_response=model_response,
api_key=api_key,
api_base=api_base,
model=model,
image=image,
timeout=litellm_params.get("timeout", None),
custom_llm_provider=custom_llm_provider,
logging_obj=litellm_logging_obj,
optional_params={},
litellm_params=litellm_params,
)
elif image_variation_provider == LITELLM_IMAGE_VARIATION_PROVIDERS.TOPAZ:
if api_key is None:
raise ValueError("API key is required for Topaz image variations")
if api_base is None:
raise ValueError("API base is required for Topaz image variations")
response = base_llm_aiohttp_handler.image_variations(
model_response=model_response,
api_key=api_key,
api_base=api_base,
model=model,
image=image,
timeout=litellm_params.get("timeout", None),
custom_llm_provider=custom_llm_provider,
logging_obj=litellm_logging_obj,
optional_params={},
litellm_params=litellm_params,
)
# return the response
if response is None:
raise ValueError(
f"Invalid image variation provider: {custom_llm_provider}. Supported providers are: {LITELLM_IMAGE_VARIATION_PROVIDERS}"
)
return response
##### Transcription #######################