forked from phoenix/litellm-mirror
feat(main.py): add support for image generation endpoint
This commit is contained in:
parent
7847ae1e23
commit
13d088b72e
7 changed files with 366 additions and 9 deletions
|
@ -132,8 +132,7 @@ for key, value in model_cost.items():
|
||||||
elif value.get('litellm_provider') == 'anthropic':
|
elif value.get('litellm_provider') == 'anthropic':
|
||||||
anthropic_models.append(key)
|
anthropic_models.append(key)
|
||||||
elif value.get('litellm_provider') == 'openrouter':
|
elif value.get('litellm_provider') == 'openrouter':
|
||||||
split_string = key.split('/', 1)
|
openrouter_models.append(key)
|
||||||
openrouter_models.append(split_string[1])
|
|
||||||
elif value.get('litellm_provider') == 'vertex_ai-text-models':
|
elif value.get('litellm_provider') == 'vertex_ai-text-models':
|
||||||
vertex_text_models.append(key)
|
vertex_text_models.append(key)
|
||||||
elif value.get('litellm_provider') == 'vertex_ai-code-text-models':
|
elif value.get('litellm_provider') == 'vertex_ai-code-text-models':
|
||||||
|
@ -366,6 +365,13 @@ bedrock_embedding_models: List = ["amazon.titan-embed-text-v1", "cohere.embed-en
|
||||||
|
|
||||||
all_embedding_models = open_ai_embedding_models + cohere_embedding_models + bedrock_embedding_models
|
all_embedding_models = open_ai_embedding_models + cohere_embedding_models + bedrock_embedding_models
|
||||||
|
|
||||||
|
####### IMAGE GENERATION MODELS ###################
|
||||||
|
openai_image_generation_models = [
|
||||||
|
"dall-e-2",
|
||||||
|
"dall-e-3"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
from .timeout import timeout
|
from .timeout import timeout
|
||||||
from .utils import (
|
from .utils import (
|
||||||
client,
|
client,
|
||||||
|
|
|
@ -456,6 +456,67 @@ class AzureChatCompletion(BaseLLM):
|
||||||
except AzureOpenAIError as e:
|
except AzureOpenAIError as e:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise e
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
if exception_mapping_worked:
|
||||||
|
raise e
|
||||||
|
else:
|
||||||
|
import traceback
|
||||||
|
raise AzureOpenAIError(status_code=500, message=traceback.format_exc())
|
||||||
|
|
||||||
|
def image_generation(self,
|
||||||
|
prompt: list,
|
||||||
|
timeout: float,
|
||||||
|
model: Optional[str]=None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
|
model_response: Optional[litellm.utils.ImageResponse] = None,
|
||||||
|
logging_obj=None,
|
||||||
|
optional_params=None,
|
||||||
|
client=None,
|
||||||
|
aimg_generation=None,
|
||||||
|
):
|
||||||
|
exception_mapping_worked = False
|
||||||
|
try:
|
||||||
|
model = model
|
||||||
|
data = {
|
||||||
|
# "model": model,
|
||||||
|
"prompt": prompt,
|
||||||
|
**optional_params
|
||||||
|
}
|
||||||
|
max_retries = data.pop("max_retries", 2)
|
||||||
|
if not isinstance(max_retries, int):
|
||||||
|
raise AzureOpenAIError(status_code=422, message="max retries must be an int")
|
||||||
|
|
||||||
|
# if aembedding == True:
|
||||||
|
# response = self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
|
||||||
|
# return response
|
||||||
|
|
||||||
|
if client is None:
|
||||||
|
azure_client = AzureOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries) # type: ignore
|
||||||
|
else:
|
||||||
|
azure_client = client
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=prompt,
|
||||||
|
api_key=azure_client.api_key,
|
||||||
|
additional_args={"headers": {"Authorization": f"Bearer {azure_client.api_key}"}, "api_base": azure_client._base_url._uri_reference, "acompletion": False, "complete_input_dict": data},
|
||||||
|
)
|
||||||
|
|
||||||
|
## COMPLETION CALL
|
||||||
|
response = azure_client.images.generate(**data) # type: ignore
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=input,
|
||||||
|
api_key=api_key,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
original_response=response,
|
||||||
|
)
|
||||||
|
# return response
|
||||||
|
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="image_generation") # type: ignore
|
||||||
|
except AzureOpenAIError as e:
|
||||||
|
exception_mapping_worked = True
|
||||||
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if exception_mapping_worked:
|
if exception_mapping_worked:
|
||||||
raise e
|
raise e
|
||||||
|
|
|
@ -445,6 +445,66 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
import traceback
|
import traceback
|
||||||
raise OpenAIError(status_code=500, message=traceback.format_exc())
|
raise OpenAIError(status_code=500, message=traceback.format_exc())
|
||||||
|
|
||||||
|
def image_generation(self,
|
||||||
|
model: Optional[str],
|
||||||
|
prompt: str,
|
||||||
|
timeout: float,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
|
model_response: Optional[litellm.utils.ImageResponse] = None,
|
||||||
|
logging_obj=None,
|
||||||
|
optional_params=None,
|
||||||
|
client=None,
|
||||||
|
aimg_generation=None,
|
||||||
|
):
|
||||||
|
exception_mapping_worked = False
|
||||||
|
try:
|
||||||
|
model = model
|
||||||
|
data = {
|
||||||
|
"model": model,
|
||||||
|
"prompt": prompt,
|
||||||
|
**optional_params
|
||||||
|
}
|
||||||
|
max_retries = data.pop("max_retries", 2)
|
||||||
|
if not isinstance(max_retries, int):
|
||||||
|
raise OpenAIError(status_code=422, message="max retries must be an int")
|
||||||
|
|
||||||
|
# if aembedding == True:
|
||||||
|
# response = self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
|
||||||
|
# return response
|
||||||
|
|
||||||
|
if client is None:
|
||||||
|
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries)
|
||||||
|
else:
|
||||||
|
openai_client = client
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=prompt,
|
||||||
|
api_key=openai_client.api_key,
|
||||||
|
additional_args={"headers": {"Authorization": f"Bearer {openai_client.api_key}"}, "api_base": openai_client._base_url._uri_reference, "acompletion": True, "complete_input_dict": data},
|
||||||
|
)
|
||||||
|
|
||||||
|
## COMPLETION CALL
|
||||||
|
response = openai_client.images.generate(**data) # type: ignore
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=input,
|
||||||
|
api_key=api_key,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
original_response=response,
|
||||||
|
)
|
||||||
|
# return response
|
||||||
|
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="image_generation") # type: ignore
|
||||||
|
except OpenAIError as e:
|
||||||
|
exception_mapping_worked = True
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
if exception_mapping_worked:
|
||||||
|
raise e
|
||||||
|
else:
|
||||||
|
import traceback
|
||||||
|
raise OpenAIError(status_code=500, message=traceback.format_exc())
|
||||||
|
|
||||||
class OpenAITextCompletion(BaseLLM):
|
class OpenAITextCompletion(BaseLLM):
|
||||||
_client_session: httpx.Client
|
_client_session: httpx.Client
|
||||||
|
|
|
@ -33,7 +33,8 @@ from litellm.utils import (
|
||||||
convert_to_model_response_object,
|
convert_to_model_response_object,
|
||||||
token_counter,
|
token_counter,
|
||||||
Usage,
|
Usage,
|
||||||
get_optional_params_embeddings
|
get_optional_params_embeddings,
|
||||||
|
get_optional_params_image_gen
|
||||||
)
|
)
|
||||||
from .llms import (
|
from .llms import (
|
||||||
anthropic,
|
anthropic,
|
||||||
|
@ -2237,6 +2238,91 @@ def moderation(input: str, api_key: Optional[str]=None):
|
||||||
response = openai.moderations.create(input=input)
|
response = openai.moderations.create(input=input)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
##### Image Generation #######################
|
||||||
|
@client
|
||||||
|
def image_generation(prompt: str,
|
||||||
|
model: Optional[str]=None,
|
||||||
|
n: Optional[int]=None,
|
||||||
|
quality: Optional[str]=None,
|
||||||
|
response_format: Optional[str]=None,
|
||||||
|
size: Optional[str]=None,
|
||||||
|
style: Optional[str]=None,
|
||||||
|
user: Optional[str]=None,
|
||||||
|
timeout=600, # default to 10 minutes
|
||||||
|
api_key: Optional[str]=None,
|
||||||
|
api_base: Optional[str]=None,
|
||||||
|
api_version: Optional[str] = None,
|
||||||
|
litellm_logging_obj=None,
|
||||||
|
custom_llm_provider=None,
|
||||||
|
**kwargs):
|
||||||
|
"""
|
||||||
|
Maps the https://api.openai.com/v1/images/generations endpoint.
|
||||||
|
|
||||||
|
Currently supports just Azure + OpenAI.
|
||||||
|
"""
|
||||||
|
litellm_call_id = kwargs.get("litellm_call_id", None)
|
||||||
|
logger_fn = kwargs.get("logger_fn", None)
|
||||||
|
proxy_server_request = kwargs.get('proxy_server_request', None)
|
||||||
|
model_info = kwargs.get("model_info", None)
|
||||||
|
metadata = kwargs.get("metadata", {})
|
||||||
|
|
||||||
|
model_response = litellm.utils.ImageResponse()
|
||||||
|
if model is not None or custom_llm_provider is not None:
|
||||||
|
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
|
||||||
|
else:
|
||||||
|
model = "dall-e-2"
|
||||||
|
custom_llm_provider = "openai" # default to dall-e-2 on openai
|
||||||
|
openai_params = ["user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "max_retries", "n", "quality", "size", "style"]
|
||||||
|
litellm_params = ["metadata", "aembedding", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info", "preset_cache_key", "caching_groups"]
|
||||||
|
default_params = openai_params + litellm_params
|
||||||
|
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
|
||||||
|
optional_params = get_optional_params_image_gen(n=n,
|
||||||
|
quality=quality,
|
||||||
|
response_format=response_format,
|
||||||
|
size=size,
|
||||||
|
style=style,
|
||||||
|
user=user,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
**non_default_params)
|
||||||
|
logging = litellm_logging_obj
|
||||||
|
logging.update_environment_variables(model=model, user=user, optional_params=optional_params, litellm_params={"timeout": timeout, "azure": False, "litellm_call_id": litellm_call_id, "logger_fn": logger_fn, "proxy_server_request": proxy_server_request, "model_info": model_info, "metadata": metadata, "preset_cache_key": None, "stream_response": {}})
|
||||||
|
|
||||||
|
if custom_llm_provider == "azure":
|
||||||
|
# azure configs
|
||||||
|
api_type = get_secret("AZURE_API_TYPE") or "azure"
|
||||||
|
|
||||||
|
api_base = (
|
||||||
|
api_base
|
||||||
|
or litellm.api_base
|
||||||
|
or get_secret("AZURE_API_BASE")
|
||||||
|
)
|
||||||
|
|
||||||
|
api_version = (
|
||||||
|
api_version or
|
||||||
|
litellm.api_version or
|
||||||
|
get_secret("AZURE_API_VERSION")
|
||||||
|
)
|
||||||
|
|
||||||
|
api_key = (
|
||||||
|
api_key or
|
||||||
|
litellm.api_key or
|
||||||
|
litellm.azure_key or
|
||||||
|
get_secret("AZURE_OPENAI_API_KEY") or
|
||||||
|
get_secret("AZURE_API_KEY")
|
||||||
|
)
|
||||||
|
|
||||||
|
azure_ad_token = (
|
||||||
|
optional_params.pop("azure_ad_token", None) or
|
||||||
|
get_secret("AZURE_AD_TOKEN")
|
||||||
|
)
|
||||||
|
|
||||||
|
# model_response = azure_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response)
|
||||||
|
pass
|
||||||
|
elif custom_llm_provider == "openai":
|
||||||
|
model_response = openai_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response)
|
||||||
|
|
||||||
|
return model_response
|
||||||
|
|
||||||
####### HELPER FUNCTIONS ################
|
####### HELPER FUNCTIONS ################
|
||||||
## Set verbose to true -> ```litellm.set_verbose = True```
|
## Set verbose to true -> ```litellm.set_verbose = True```
|
||||||
def print_verbose(print_statement):
|
def print_verbose(print_statement):
|
||||||
|
|
|
@ -650,7 +650,7 @@ def test_completion_azure_key_completion_arg():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
os.environ["AZURE_API_KEY"] = old_key
|
os.environ["AZURE_API_KEY"] = old_key
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
# test_completion_azure_key_completion_arg()
|
test_completion_azure_key_completion_arg()
|
||||||
|
|
||||||
|
|
||||||
async def test_re_use_azure_async_client():
|
async def test_re_use_azure_async_client():
|
||||||
|
|
37
litellm/tests/test_image_generation.py
Normal file
37
litellm/tests/test_image_generation.py
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
# What this tests?
|
||||||
|
## This tests the litellm support for the openai /generations endpoint
|
||||||
|
|
||||||
|
import sys, os
|
||||||
|
import traceback
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
import os
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import pytest
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
def test_image_generation_openai():
|
||||||
|
litellm.set_verbose = True
|
||||||
|
response = litellm.image_generation(prompt="A cute baby sea otter", model="dall-e-3")
|
||||||
|
print(f"response: {response}")
|
||||||
|
|
||||||
|
# test_image_generation_openai()
|
||||||
|
|
||||||
|
# def test_image_generation_azure():
|
||||||
|
# response = litellm.image_generation(prompt="A cute baby sea otter", api_version="2023-06-01-preview", custom_llm_provider="azure")
|
||||||
|
# print(f"response: {response}")
|
||||||
|
# test_image_generation_azure()
|
||||||
|
|
||||||
|
# @pytest.mark.asyncio
|
||||||
|
# async def test_async_image_generation_openai():
|
||||||
|
# response = litellm.image_generation(prompt="A cute baby sea otter", model="dall-e-3")
|
||||||
|
# print(f"response: {response}")
|
||||||
|
|
||||||
|
# @pytest.mark.asyncio
|
||||||
|
# async def test_async_image_generation_azure():
|
||||||
|
# response = litellm.image_generation(prompt="A cute baby sea otter", model="azure/dall-e-3")
|
||||||
|
# print(f"response: {response}")
|
117
litellm/utils.py
117
litellm/utils.py
|
@ -545,6 +545,52 @@ class TextCompletionResponse(OpenAIObject):
|
||||||
# Allow dictionary-style assignment of attributes
|
# Allow dictionary-style assignment of attributes
|
||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageResponse(OpenAIObject):
|
||||||
|
created: Optional[int] = None
|
||||||
|
|
||||||
|
data: Optional[list] = None
|
||||||
|
|
||||||
|
def __init__(self, created=None, data=None, response_ms=None):
|
||||||
|
if response_ms:
|
||||||
|
_response_ms = response_ms
|
||||||
|
else:
|
||||||
|
_response_ms = None
|
||||||
|
if data:
|
||||||
|
data = data
|
||||||
|
else:
|
||||||
|
data = None
|
||||||
|
|
||||||
|
if created:
|
||||||
|
created = created
|
||||||
|
else:
|
||||||
|
created = None
|
||||||
|
|
||||||
|
super().__init__(data=data, created=created)
|
||||||
|
|
||||||
|
def __contains__(self, key):
|
||||||
|
# Define custom behavior for the 'in' operator
|
||||||
|
return hasattr(self, key)
|
||||||
|
|
||||||
|
def get(self, key, default=None):
|
||||||
|
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
|
||||||
|
return getattr(self, key, default)
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
# Allow dictionary-style access to attributes
|
||||||
|
return getattr(self, key)
|
||||||
|
|
||||||
|
def __setitem__(self, key, value):
|
||||||
|
# Allow dictionary-style assignment of attributes
|
||||||
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
def json(self, **kwargs):
|
||||||
|
try:
|
||||||
|
return self.model_dump() # noqa
|
||||||
|
except:
|
||||||
|
# if using pydantic v1
|
||||||
|
return self.dict()
|
||||||
|
|
||||||
############################################################
|
############################################################
|
||||||
def print_verbose(print_statement):
|
def print_verbose(print_statement):
|
||||||
try:
|
try:
|
||||||
|
@ -561,6 +607,8 @@ class CallTypes(Enum):
|
||||||
completion = 'completion'
|
completion = 'completion'
|
||||||
acompletion = 'acompletion'
|
acompletion = 'acompletion'
|
||||||
aembedding = 'aembedding'
|
aembedding = 'aembedding'
|
||||||
|
image_generation = 'image_generation'
|
||||||
|
aimage_generation = 'aimage_generation'
|
||||||
|
|
||||||
# Logging function -> log the exact model details + what's being sent | Non-Blocking
|
# Logging function -> log the exact model details + what's being sent | Non-Blocking
|
||||||
class Logging:
|
class Logging:
|
||||||
|
@ -1499,7 +1547,7 @@ def client(original_function):
|
||||||
# CRASH REPORTING TELEMETRY
|
# CRASH REPORTING TELEMETRY
|
||||||
crash_reporting(*args, **kwargs)
|
crash_reporting(*args, **kwargs)
|
||||||
# INIT LOGGER - for user-specified integrations
|
# INIT LOGGER - for user-specified integrations
|
||||||
model = args[0] if len(args) > 0 else kwargs["model"]
|
model = args[0] if len(args) > 0 else kwargs.get("model", None)
|
||||||
call_type = original_function.__name__
|
call_type = original_function.__name__
|
||||||
if call_type == CallTypes.completion.value or call_type == CallTypes.acompletion.value:
|
if call_type == CallTypes.completion.value or call_type == CallTypes.acompletion.value:
|
||||||
messages = None
|
messages = None
|
||||||
|
@ -1512,6 +1560,8 @@ def client(original_function):
|
||||||
rules_obj.pre_call_rules(input="".join(m["content"] for m in messages if isinstance(m["content"], str)), model=model)
|
rules_obj.pre_call_rules(input="".join(m["content"] for m in messages if isinstance(m["content"], str)), model=model)
|
||||||
elif call_type == CallTypes.embedding.value or call_type == CallTypes.aembedding.value:
|
elif call_type == CallTypes.embedding.value or call_type == CallTypes.aembedding.value:
|
||||||
messages = args[1] if len(args) > 1 else kwargs["input"]
|
messages = args[1] if len(args) > 1 else kwargs["input"]
|
||||||
|
elif call_type == CallTypes.image_generation.value or call_type == CallTypes.aimage_generation.value:
|
||||||
|
messages = args[0] if len(args) > 0 else kwargs["prompt"]
|
||||||
stream = True if "stream" in kwargs and kwargs["stream"] == True else False
|
stream = True if "stream" in kwargs and kwargs["stream"] == True else False
|
||||||
logging_obj = Logging(model=model, messages=messages, stream=stream, litellm_call_id=kwargs["litellm_call_id"], function_id=function_id, call_type=call_type, start_time=start_time)
|
logging_obj = Logging(model=model, messages=messages, stream=stream, litellm_call_id=kwargs["litellm_call_id"], function_id=function_id, call_type=call_type, start_time=start_time)
|
||||||
return logging_obj
|
return logging_obj
|
||||||
|
@ -1560,7 +1610,9 @@ def client(original_function):
|
||||||
try:
|
try:
|
||||||
model = args[0] if len(args) > 0 else kwargs["model"]
|
model = args[0] if len(args) > 0 else kwargs["model"]
|
||||||
except:
|
except:
|
||||||
raise ValueError("model param not passed in.")
|
call_type = original_function.__name__
|
||||||
|
if call_type != CallTypes.image_generation.value:
|
||||||
|
raise ValueError("model param not passed in.")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if logging_obj is None:
|
if logging_obj is None:
|
||||||
|
@ -1614,7 +1666,7 @@ def client(original_function):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
### POST-CALL RULES ###
|
### POST-CALL RULES ###
|
||||||
post_call_processing(original_response=result, model=model)
|
post_call_processing(original_response=result, model=model or None)
|
||||||
|
|
||||||
# [OPTIONAL] ADD TO CACHE
|
# [OPTIONAL] ADD TO CACHE
|
||||||
if litellm.cache is not None and str(original_function.__name__) in litellm.cache.supported_call_types:
|
if litellm.cache is not None and str(original_function.__name__) in litellm.cache.supported_call_types:
|
||||||
|
@ -2207,6 +2259,47 @@ def get_litellm_params(
|
||||||
|
|
||||||
return litellm_params
|
return litellm_params
|
||||||
|
|
||||||
|
def get_optional_params_image_gen(
|
||||||
|
n: Optional[int]=None,
|
||||||
|
quality: Optional[str]=None,
|
||||||
|
response_format: Optional[str]=None,
|
||||||
|
size: Optional[str]=None,
|
||||||
|
style: Optional[str]=None,
|
||||||
|
user: Optional[str]=None,
|
||||||
|
custom_llm_provider: Optional[str]=None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
# retrieve all parameters passed to the function
|
||||||
|
passed_params = locals()
|
||||||
|
custom_llm_provider = passed_params.pop("custom_llm_provider")
|
||||||
|
special_params = passed_params.pop("kwargs")
|
||||||
|
for k, v in special_params.items():
|
||||||
|
passed_params[k] = v
|
||||||
|
|
||||||
|
default_params = {
|
||||||
|
"n": None,
|
||||||
|
"quality" : None,
|
||||||
|
"response_format" : None,
|
||||||
|
"size": None,
|
||||||
|
"style": None,
|
||||||
|
"user": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
non_default_params = {k: v for k, v in passed_params.items() if (k in default_params and v != default_params[k])}
|
||||||
|
## raise exception if non-default value passed for non-openai/azure embedding calls
|
||||||
|
if custom_llm_provider != "openai" and custom_llm_provider != "azure":
|
||||||
|
if len(non_default_params.keys()) > 0:
|
||||||
|
if litellm.drop_params is True: # drop the unsupported non-default values
|
||||||
|
keys = list(non_default_params.keys())
|
||||||
|
for k in keys:
|
||||||
|
non_default_params.pop(k, None)
|
||||||
|
return non_default_params
|
||||||
|
raise UnsupportedParamsError(status_code=500, message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.")
|
||||||
|
|
||||||
|
final_params = {**non_default_params, **kwargs}
|
||||||
|
return final_params
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_optional_params_embeddings(
|
def get_optional_params_embeddings(
|
||||||
# 2 optional params
|
# 2 optional params
|
||||||
|
@ -2854,7 +2947,7 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None, api_
|
||||||
|
|
||||||
# check if model in known model provider list -> for huggingface models, raise exception as they don't have a fixed provider (can be togetherai, anyscale, baseten, runpod, et.)
|
# check if model in known model provider list -> for huggingface models, raise exception as they don't have a fixed provider (can be togetherai, anyscale, baseten, runpod, et.)
|
||||||
## openai - chatcompletion + text completion
|
## openai - chatcompletion + text completion
|
||||||
if model in litellm.open_ai_chat_completion_models or "ft:gpt-3.5-turbo" in model:
|
if model in litellm.open_ai_chat_completion_models or "ft:gpt-3.5-turbo" in model or model in litellm.openai_image_generation_models:
|
||||||
custom_llm_provider = "openai"
|
custom_llm_provider = "openai"
|
||||||
elif model in litellm.open_ai_text_completion_models:
|
elif model in litellm.open_ai_text_completion_models:
|
||||||
custom_llm_provider = "text-completion-openai"
|
custom_llm_provider = "text-completion-openai"
|
||||||
|
@ -3801,7 +3894,7 @@ def convert_to_streaming_response(response_object: Optional[dict]=None):
|
||||||
yield model_response_object
|
yield model_response_object
|
||||||
|
|
||||||
|
|
||||||
def convert_to_model_response_object(response_object: Optional[dict]=None, model_response_object: Optional[Union[ModelResponse, EmbeddingResponse]]=None, response_type: Literal["completion", "embedding"] = "completion", stream = False):
|
def convert_to_model_response_object(response_object: Optional[dict]=None, model_response_object: Optional[Union[ModelResponse, EmbeddingResponse, ImageResponse]]=None, response_type: Literal["completion", "embedding", "image_generation"] = "completion", stream = False):
|
||||||
try:
|
try:
|
||||||
if response_type == "completion" and (model_response_object is None or isinstance(model_response_object, ModelResponse)):
|
if response_type == "completion" and (model_response_object is None or isinstance(model_response_object, ModelResponse)):
|
||||||
if response_object is None or model_response_object is None:
|
if response_object is None or model_response_object is None:
|
||||||
|
@ -3863,6 +3956,20 @@ def convert_to_model_response_object(response_object: Optional[dict]=None, model
|
||||||
model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore
|
model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
return model_response_object
|
||||||
|
elif response_type == "image_generation" and (model_response_object is None or isinstance(model_response_object, ImageResponse)):
|
||||||
|
if response_object is None:
|
||||||
|
raise Exception("Error in response object format")
|
||||||
|
|
||||||
|
if model_response_object is None:
|
||||||
|
model_response_object = EmbeddingResponse()
|
||||||
|
|
||||||
|
if "created" in response_object:
|
||||||
|
model_response_object.created = response_object["created"]
|
||||||
|
|
||||||
|
if "data" in response_object:
|
||||||
|
model_response_object.data = response_object["data"]
|
||||||
|
|
||||||
return model_response_object
|
return model_response_object
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(f"Invalid response object {e}")
|
raise Exception(f"Invalid response object {e}")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue