mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Merge branch 'main' into explicit-args-acomplete
This commit is contained in:
commit
203089e6c7
10 changed files with 271 additions and 477 deletions
331
litellm/main.py
331
litellm/main.py
|
@ -1173,7 +1173,7 @@ def completion(
|
|||
acompletion=acompletion,
|
||||
logging_obj=logging,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
timeout=timeout
|
||||
timeout=timeout,
|
||||
)
|
||||
if (
|
||||
"stream" in optional_params
|
||||
|
@ -2894,158 +2894,167 @@ def image_generation(
|
|||
|
||||
Currently supports just Azure + OpenAI.
|
||||
"""
|
||||
aimg_generation = kwargs.get("aimg_generation", False)
|
||||
litellm_call_id = kwargs.get("litellm_call_id", None)
|
||||
logger_fn = kwargs.get("logger_fn", None)
|
||||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
||||
model_info = kwargs.get("model_info", None)
|
||||
metadata = kwargs.get("metadata", {})
|
||||
try:
|
||||
aimg_generation = kwargs.get("aimg_generation", False)
|
||||
litellm_call_id = kwargs.get("litellm_call_id", None)
|
||||
logger_fn = kwargs.get("logger_fn", None)
|
||||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
||||
model_info = kwargs.get("model_info", None)
|
||||
metadata = kwargs.get("metadata", {})
|
||||
|
||||
model_response = litellm.utils.ImageResponse()
|
||||
if model is not None or custom_llm_provider is not None:
|
||||
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
|
||||
else:
|
||||
model = "dall-e-2"
|
||||
custom_llm_provider = "openai" # default to dall-e-2 on openai
|
||||
openai_params = [
|
||||
"user",
|
||||
"request_timeout",
|
||||
"api_base",
|
||||
"api_version",
|
||||
"api_key",
|
||||
"deployment_id",
|
||||
"organization",
|
||||
"base_url",
|
||||
"default_headers",
|
||||
"timeout",
|
||||
"max_retries",
|
||||
"n",
|
||||
"quality",
|
||||
"size",
|
||||
"style",
|
||||
]
|
||||
litellm_params = [
|
||||
"metadata",
|
||||
"aimg_generation",
|
||||
"caching",
|
||||
"mock_response",
|
||||
"api_key",
|
||||
"api_version",
|
||||
"api_base",
|
||||
"force_timeout",
|
||||
"logger_fn",
|
||||
"verbose",
|
||||
"custom_llm_provider",
|
||||
"litellm_logging_obj",
|
||||
"litellm_call_id",
|
||||
"use_client",
|
||||
"id",
|
||||
"fallbacks",
|
||||
"azure",
|
||||
"headers",
|
||||
"model_list",
|
||||
"num_retries",
|
||||
"context_window_fallback_dict",
|
||||
"roles",
|
||||
"final_prompt_value",
|
||||
"bos_token",
|
||||
"eos_token",
|
||||
"request_timeout",
|
||||
"complete_response",
|
||||
"self",
|
||||
"client",
|
||||
"rpm",
|
||||
"tpm",
|
||||
"input_cost_per_token",
|
||||
"output_cost_per_token",
|
||||
"hf_model_name",
|
||||
"proxy_server_request",
|
||||
"model_info",
|
||||
"preset_cache_key",
|
||||
"caching_groups",
|
||||
"ttl",
|
||||
"cache",
|
||||
]
|
||||
default_params = openai_params + litellm_params
|
||||
non_default_params = {
|
||||
k: v for k, v in kwargs.items() if k not in default_params
|
||||
} # model-specific params - pass them straight to the model/provider
|
||||
optional_params = get_optional_params_image_gen(
|
||||
n=n,
|
||||
quality=quality,
|
||||
response_format=response_format,
|
||||
size=size,
|
||||
style=style,
|
||||
user=user,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
**non_default_params,
|
||||
)
|
||||
logging = litellm_logging_obj
|
||||
logging.update_environment_variables(
|
||||
model=model,
|
||||
user=user,
|
||||
optional_params=optional_params,
|
||||
litellm_params={
|
||||
"timeout": timeout,
|
||||
"azure": False,
|
||||
"litellm_call_id": litellm_call_id,
|
||||
"logger_fn": logger_fn,
|
||||
"proxy_server_request": proxy_server_request,
|
||||
"model_info": model_info,
|
||||
"metadata": metadata,
|
||||
"preset_cache_key": None,
|
||||
"stream_response": {},
|
||||
},
|
||||
)
|
||||
|
||||
if custom_llm_provider == "azure":
|
||||
# azure configs
|
||||
api_type = get_secret("AZURE_API_TYPE") or "azure"
|
||||
|
||||
api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE")
|
||||
|
||||
api_version = (
|
||||
api_version or litellm.api_version or get_secret("AZURE_API_VERSION")
|
||||
model_response = litellm.utils.ImageResponse()
|
||||
if model is not None or custom_llm_provider is not None:
|
||||
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
|
||||
else:
|
||||
model = "dall-e-2"
|
||||
custom_llm_provider = "openai" # default to dall-e-2 on openai
|
||||
openai_params = [
|
||||
"user",
|
||||
"request_timeout",
|
||||
"api_base",
|
||||
"api_version",
|
||||
"api_key",
|
||||
"deployment_id",
|
||||
"organization",
|
||||
"base_url",
|
||||
"default_headers",
|
||||
"timeout",
|
||||
"max_retries",
|
||||
"n",
|
||||
"quality",
|
||||
"size",
|
||||
"style",
|
||||
]
|
||||
litellm_params = [
|
||||
"metadata",
|
||||
"aimg_generation",
|
||||
"caching",
|
||||
"mock_response",
|
||||
"api_key",
|
||||
"api_version",
|
||||
"api_base",
|
||||
"force_timeout",
|
||||
"logger_fn",
|
||||
"verbose",
|
||||
"custom_llm_provider",
|
||||
"litellm_logging_obj",
|
||||
"litellm_call_id",
|
||||
"use_client",
|
||||
"id",
|
||||
"fallbacks",
|
||||
"azure",
|
||||
"headers",
|
||||
"model_list",
|
||||
"num_retries",
|
||||
"context_window_fallback_dict",
|
||||
"roles",
|
||||
"final_prompt_value",
|
||||
"bos_token",
|
||||
"eos_token",
|
||||
"request_timeout",
|
||||
"complete_response",
|
||||
"self",
|
||||
"client",
|
||||
"rpm",
|
||||
"tpm",
|
||||
"input_cost_per_token",
|
||||
"output_cost_per_token",
|
||||
"hf_model_name",
|
||||
"proxy_server_request",
|
||||
"model_info",
|
||||
"preset_cache_key",
|
||||
"caching_groups",
|
||||
"ttl",
|
||||
"cache",
|
||||
]
|
||||
default_params = openai_params + litellm_params
|
||||
non_default_params = {
|
||||
k: v for k, v in kwargs.items() if k not in default_params
|
||||
} # model-specific params - pass them straight to the model/provider
|
||||
optional_params = get_optional_params_image_gen(
|
||||
n=n,
|
||||
quality=quality,
|
||||
response_format=response_format,
|
||||
size=size,
|
||||
style=style,
|
||||
user=user,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
**non_default_params,
|
||||
)
|
||||
|
||||
api_key = (
|
||||
api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret("AZURE_OPENAI_API_KEY")
|
||||
or get_secret("AZURE_API_KEY")
|
||||
)
|
||||
|
||||
azure_ad_token = optional_params.pop("azure_ad_token", None) or get_secret(
|
||||
"AZURE_AD_TOKEN"
|
||||
)
|
||||
|
||||
model_response = azure_chat_completions.image_generation(
|
||||
logging = litellm_logging_obj
|
||||
logging.update_environment_variables(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
timeout=timeout,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
logging_obj=litellm_logging_obj,
|
||||
user=user,
|
||||
optional_params=optional_params,
|
||||
model_response=model_response,
|
||||
api_version=api_version,
|
||||
aimg_generation=aimg_generation,
|
||||
)
|
||||
elif custom_llm_provider == "openai":
|
||||
model_response = openai_chat_completions.image_generation(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
timeout=timeout,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
logging_obj=litellm_logging_obj,
|
||||
optional_params=optional_params,
|
||||
model_response=model_response,
|
||||
aimg_generation=aimg_generation,
|
||||
litellm_params={
|
||||
"timeout": timeout,
|
||||
"azure": False,
|
||||
"litellm_call_id": litellm_call_id,
|
||||
"logger_fn": logger_fn,
|
||||
"proxy_server_request": proxy_server_request,
|
||||
"model_info": model_info,
|
||||
"metadata": metadata,
|
||||
"preset_cache_key": None,
|
||||
"stream_response": {},
|
||||
},
|
||||
)
|
||||
|
||||
return model_response
|
||||
if custom_llm_provider == "azure":
|
||||
# azure configs
|
||||
api_type = get_secret("AZURE_API_TYPE") or "azure"
|
||||
|
||||
api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE")
|
||||
|
||||
api_version = (
|
||||
api_version or litellm.api_version or get_secret("AZURE_API_VERSION")
|
||||
)
|
||||
|
||||
api_key = (
|
||||
api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret("AZURE_OPENAI_API_KEY")
|
||||
or get_secret("AZURE_API_KEY")
|
||||
)
|
||||
|
||||
azure_ad_token = optional_params.pop("azure_ad_token", None) or get_secret(
|
||||
"AZURE_AD_TOKEN"
|
||||
)
|
||||
|
||||
model_response = azure_chat_completions.image_generation(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
timeout=timeout,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
logging_obj=litellm_logging_obj,
|
||||
optional_params=optional_params,
|
||||
model_response=model_response,
|
||||
api_version=api_version,
|
||||
aimg_generation=aimg_generation,
|
||||
)
|
||||
elif custom_llm_provider == "openai":
|
||||
model_response = openai_chat_completions.image_generation(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
timeout=timeout,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
logging_obj=litellm_logging_obj,
|
||||
optional_params=optional_params,
|
||||
model_response=model_response,
|
||||
aimg_generation=aimg_generation,
|
||||
)
|
||||
|
||||
return model_response
|
||||
except Exception as e:
|
||||
## Map to OpenAI Exception
|
||||
raise exception_type(
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
original_exception=e,
|
||||
completion_kwargs=locals(),
|
||||
)
|
||||
|
||||
|
||||
##### Health Endpoints #######################
|
||||
|
@ -3170,7 +3179,8 @@ def config_completion(**kwargs):
|
|||
"No config path set, please set a config path using `litellm.config_path = 'path/to/config.json'`"
|
||||
)
|
||||
|
||||
def stream_chunk_builder_text_completion(chunks: list, messages: Optional[List]=None):
|
||||
|
||||
def stream_chunk_builder_text_completion(chunks: list, messages: Optional[List] = None):
|
||||
id = chunks[0]["id"]
|
||||
object = chunks[0]["object"]
|
||||
created = chunks[0]["created"]
|
||||
|
@ -3187,23 +3197,27 @@ def stream_chunk_builder_text_completion(chunks: list, messages: Optional[List]=
|
|||
"system_fingerprint": system_fingerprint,
|
||||
"choices": [
|
||||
{
|
||||
"text": None,
|
||||
"index": 0,
|
||||
"logprobs": logprobs,
|
||||
"finish_reason": finish_reason
|
||||
"text": None,
|
||||
"index": 0,
|
||||
"logprobs": logprobs,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": None,
|
||||
"completion_tokens": None,
|
||||
"total_tokens": None
|
||||
}
|
||||
"total_tokens": None,
|
||||
},
|
||||
}
|
||||
content_list = []
|
||||
for chunk in chunks:
|
||||
choices = chunk["choices"]
|
||||
for choice in choices:
|
||||
if choice is not None and hasattr(choice, "text") and choice.get("text") is not None:
|
||||
if (
|
||||
choice is not None
|
||||
and hasattr(choice, "text")
|
||||
and choice.get("text") is not None
|
||||
):
|
||||
_choice = choice.get("text")
|
||||
content_list.append(_choice)
|
||||
|
||||
|
@ -3235,13 +3249,18 @@ def stream_chunk_builder_text_completion(chunks: list, messages: Optional[List]=
|
|||
)
|
||||
return response
|
||||
|
||||
|
||||
def stream_chunk_builder(chunks: list, messages: Optional[list] = None):
|
||||
id = chunks[0]["id"]
|
||||
object = chunks[0]["object"]
|
||||
created = chunks[0]["created"]
|
||||
model = chunks[0]["model"]
|
||||
system_fingerprint = chunks[0].get("system_fingerprint", None)
|
||||
if isinstance(chunks[0]["choices"][0], litellm.utils.TextChoices): # route to the text completion logic
|
||||
|
||||
if isinstance(
|
||||
chunks[0]["choices"][0], litellm.utils.TextChoices
|
||||
): # route to the text completion logic
|
||||
|
||||
return stream_chunk_builder_text_completion(chunks=chunks, messages=messages)
|
||||
role = chunks[0]["choices"][0]["delta"]["role"]
|
||||
finish_reason = chunks[-1]["choices"][0]["finish_reason"]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue