diff --git a/litellm/__init__.py b/litellm/__init__.py index 68de303b6..2d5618bcd 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -281,7 +281,8 @@ from .utils import ( register_prompt_template, validate_environment, check_valid_key, - get_llm_provider + get_llm_provider, + completion_with_config ) from .main import * # type: ignore from .integrations import * diff --git a/litellm/__pycache__/__init__.cpython-311.pyc b/litellm/__pycache__/__init__.cpython-311.pyc index 77b9f46da..0bee58842 100644 Binary files a/litellm/__pycache__/__init__.cpython-311.pyc and b/litellm/__pycache__/__init__.cpython-311.pyc differ diff --git a/litellm/__pycache__/main.cpython-311.pyc b/litellm/__pycache__/main.cpython-311.pyc index 9a91ccf3b..9ad726ac5 100644 Binary files a/litellm/__pycache__/main.cpython-311.pyc and b/litellm/__pycache__/main.cpython-311.pyc differ diff --git a/litellm/__pycache__/utils.cpython-311.pyc b/litellm/__pycache__/utils.cpython-311.pyc index de1a3929e..7c768446d 100644 Binary files a/litellm/__pycache__/utils.cpython-311.pyc and b/litellm/__pycache__/utils.cpython-311.pyc differ diff --git a/litellm/main.py b/litellm/main.py index 538c60b8d..150fa31f1 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1321,11 +1321,8 @@ def text_completion(*args, **kwargs): return completion(*args, **kwargs) ##### Moderation ####################### -def moderation(*args, **kwargs): +def moderation(input: str, api_key: Optional[str]=None): # only supports open ai for now - api_key = None - if "api_key" in kwargs: - api_key = kwargs["api_key"] api_key = ( api_key or litellm.api_key or @@ -1336,7 +1333,7 @@ def moderation(*args, **kwargs): openai.api_type = "open_ai" openai.api_version = None openai.api_base = "https://api.openai.com/v1" - response = openai.Moderation.create(*args, **kwargs) + response = openai.Moderation.create(input) return response ####### HELPER FUNCTIONS ################ diff --git a/litellm/tests/test_config.py b/litellm/tests/test_config.py new file mode 100644 index 000000000..0df5928f9 --- /dev/null +++ b/litellm/tests/test_config.py @@ -0,0 +1,42 @@ +import sys, os +import traceback +from dotenv import load_dotenv + +load_dotenv() +import os + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import pytest +import litellm +from litellm import completion_with_config + +config = { + "function": "completion", + "model": { + "claude-instant-1": { + "needs_moderation": True + }, + "gpt-3.5-turbo": { + "error_handling": { + "ContextWindowExceededError": {"fallback_model": "gpt-3.5-turbo-16k"} + } + } + } +} + +def test_config(): + try: + sample_text = "how does a court case get to the Supreme Court?" * 1000 + messages = [{"content": sample_text, "role": "user"}] + response = completion_with_config(model="gpt-3.5-turbo", messages=messages, config=config) + print(response) + messages=[{"role": "user", "content": "I want to kill them."}] + response = completion_with_config(model="claude-instant-1", messages=messages, config=config) + print(response) + except Exception as e: + print(f"Exception: {e}") + pytest.fail(f"An exception occurred: {e}") + +# test_config() \ No newline at end of file diff --git a/litellm/utils.py b/litellm/utils.py index 7e168d1ba..bb88101c1 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2772,6 +2772,62 @@ def read_config_args(config_path): ########## experimental completion variants ############################ +def completion_with_config(*args, config: Union[dict, str], **kwargs): + if config is not None: + if isinstance(config, str): + config = read_config_args(config) + elif isinstance(config, dict): + config = config + else: + raise Exception("Config path must be a string or a dictionary.") + else: + raise Exception("Config path not passed in.") + + ## load the completion config + completion_config = None + + if config["function"] == "completion": + completion_config = config + + if completion_config is None: + raise Exception("No completion config in the config file") + + models_with_config = completion_config["model"].keys() + model = args[0] if len(args) > 0 else kwargs["model"] + messages = args[1] if len(args) > 1 else kwargs["messages"] + if model in models_with_config: + ## Moderation check + if completion_config["model"][model].get("needs_moderation"): + input = " ".join(message["content"] for message in messages) + response = litellm.moderation(input=input) + flagged = response["results"][0]["flagged"] + if flagged: + raise Exception("This response was flagged as inappropriate") + + ## Load Error Handling Logic + error_handling = None + if completion_config["model"][model].get("error_handling"): + error_handling = completion_config["model"][model]["error_handling"] + + try: + response = litellm.completion(*args, **kwargs) + return response + except Exception as e: + exception_name = type(e).__name__ + fallback_model = None + if exception_name in error_handling: + error_handler = error_handling[exception_name] + # either switch model or api key + fallback_model = error_handler.get("fallback_model", None) + if fallback_model: + kwargs["model"] = fallback_model + return litellm.completion(*args, **kwargs) + raise e + else: + return litellm.completion(*args, **kwargs) + + + def get_model_split_test(models, completion_call_id): global last_fetched_at try: