[BETA] Add OpenAI /images/variations + Topaz API support (#7700)

* feat(main.py): initial commit for `/image/variations` endpoint support

* refactor(base_llm/): introduce new base llm base config for image variation endpoints

* refactor(openai/image_variations/transformation.py): implement openai image variation transformation handler

* fix: test

* feat(openai/): working openai `/image/variation` endpoint calls via sdk

* feat(topaz/): topaz sync image variation call support

Addresses https://github.com/BerriAI/litellm/issues/7593

'

* fix(topaz/transformation.py): fix linting errors

* fix(openai/image_variations/handler.py): fix passing json data

* fix(main.py): image_variation/

support async image variation route - `aimage_variation`

* fix(test_get_model_info.py): fix test

* fix: cleanup unused imports

* feat(openai/): add async `/image/variations` endpoint support

* feat(topaz/): support async `/image/variations` calls

* fix: test

* fix(utils.py): fix get_model_info_helper for no model info w/ provider config

handles situation where model info is not known but provider config exists

* test(test_router_fallbacks.py): mark flaky test

* fix: fix unused imports

* test: bump otel load test perf threshold - accounts for current load tests hitting same server
This commit is contained in:
Krish Dholakia 2025-01-11 23:27:46 -08:00 committed by GitHub
parent a7c803edc5
commit ad2f66b3e3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 1254 additions and 20 deletions

View file

@ -181,6 +181,9 @@ from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.llms.base_llm.completion.transformation import BaseTextCompletionConfig
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
from litellm.llms.base_llm.image_variations.transformation import (
BaseImageVariationConfig,
)
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
from ._logging import _is_debugging_on, verbose_logger
@ -924,6 +927,8 @@ def client(original_function): # noqa: PLR0915
return result
elif "aspeech" in kwargs and kwargs["aspeech"] is True:
return result
elif asyncio.iscoroutine(result): # bubble up to relevant async function
return result
### POST-CALL RULES ###
post_call_processing(
@ -1954,7 +1959,6 @@ def register_model(model_cost: Union[str, dict]): # noqa: PLR0915
## override / add new keys to the existing model cost dictionary
updated_dictionary = _update_dictionary(existing_model, value)
litellm.model_cost.setdefault(model_cost_key, {}).update(updated_dictionary)
verbose_logger.debug(f"{model_cost_key} added to model cost map")
# add new model names to provider lists
if value.get("litellm_provider") == "openai":
if key not in litellm.open_ai_chat_completion_models:
@ -2036,7 +2040,9 @@ def get_litellm_params(
drop_params: Optional[bool] = None,
prompt_id: Optional[str] = None,
prompt_variables: Optional[dict] = None,
):
async_call: Optional[bool] = None,
**kwargs,
) -> dict:
litellm_params = {
"acompletion": acompletion,
"api_key": api_key,
@ -2072,6 +2078,7 @@ def get_litellm_params(
"drop_params": drop_params,
"prompt_id": prompt_id,
"prompt_variables": prompt_variables,
"async_call": async_call,
}
return litellm_params
@ -4123,8 +4130,7 @@ def _get_model_info_helper( # noqa: PLR0915
model=model, existing_model_info=_model_info
),
)
if key is None:
key = "provider_specific_model_info"
key = "provider_specific_model_info"
if _model_info is None or key is None:
raise ValueError(
"This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json"
@ -4230,6 +4236,7 @@ def _get_model_info_helper( # noqa: PLR0915
rpm=_model_info.get("rpm", None),
)
except Exception as e:
verbose_logger.debug(f"Error getting model info: {e}")
if "OllamaError" in str(e):
raise e
raise Exception(
@ -6165,11 +6172,26 @@ class ProviderConfigManager:
) -> Optional[BaseLLMModelInfo]:
if LlmProviders.FIREWORKS_AI == provider:
return litellm.FireworksAIConfig()
elif LlmProviders.OPENAI == provider:
return litellm.OpenAIGPTConfig()
elif LlmProviders.LITELLM_PROXY == provider:
return litellm.LiteLLMProxyChatConfig()
elif LlmProviders.TOPAZ == provider:
return litellm.TopazModelInfo()
return None
@staticmethod
def get_provider_image_variation_config(
model: str,
provider: LlmProviders,
) -> Optional[BaseImageVariationConfig]:
if LlmProviders.OPENAI == provider:
return litellm.OpenAIImageVariationConfig()
elif LlmProviders.TOPAZ == provider:
return litellm.TopazImageVariationConfig()
return None
def get_end_user_id_for_cost_tracking(
litellm_params: dict,