forked from phoenix/litellm-mirror
Merge branch 'main' into explicit-args-acomplete
This commit is contained in:
commit
203089e6c7
10 changed files with 271 additions and 477 deletions
|
@ -1,49 +0,0 @@
|
||||||
# Model Config
|
|
||||||
|
|
||||||
Model-specific changes can make our code complicated, making it harder to debug errors. Use model configs to simplify this.
|
|
||||||
|
|
||||||
### usage
|
|
||||||
|
|
||||||
Handling prompt logic. Different models have different context windows. Use `adapt_to_prompt_size` to select the right model for the prompt (in case the current model is too small).
|
|
||||||
|
|
||||||
|
|
||||||
```python
|
|
||||||
from litellm import completion_with_config
|
|
||||||
import os
|
|
||||||
|
|
||||||
config = {
|
|
||||||
"available_models": ["gpt-3.5-turbo", "claude-instant-1", "gpt-3.5-turbo-16k"],
|
|
||||||
"adapt_to_prompt_size": True, # 👈 key change
|
|
||||||
}
|
|
||||||
|
|
||||||
# set env var
|
|
||||||
os.environ["OPENAI_API_KEY"] = "your-api-key"
|
|
||||||
os.environ["ANTHROPIC_API_KEY"] = "your-api-key"
|
|
||||||
|
|
||||||
|
|
||||||
sample_text = "how does a court case get to the Supreme Court?" * 1000
|
|
||||||
messages = [{"content": sample_text, "role": "user"}]
|
|
||||||
response = completion_with_config(model="gpt-3.5-turbo", messages=messages, config=config)
|
|
||||||
```
|
|
||||||
|
|
||||||
[**See Code**](https://github.com/BerriAI/litellm/blob/30724d9e51cdc2c3e0eb063271b4f171bc01b382/litellm/utils.py#L2783)
|
|
||||||
|
|
||||||
### Complete Config Structure
|
|
||||||
|
|
||||||
```python
|
|
||||||
config = {
|
|
||||||
"default_fallback_models": # [Optional] List of model names to try if a call fails
|
|
||||||
"available_models": # [Optional] List of all possible models you could call
|
|
||||||
"adapt_to_prompt_size": # [Optional] True/False - if you want to select model based on prompt size (will pick from available_models)
|
|
||||||
"model": {
|
|
||||||
"model-name": {
|
|
||||||
"needs_moderation": # [Optional] True/False - if you want to call openai moderations endpoint before making completion call. Will raise exception, if flagged.
|
|
||||||
"error_handling": {
|
|
||||||
"error-type": { # One of the errors listed here - https://docs.litellm.ai/docs/exception_mapping#custom-mapping-list
|
|
||||||
"fallback_model": "" # str, name of the model it should try instead, when that error occurs
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
|
@ -12,6 +12,7 @@ LiteLLM maps exceptions across all providers to their OpenAI counterparts.
|
||||||
| 429 | RateLimitError |
|
| 429 | RateLimitError |
|
||||||
| >=500 | InternalServerError |
|
| >=500 | InternalServerError |
|
||||||
| N/A | ContextWindowExceededError|
|
| N/A | ContextWindowExceededError|
|
||||||
|
| 400 | ContentPolicyViolationError|
|
||||||
| N/A | APIConnectionError |
|
| N/A | APIConnectionError |
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -129,26 +129,6 @@ const sidebars = {
|
||||||
"secret",
|
"secret",
|
||||||
"completion/token_usage",
|
"completion/token_usage",
|
||||||
"load_test",
|
"load_test",
|
||||||
{
|
|
||||||
type: 'category',
|
|
||||||
label: 'Tutorials',
|
|
||||||
items: [
|
|
||||||
'tutorials/azure_openai',
|
|
||||||
"tutorials/lm_evaluation_harness",
|
|
||||||
"tutorials/eval_suites",
|
|
||||||
'tutorials/oobabooga',
|
|
||||||
"tutorials/gradio_integration",
|
|
||||||
'tutorials/huggingface_codellama',
|
|
||||||
'tutorials/huggingface_tutorial',
|
|
||||||
'tutorials/TogetherAI_liteLLM',
|
|
||||||
'tutorials/finetuned_chat_gpt',
|
|
||||||
'tutorials/sagemaker_llms',
|
|
||||||
'tutorials/text_completion',
|
|
||||||
"tutorials/first_playground",
|
|
||||||
'tutorials/compare_llms',
|
|
||||||
"tutorials/model_fallbacks",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
type: "category",
|
type: "category",
|
||||||
label: "Logging & Observability",
|
label: "Logging & Observability",
|
||||||
|
@ -170,6 +150,23 @@ const sidebars = {
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
"caching/redis_cache",
|
"caching/redis_cache",
|
||||||
|
{
|
||||||
|
type: 'category',
|
||||||
|
label: 'Tutorials',
|
||||||
|
items: [
|
||||||
|
'tutorials/azure_openai',
|
||||||
|
'tutorials/oobabooga',
|
||||||
|
"tutorials/gradio_integration",
|
||||||
|
'tutorials/huggingface_codellama',
|
||||||
|
'tutorials/huggingface_tutorial',
|
||||||
|
'tutorials/TogetherAI_liteLLM',
|
||||||
|
'tutorials/finetuned_chat_gpt',
|
||||||
|
'tutorials/sagemaker_llms',
|
||||||
|
'tutorials/text_completion',
|
||||||
|
"tutorials/first_playground",
|
||||||
|
"tutorials/model_fallbacks",
|
||||||
|
],
|
||||||
|
},
|
||||||
{
|
{
|
||||||
type: "category",
|
type: "category",
|
||||||
label: "LangChain, LlamaIndex Integration",
|
label: "LangChain, LlamaIndex Integration",
|
||||||
|
|
|
@ -500,7 +500,6 @@ from .utils import (
|
||||||
validate_environment,
|
validate_environment,
|
||||||
check_valid_key,
|
check_valid_key,
|
||||||
get_llm_provider,
|
get_llm_provider,
|
||||||
completion_with_config,
|
|
||||||
register_model,
|
register_model,
|
||||||
encode,
|
encode,
|
||||||
decode,
|
decode,
|
||||||
|
@ -544,6 +543,7 @@ from .exceptions import (
|
||||||
ServiceUnavailableError,
|
ServiceUnavailableError,
|
||||||
OpenAIError,
|
OpenAIError,
|
||||||
ContextWindowExceededError,
|
ContextWindowExceededError,
|
||||||
|
ContentPolicyViolationError,
|
||||||
BudgetExceededError,
|
BudgetExceededError,
|
||||||
APIError,
|
APIError,
|
||||||
Timeout,
|
Timeout,
|
||||||
|
|
|
@ -108,6 +108,21 @@ class ContextWindowExceededError(BadRequestError): # type: ignore
|
||||||
) # Call the base class constructor with the parameters it needs
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
|
||||||
|
class ContentPolicyViolationError(BadRequestError): # type: ignore
|
||||||
|
# Error code: 400 - {'error': {'code': 'content_policy_violation', 'message': 'Your request was rejected as a result of our safety system. Image descriptions generated from your prompt may contain text that is not allowed by our safety system. If you believe this was done in error, your request may succeed if retried, or by adjusting your prompt.', 'param': None, 'type': 'invalid_request_error'}}
|
||||||
|
def __init__(self, message, model, llm_provider, response: httpx.Response):
|
||||||
|
self.status_code = 400
|
||||||
|
self.message = message
|
||||||
|
self.model = model
|
||||||
|
self.llm_provider = llm_provider
|
||||||
|
super().__init__(
|
||||||
|
message=self.message,
|
||||||
|
model=self.model, # type: ignore
|
||||||
|
llm_provider=self.llm_provider, # type: ignore
|
||||||
|
response=response,
|
||||||
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
|
||||||
class ServiceUnavailableError(APIStatusError): # type: ignore
|
class ServiceUnavailableError(APIStatusError): # type: ignore
|
||||||
def __init__(self, message, llm_provider, model, response: httpx.Response):
|
def __init__(self, message, llm_provider, model, response: httpx.Response):
|
||||||
self.status_code = 503
|
self.status_code = 503
|
||||||
|
|
331
litellm/main.py
331
litellm/main.py
|
@ -1173,7 +1173,7 @@ def completion(
|
||||||
acompletion=acompletion,
|
acompletion=acompletion,
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
custom_prompt_dict=custom_prompt_dict,
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
timeout=timeout
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
"stream" in optional_params
|
"stream" in optional_params
|
||||||
|
@ -2894,158 +2894,167 @@ def image_generation(
|
||||||
|
|
||||||
Currently supports just Azure + OpenAI.
|
Currently supports just Azure + OpenAI.
|
||||||
"""
|
"""
|
||||||
aimg_generation = kwargs.get("aimg_generation", False)
|
try:
|
||||||
litellm_call_id = kwargs.get("litellm_call_id", None)
|
aimg_generation = kwargs.get("aimg_generation", False)
|
||||||
logger_fn = kwargs.get("logger_fn", None)
|
litellm_call_id = kwargs.get("litellm_call_id", None)
|
||||||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
logger_fn = kwargs.get("logger_fn", None)
|
||||||
model_info = kwargs.get("model_info", None)
|
proxy_server_request = kwargs.get("proxy_server_request", None)
|
||||||
metadata = kwargs.get("metadata", {})
|
model_info = kwargs.get("model_info", None)
|
||||||
|
metadata = kwargs.get("metadata", {})
|
||||||
|
|
||||||
model_response = litellm.utils.ImageResponse()
|
model_response = litellm.utils.ImageResponse()
|
||||||
if model is not None or custom_llm_provider is not None:
|
if model is not None or custom_llm_provider is not None:
|
||||||
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
|
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
|
||||||
else:
|
else:
|
||||||
model = "dall-e-2"
|
model = "dall-e-2"
|
||||||
custom_llm_provider = "openai" # default to dall-e-2 on openai
|
custom_llm_provider = "openai" # default to dall-e-2 on openai
|
||||||
openai_params = [
|
openai_params = [
|
||||||
"user",
|
"user",
|
||||||
"request_timeout",
|
"request_timeout",
|
||||||
"api_base",
|
"api_base",
|
||||||
"api_version",
|
"api_version",
|
||||||
"api_key",
|
"api_key",
|
||||||
"deployment_id",
|
"deployment_id",
|
||||||
"organization",
|
"organization",
|
||||||
"base_url",
|
"base_url",
|
||||||
"default_headers",
|
"default_headers",
|
||||||
"timeout",
|
"timeout",
|
||||||
"max_retries",
|
"max_retries",
|
||||||
"n",
|
"n",
|
||||||
"quality",
|
"quality",
|
||||||
"size",
|
"size",
|
||||||
"style",
|
"style",
|
||||||
]
|
]
|
||||||
litellm_params = [
|
litellm_params = [
|
||||||
"metadata",
|
"metadata",
|
||||||
"aimg_generation",
|
"aimg_generation",
|
||||||
"caching",
|
"caching",
|
||||||
"mock_response",
|
"mock_response",
|
||||||
"api_key",
|
"api_key",
|
||||||
"api_version",
|
"api_version",
|
||||||
"api_base",
|
"api_base",
|
||||||
"force_timeout",
|
"force_timeout",
|
||||||
"logger_fn",
|
"logger_fn",
|
||||||
"verbose",
|
"verbose",
|
||||||
"custom_llm_provider",
|
"custom_llm_provider",
|
||||||
"litellm_logging_obj",
|
"litellm_logging_obj",
|
||||||
"litellm_call_id",
|
"litellm_call_id",
|
||||||
"use_client",
|
"use_client",
|
||||||
"id",
|
"id",
|
||||||
"fallbacks",
|
"fallbacks",
|
||||||
"azure",
|
"azure",
|
||||||
"headers",
|
"headers",
|
||||||
"model_list",
|
"model_list",
|
||||||
"num_retries",
|
"num_retries",
|
||||||
"context_window_fallback_dict",
|
"context_window_fallback_dict",
|
||||||
"roles",
|
"roles",
|
||||||
"final_prompt_value",
|
"final_prompt_value",
|
||||||
"bos_token",
|
"bos_token",
|
||||||
"eos_token",
|
"eos_token",
|
||||||
"request_timeout",
|
"request_timeout",
|
||||||
"complete_response",
|
"complete_response",
|
||||||
"self",
|
"self",
|
||||||
"client",
|
"client",
|
||||||
"rpm",
|
"rpm",
|
||||||
"tpm",
|
"tpm",
|
||||||
"input_cost_per_token",
|
"input_cost_per_token",
|
||||||
"output_cost_per_token",
|
"output_cost_per_token",
|
||||||
"hf_model_name",
|
"hf_model_name",
|
||||||
"proxy_server_request",
|
"proxy_server_request",
|
||||||
"model_info",
|
"model_info",
|
||||||
"preset_cache_key",
|
"preset_cache_key",
|
||||||
"caching_groups",
|
"caching_groups",
|
||||||
"ttl",
|
"ttl",
|
||||||
"cache",
|
"cache",
|
||||||
]
|
]
|
||||||
default_params = openai_params + litellm_params
|
default_params = openai_params + litellm_params
|
||||||
non_default_params = {
|
non_default_params = {
|
||||||
k: v for k, v in kwargs.items() if k not in default_params
|
k: v for k, v in kwargs.items() if k not in default_params
|
||||||
} # model-specific params - pass them straight to the model/provider
|
} # model-specific params - pass them straight to the model/provider
|
||||||
optional_params = get_optional_params_image_gen(
|
optional_params = get_optional_params_image_gen(
|
||||||
n=n,
|
n=n,
|
||||||
quality=quality,
|
quality=quality,
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
size=size,
|
size=size,
|
||||||
style=style,
|
style=style,
|
||||||
user=user,
|
user=user,
|
||||||
custom_llm_provider=custom_llm_provider,
|
custom_llm_provider=custom_llm_provider,
|
||||||
**non_default_params,
|
**non_default_params,
|
||||||
)
|
|
||||||
logging = litellm_logging_obj
|
|
||||||
logging.update_environment_variables(
|
|
||||||
model=model,
|
|
||||||
user=user,
|
|
||||||
optional_params=optional_params,
|
|
||||||
litellm_params={
|
|
||||||
"timeout": timeout,
|
|
||||||
"azure": False,
|
|
||||||
"litellm_call_id": litellm_call_id,
|
|
||||||
"logger_fn": logger_fn,
|
|
||||||
"proxy_server_request": proxy_server_request,
|
|
||||||
"model_info": model_info,
|
|
||||||
"metadata": metadata,
|
|
||||||
"preset_cache_key": None,
|
|
||||||
"stream_response": {},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
if custom_llm_provider == "azure":
|
|
||||||
# azure configs
|
|
||||||
api_type = get_secret("AZURE_API_TYPE") or "azure"
|
|
||||||
|
|
||||||
api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE")
|
|
||||||
|
|
||||||
api_version = (
|
|
||||||
api_version or litellm.api_version or get_secret("AZURE_API_VERSION")
|
|
||||||
)
|
)
|
||||||
|
logging = litellm_logging_obj
|
||||||
api_key = (
|
logging.update_environment_variables(
|
||||||
api_key
|
|
||||||
or litellm.api_key
|
|
||||||
or litellm.azure_key
|
|
||||||
or get_secret("AZURE_OPENAI_API_KEY")
|
|
||||||
or get_secret("AZURE_API_KEY")
|
|
||||||
)
|
|
||||||
|
|
||||||
azure_ad_token = optional_params.pop("azure_ad_token", None) or get_secret(
|
|
||||||
"AZURE_AD_TOKEN"
|
|
||||||
)
|
|
||||||
|
|
||||||
model_response = azure_chat_completions.image_generation(
|
|
||||||
model=model,
|
model=model,
|
||||||
prompt=prompt,
|
user=user,
|
||||||
timeout=timeout,
|
|
||||||
api_key=api_key,
|
|
||||||
api_base=api_base,
|
|
||||||
logging_obj=litellm_logging_obj,
|
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
model_response=model_response,
|
litellm_params={
|
||||||
api_version=api_version,
|
"timeout": timeout,
|
||||||
aimg_generation=aimg_generation,
|
"azure": False,
|
||||||
)
|
"litellm_call_id": litellm_call_id,
|
||||||
elif custom_llm_provider == "openai":
|
"logger_fn": logger_fn,
|
||||||
model_response = openai_chat_completions.image_generation(
|
"proxy_server_request": proxy_server_request,
|
||||||
model=model,
|
"model_info": model_info,
|
||||||
prompt=prompt,
|
"metadata": metadata,
|
||||||
timeout=timeout,
|
"preset_cache_key": None,
|
||||||
api_key=api_key,
|
"stream_response": {},
|
||||||
api_base=api_base,
|
},
|
||||||
logging_obj=litellm_logging_obj,
|
|
||||||
optional_params=optional_params,
|
|
||||||
model_response=model_response,
|
|
||||||
aimg_generation=aimg_generation,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return model_response
|
if custom_llm_provider == "azure":
|
||||||
|
# azure configs
|
||||||
|
api_type = get_secret("AZURE_API_TYPE") or "azure"
|
||||||
|
|
||||||
|
api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE")
|
||||||
|
|
||||||
|
api_version = (
|
||||||
|
api_version or litellm.api_version or get_secret("AZURE_API_VERSION")
|
||||||
|
)
|
||||||
|
|
||||||
|
api_key = (
|
||||||
|
api_key
|
||||||
|
or litellm.api_key
|
||||||
|
or litellm.azure_key
|
||||||
|
or get_secret("AZURE_OPENAI_API_KEY")
|
||||||
|
or get_secret("AZURE_API_KEY")
|
||||||
|
)
|
||||||
|
|
||||||
|
azure_ad_token = optional_params.pop("azure_ad_token", None) or get_secret(
|
||||||
|
"AZURE_AD_TOKEN"
|
||||||
|
)
|
||||||
|
|
||||||
|
model_response = azure_chat_completions.image_generation(
|
||||||
|
model=model,
|
||||||
|
prompt=prompt,
|
||||||
|
timeout=timeout,
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
logging_obj=litellm_logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
model_response=model_response,
|
||||||
|
api_version=api_version,
|
||||||
|
aimg_generation=aimg_generation,
|
||||||
|
)
|
||||||
|
elif custom_llm_provider == "openai":
|
||||||
|
model_response = openai_chat_completions.image_generation(
|
||||||
|
model=model,
|
||||||
|
prompt=prompt,
|
||||||
|
timeout=timeout,
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
logging_obj=litellm_logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
model_response=model_response,
|
||||||
|
aimg_generation=aimg_generation,
|
||||||
|
)
|
||||||
|
|
||||||
|
return model_response
|
||||||
|
except Exception as e:
|
||||||
|
## Map to OpenAI Exception
|
||||||
|
raise exception_type(
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
original_exception=e,
|
||||||
|
completion_kwargs=locals(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
##### Health Endpoints #######################
|
##### Health Endpoints #######################
|
||||||
|
@ -3170,7 +3179,8 @@ def config_completion(**kwargs):
|
||||||
"No config path set, please set a config path using `litellm.config_path = 'path/to/config.json'`"
|
"No config path set, please set a config path using `litellm.config_path = 'path/to/config.json'`"
|
||||||
)
|
)
|
||||||
|
|
||||||
def stream_chunk_builder_text_completion(chunks: list, messages: Optional[List]=None):
|
|
||||||
|
def stream_chunk_builder_text_completion(chunks: list, messages: Optional[List] = None):
|
||||||
id = chunks[0]["id"]
|
id = chunks[0]["id"]
|
||||||
object = chunks[0]["object"]
|
object = chunks[0]["object"]
|
||||||
created = chunks[0]["created"]
|
created = chunks[0]["created"]
|
||||||
|
@ -3187,23 +3197,27 @@ def stream_chunk_builder_text_completion(chunks: list, messages: Optional[List]=
|
||||||
"system_fingerprint": system_fingerprint,
|
"system_fingerprint": system_fingerprint,
|
||||||
"choices": [
|
"choices": [
|
||||||
{
|
{
|
||||||
"text": None,
|
"text": None,
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"logprobs": logprobs,
|
"logprobs": logprobs,
|
||||||
"finish_reason": finish_reason
|
"finish_reason": finish_reason,
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"usage": {
|
"usage": {
|
||||||
"prompt_tokens": None,
|
"prompt_tokens": None,
|
||||||
"completion_tokens": None,
|
"completion_tokens": None,
|
||||||
"total_tokens": None
|
"total_tokens": None,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
content_list = []
|
content_list = []
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
choices = chunk["choices"]
|
choices = chunk["choices"]
|
||||||
for choice in choices:
|
for choice in choices:
|
||||||
if choice is not None and hasattr(choice, "text") and choice.get("text") is not None:
|
if (
|
||||||
|
choice is not None
|
||||||
|
and hasattr(choice, "text")
|
||||||
|
and choice.get("text") is not None
|
||||||
|
):
|
||||||
_choice = choice.get("text")
|
_choice = choice.get("text")
|
||||||
content_list.append(_choice)
|
content_list.append(_choice)
|
||||||
|
|
||||||
|
@ -3235,13 +3249,18 @@ def stream_chunk_builder_text_completion(chunks: list, messages: Optional[List]=
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
def stream_chunk_builder(chunks: list, messages: Optional[list] = None):
|
def stream_chunk_builder(chunks: list, messages: Optional[list] = None):
|
||||||
id = chunks[0]["id"]
|
id = chunks[0]["id"]
|
||||||
object = chunks[0]["object"]
|
object = chunks[0]["object"]
|
||||||
created = chunks[0]["created"]
|
created = chunks[0]["created"]
|
||||||
model = chunks[0]["model"]
|
model = chunks[0]["model"]
|
||||||
system_fingerprint = chunks[0].get("system_fingerprint", None)
|
system_fingerprint = chunks[0].get("system_fingerprint", None)
|
||||||
if isinstance(chunks[0]["choices"][0], litellm.utils.TextChoices): # route to the text completion logic
|
|
||||||
|
if isinstance(
|
||||||
|
chunks[0]["choices"][0], litellm.utils.TextChoices
|
||||||
|
): # route to the text completion logic
|
||||||
|
|
||||||
return stream_chunk_builder_text_completion(chunks=chunks, messages=messages)
|
return stream_chunk_builder_text_completion(chunks=chunks, messages=messages)
|
||||||
role = chunks[0]["choices"][0]["delta"]["role"]
|
role = chunks[0]["choices"][0]["delta"]["role"]
|
||||||
finish_reason = chunks[-1]["choices"][0]["finish_reason"]
|
finish_reason = chunks[-1]["choices"][0]["finish_reason"]
|
||||||
|
|
|
@ -1,118 +0,0 @@
|
||||||
import sys, os
|
|
||||||
import traceback
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
import os
|
|
||||||
|
|
||||||
sys.path.insert(
|
|
||||||
0, os.path.abspath("../..")
|
|
||||||
) # Adds the parent directory to the system path
|
|
||||||
import pytest
|
|
||||||
import litellm
|
|
||||||
from litellm import completion_with_config
|
|
||||||
|
|
||||||
config = {
|
|
||||||
"default_fallback_models": ["gpt-3.5-turbo", "claude-instant-1", "j2-ultra"],
|
|
||||||
"model": {
|
|
||||||
"claude-instant-1": {"needs_moderation": True},
|
|
||||||
"gpt-3.5-turbo": {
|
|
||||||
"error_handling": {
|
|
||||||
"ContextWindowExceededError": {"fallback_model": "gpt-3.5-turbo-16k"}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def test_config_context_window_exceeded():
|
|
||||||
try:
|
|
||||||
sample_text = "how does a court case get to the Supreme Court?" * 1000
|
|
||||||
messages = [{"content": sample_text, "role": "user"}]
|
|
||||||
response = completion_with_config(
|
|
||||||
model="gpt-3.5-turbo", messages=messages, config=config
|
|
||||||
)
|
|
||||||
print(response)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Exception: {e}")
|
|
||||||
pytest.fail(f"An exception occurred: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
# test_config_context_window_exceeded()
|
|
||||||
|
|
||||||
|
|
||||||
def test_config_context_moderation():
|
|
||||||
try:
|
|
||||||
messages = [{"role": "user", "content": "I want to kill them."}]
|
|
||||||
response = completion_with_config(
|
|
||||||
model="claude-instant-1", messages=messages, config=config
|
|
||||||
)
|
|
||||||
print(response)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Exception: {e}")
|
|
||||||
pytest.fail(f"An exception occurred: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
# test_config_context_moderation()
|
|
||||||
|
|
||||||
|
|
||||||
def test_config_context_default_fallback():
|
|
||||||
try:
|
|
||||||
messages = [{"role": "user", "content": "Hey, how's it going?"}]
|
|
||||||
response = completion_with_config(
|
|
||||||
model="claude-instant-1",
|
|
||||||
messages=messages,
|
|
||||||
config=config,
|
|
||||||
api_key="bad-key",
|
|
||||||
)
|
|
||||||
print(response)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Exception: {e}")
|
|
||||||
pytest.fail(f"An exception occurred: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
# test_config_context_default_fallback()
|
|
||||||
|
|
||||||
|
|
||||||
config = {
|
|
||||||
"default_fallback_models": ["gpt-3.5-turbo", "claude-instant-1", "j2-ultra"],
|
|
||||||
"available_models": [
|
|
||||||
"gpt-3.5-turbo",
|
|
||||||
"gpt-3.5-turbo-0301",
|
|
||||||
"gpt-3.5-turbo-0613",
|
|
||||||
"gpt-4",
|
|
||||||
"gpt-4-0314",
|
|
||||||
"gpt-4-0613",
|
|
||||||
"j2-ultra",
|
|
||||||
"command-nightly",
|
|
||||||
"togethercomputer/llama-2-70b-chat",
|
|
||||||
"chat-bison",
|
|
||||||
"chat-bison@001",
|
|
||||||
"claude-2",
|
|
||||||
],
|
|
||||||
"adapt_to_prompt_size": True, # type: ignore
|
|
||||||
"model": {
|
|
||||||
"claude-instant-1": {"needs_moderation": True},
|
|
||||||
"gpt-3.5-turbo": {
|
|
||||||
"error_handling": {
|
|
||||||
"ContextWindowExceededError": {"fallback_model": "gpt-3.5-turbo-16k"}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def test_config_context_adapt_to_prompt():
|
|
||||||
try:
|
|
||||||
sample_text = "how does a court case get to the Supreme Court?" * 1000
|
|
||||||
messages = [{"content": sample_text, "role": "user"}]
|
|
||||||
response = completion_with_config(
|
|
||||||
model="gpt-3.5-turbo", messages=messages, config=config
|
|
||||||
)
|
|
||||||
print(response)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Exception: {e}")
|
|
||||||
pytest.fail(f"An exception occurred: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
test_config_context_adapt_to_prompt()
|
|
|
@ -352,6 +352,25 @@ def test_completion_mistral_exception():
|
||||||
# test_completion_mistral_exception()
|
# test_completion_mistral_exception()
|
||||||
|
|
||||||
|
|
||||||
|
def test_content_policy_exceptionimage_generation_openai():
|
||||||
|
try:
|
||||||
|
# this is ony a test - we needed some way to invoke the exception :(
|
||||||
|
litellm.set_verbose = True
|
||||||
|
response = litellm.image_generation(
|
||||||
|
prompt="where do i buy lethal drugs from", model="dall-e-3"
|
||||||
|
)
|
||||||
|
print(f"response: {response}")
|
||||||
|
assert len(response.data) > 0
|
||||||
|
except litellm.ContentPolicyViolationError as e:
|
||||||
|
print("caught a content policy violation error! Passed")
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"An exception occurred - {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
# test_content_policy_exceptionimage_generation_openai()
|
||||||
|
|
||||||
|
|
||||||
# # test_invalid_request_error(model="command-nightly")
|
# # test_invalid_request_error(model="command-nightly")
|
||||||
# # Test 3: Rate Limit Errors
|
# # Test 3: Rate Limit Errors
|
||||||
# def test_model_call(model):
|
# def test_model_call(model):
|
||||||
|
|
|
@ -19,7 +19,7 @@ import litellm
|
||||||
|
|
||||||
|
|
||||||
def test_image_generation_openai():
|
def test_image_generation_openai():
|
||||||
try:
|
try:
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
response = litellm.image_generation(
|
response = litellm.image_generation(
|
||||||
prompt="A cute baby sea otter", model="dall-e-3"
|
prompt="A cute baby sea otter", model="dall-e-3"
|
||||||
|
@ -28,6 +28,8 @@ def test_image_generation_openai():
|
||||||
assert len(response.data) > 0
|
assert len(response.data) > 0
|
||||||
except litellm.RateLimitError as e:
|
except litellm.RateLimitError as e:
|
||||||
pass
|
pass
|
||||||
|
except litellm.ContentPolicyViolationError:
|
||||||
|
pass # OpenAI randomly raises these errors - skip when they occur
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"An exception occurred - {str(e)}")
|
pytest.fail(f"An exception occurred - {str(e)}")
|
||||||
|
|
||||||
|
@ -36,22 +38,27 @@ def test_image_generation_openai():
|
||||||
|
|
||||||
|
|
||||||
def test_image_generation_azure():
|
def test_image_generation_azure():
|
||||||
try:
|
try:
|
||||||
response = litellm.image_generation(
|
response = litellm.image_generation(
|
||||||
prompt="A cute baby sea otter", model="azure/", api_version="2023-06-01-preview"
|
prompt="A cute baby sea otter",
|
||||||
|
model="azure/",
|
||||||
|
api_version="2023-06-01-preview",
|
||||||
)
|
)
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
assert len(response.data) > 0
|
assert len(response.data) > 0
|
||||||
except litellm.RateLimitError as e:
|
except litellm.RateLimitError as e:
|
||||||
pass
|
pass
|
||||||
|
except litellm.ContentPolicyViolationError:
|
||||||
|
pass # Azure randomly raises these errors - skip when they occur
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"An exception occurred - {str(e)}")
|
pytest.fail(f"An exception occurred - {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
# test_image_generation_azure()
|
# test_image_generation_azure()
|
||||||
|
|
||||||
|
|
||||||
def test_image_generation_azure_dall_e_3():
|
def test_image_generation_azure_dall_e_3():
|
||||||
try:
|
try:
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
response = litellm.image_generation(
|
response = litellm.image_generation(
|
||||||
prompt="A cute baby sea otter",
|
prompt="A cute baby sea otter",
|
||||||
|
@ -64,6 +71,8 @@ def test_image_generation_azure_dall_e_3():
|
||||||
assert len(response.data) > 0
|
assert len(response.data) > 0
|
||||||
except litellm.RateLimitError as e:
|
except litellm.RateLimitError as e:
|
||||||
pass
|
pass
|
||||||
|
except litellm.ContentPolicyViolationError:
|
||||||
|
pass # OpenAI randomly raises these errors - skip when they occur
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"An exception occurred - {str(e)}")
|
pytest.fail(f"An exception occurred - {str(e)}")
|
||||||
|
|
||||||
|
@ -71,7 +80,7 @@ def test_image_generation_azure_dall_e_3():
|
||||||
# test_image_generation_azure_dall_e_3()
|
# test_image_generation_azure_dall_e_3()
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_async_image_generation_openai():
|
async def test_async_image_generation_openai():
|
||||||
try:
|
try:
|
||||||
response = litellm.image_generation(
|
response = litellm.image_generation(
|
||||||
prompt="A cute baby sea otter", model="dall-e-3"
|
prompt="A cute baby sea otter", model="dall-e-3"
|
||||||
)
|
)
|
||||||
|
@ -79,20 +88,25 @@ async def test_async_image_generation_openai():
|
||||||
assert len(response.data) > 0
|
assert len(response.data) > 0
|
||||||
except litellm.RateLimitError as e:
|
except litellm.RateLimitError as e:
|
||||||
pass
|
pass
|
||||||
|
except litellm.ContentPolicyViolationError:
|
||||||
|
pass # openai randomly raises these errors - skip when they occur
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"An exception occurred - {str(e)}")
|
pytest.fail(f"An exception occurred - {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
# asyncio.run(test_async_image_generation_openai())
|
# asyncio.run(test_async_image_generation_openai())
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_async_image_generation_azure():
|
async def test_async_image_generation_azure():
|
||||||
try:
|
try:
|
||||||
response = await litellm.aimage_generation(
|
response = await litellm.aimage_generation(
|
||||||
prompt="A cute baby sea otter", model="azure/dall-e-3-test"
|
prompt="A cute baby sea otter", model="azure/dall-e-3-test"
|
||||||
)
|
)
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
except litellm.RateLimitError as e:
|
except litellm.RateLimitError as e:
|
||||||
pass
|
pass
|
||||||
|
except litellm.ContentPolicyViolationError:
|
||||||
|
pass # Azure randomly raises these errors - skip when they occur
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"An exception occurred - {str(e)}")
|
pytest.fail(f"An exception occurred - {str(e)}")
|
||||||
|
|
150
litellm/utils.py
150
litellm/utils.py
|
@ -60,6 +60,7 @@ from .exceptions import (
|
||||||
ServiceUnavailableError,
|
ServiceUnavailableError,
|
||||||
OpenAIError,
|
OpenAIError,
|
||||||
ContextWindowExceededError,
|
ContextWindowExceededError,
|
||||||
|
ContentPolicyViolationError,
|
||||||
Timeout,
|
Timeout,
|
||||||
APIConnectionError,
|
APIConnectionError,
|
||||||
APIError,
|
APIError,
|
||||||
|
@ -5551,6 +5552,17 @@ def exception_type(
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=original_exception.response,
|
||||||
)
|
)
|
||||||
|
elif (
|
||||||
|
"invalid_request_error" in error_str
|
||||||
|
and "content_policy_violation" in error_str
|
||||||
|
):
|
||||||
|
exception_mapping_worked = True
|
||||||
|
raise ContentPolicyViolationError(
|
||||||
|
message=f"OpenAIException - {original_exception.message}",
|
||||||
|
llm_provider="openai",
|
||||||
|
model=model,
|
||||||
|
response=original_exception.response,
|
||||||
|
)
|
||||||
elif (
|
elif (
|
||||||
"invalid_request_error" in error_str
|
"invalid_request_error" in error_str
|
||||||
and "Incorrect API key provided" not in error_str
|
and "Incorrect API key provided" not in error_str
|
||||||
|
@ -6500,6 +6512,17 @@ def exception_type(
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=original_exception.response,
|
||||||
)
|
)
|
||||||
|
elif (
|
||||||
|
"invalid_request_error" in error_str
|
||||||
|
and "content_policy_violation" in error_str
|
||||||
|
):
|
||||||
|
exception_mapping_worked = True
|
||||||
|
raise ContentPolicyViolationError(
|
||||||
|
message=f"AzureException - {original_exception.message}",
|
||||||
|
llm_provider="azure",
|
||||||
|
model=model,
|
||||||
|
response=original_exception.response,
|
||||||
|
)
|
||||||
elif "invalid_request_error" in error_str:
|
elif "invalid_request_error" in error_str:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise BadRequestError(
|
raise BadRequestError(
|
||||||
|
@ -7846,133 +7869,6 @@ def read_config_args(config_path) -> dict:
|
||||||
########## experimental completion variants ############################
|
########## experimental completion variants ############################
|
||||||
|
|
||||||
|
|
||||||
def completion_with_config(config: Union[dict, str], **kwargs):
|
|
||||||
"""
|
|
||||||
Generate a litellm.completion() using a config dict and all supported completion args
|
|
||||||
|
|
||||||
Example config;
|
|
||||||
config = {
|
|
||||||
"default_fallback_models": # [Optional] List of model names to try if a call fails
|
|
||||||
"available_models": # [Optional] List of all possible models you could call
|
|
||||||
"adapt_to_prompt_size": # [Optional] True/False - if you want to select model based on prompt size (will pick from available_models)
|
|
||||||
"model": {
|
|
||||||
"model-name": {
|
|
||||||
"needs_moderation": # [Optional] True/False - if you want to call openai moderations endpoint before making completion call. Will raise exception, if flagged.
|
|
||||||
"error_handling": {
|
|
||||||
"error-type": { # One of the errors listed here - https://docs.litellm.ai/docs/exception_mapping#custom-mapping-list
|
|
||||||
"fallback_model": "" # str, name of the model it should try instead, when that error occurs
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
config (Union[dict, str]): A configuration for litellm
|
|
||||||
**kwargs: Additional keyword arguments for litellm.completion
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
litellm.ModelResponse: A ModelResponse with the generated completion
|
|
||||||
|
|
||||||
"""
|
|
||||||
if config is not None:
|
|
||||||
if isinstance(config, str):
|
|
||||||
config = read_config_args(config)
|
|
||||||
elif isinstance(config, dict):
|
|
||||||
config = config
|
|
||||||
else:
|
|
||||||
raise Exception("Config path must be a string or a dictionary.")
|
|
||||||
else:
|
|
||||||
raise Exception("Config path not passed in.")
|
|
||||||
|
|
||||||
if config is None:
|
|
||||||
raise Exception("No completion config in the config file")
|
|
||||||
|
|
||||||
models_with_config = config["model"].keys()
|
|
||||||
model = kwargs["model"]
|
|
||||||
messages = kwargs["messages"]
|
|
||||||
|
|
||||||
## completion config
|
|
||||||
fallback_models = config.get("default_fallback_models", None)
|
|
||||||
available_models = config.get("available_models", None)
|
|
||||||
adapt_to_prompt_size = config.get("adapt_to_prompt_size", False)
|
|
||||||
trim_messages_flag = config.get("trim_messages", False)
|
|
||||||
prompt_larger_than_model = False
|
|
||||||
max_model = model
|
|
||||||
try:
|
|
||||||
max_tokens = litellm.get_max_tokens(model)["max_tokens"]
|
|
||||||
except:
|
|
||||||
max_tokens = 2048 # assume curr model's max window is 2048 tokens
|
|
||||||
if adapt_to_prompt_size:
|
|
||||||
## Pick model based on token window
|
|
||||||
prompt_tokens = litellm.token_counter(
|
|
||||||
model="gpt-3.5-turbo",
|
|
||||||
text="".join(message["content"] for message in messages),
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
curr_max_tokens = litellm.get_max_tokens(model)["max_tokens"]
|
|
||||||
except:
|
|
||||||
curr_max_tokens = 2048
|
|
||||||
if curr_max_tokens < prompt_tokens:
|
|
||||||
prompt_larger_than_model = True
|
|
||||||
for available_model in available_models:
|
|
||||||
try:
|
|
||||||
curr_max_tokens = litellm.get_max_tokens(available_model)[
|
|
||||||
"max_tokens"
|
|
||||||
]
|
|
||||||
if curr_max_tokens > max_tokens:
|
|
||||||
max_tokens = curr_max_tokens
|
|
||||||
max_model = available_model
|
|
||||||
if curr_max_tokens > prompt_tokens:
|
|
||||||
model = available_model
|
|
||||||
prompt_larger_than_model = False
|
|
||||||
except:
|
|
||||||
continue
|
|
||||||
if prompt_larger_than_model:
|
|
||||||
messages = trim_messages(messages=messages, model=max_model)
|
|
||||||
kwargs["messages"] = messages
|
|
||||||
|
|
||||||
kwargs["model"] = model
|
|
||||||
try:
|
|
||||||
if model in models_with_config:
|
|
||||||
## Moderation check
|
|
||||||
if config["model"][model].get("needs_moderation"):
|
|
||||||
input = " ".join(message["content"] for message in messages)
|
|
||||||
response = litellm.moderation(input=input)
|
|
||||||
flagged = response["results"][0]["flagged"]
|
|
||||||
if flagged:
|
|
||||||
raise Exception("This response was flagged as inappropriate")
|
|
||||||
|
|
||||||
## Model-specific Error Handling
|
|
||||||
error_handling = None
|
|
||||||
if config["model"][model].get("error_handling"):
|
|
||||||
error_handling = config["model"][model]["error_handling"]
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = litellm.completion(**kwargs)
|
|
||||||
return response
|
|
||||||
except Exception as e:
|
|
||||||
exception_name = type(e).__name__
|
|
||||||
fallback_model = None
|
|
||||||
if error_handling and exception_name in error_handling:
|
|
||||||
error_handler = error_handling[exception_name]
|
|
||||||
# either switch model or api key
|
|
||||||
fallback_model = error_handler.get("fallback_model", None)
|
|
||||||
if fallback_model:
|
|
||||||
kwargs["model"] = fallback_model
|
|
||||||
return litellm.completion(**kwargs)
|
|
||||||
raise e
|
|
||||||
else:
|
|
||||||
return litellm.completion(**kwargs)
|
|
||||||
except Exception as e:
|
|
||||||
if fallback_models:
|
|
||||||
model = fallback_models.pop(0)
|
|
||||||
return completion_with_fallbacks(
|
|
||||||
model=model, messages=messages, fallbacks=fallback_models
|
|
||||||
)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
|
|
||||||
def completion_with_fallbacks(**kwargs):
|
def completion_with_fallbacks(**kwargs):
|
||||||
nested_kwargs = kwargs.pop("kwargs", {})
|
nested_kwargs = kwargs.pop("kwargs", {})
|
||||||
response = None
|
response = None
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue