diff --git a/litellm/__pycache__/utils.cpython-311.pyc b/litellm/__pycache__/utils.cpython-311.pyc index 5286ad6f3..f60178967 100644 Binary files a/litellm/__pycache__/utils.cpython-311.pyc and b/litellm/__pycache__/utils.cpython-311.pyc differ diff --git a/litellm/tests/test_config.py b/litellm/tests/test_config.py index 6a2cbc238..50de81ca9 100644 --- a/litellm/tests/test_config.py +++ b/litellm/tests/test_config.py @@ -59,4 +59,35 @@ def test_config_context_default_fallback(): print(f"Exception: {e}") pytest.fail(f"An exception occurred: {e}") -test_config_context_default_fallback() \ No newline at end of file +# test_config_context_default_fallback() + + +config = { + "function": "completion", + "default_fallback_models": ["gpt-3.5-turbo", "claude-instant-1", "j2-ultra"], + "available_models": ["gpt-3.5-turbo", "gpt-3.5-turbo-0301", "gpt-3.5-turbo-0613", "gpt-4", "gpt-4-0314", "gpt-4-0613", + "j2-ultra", "command-nightly", "togethercomputer/llama-2-70b-chat", "chat-bison", "chat-bison@001", "claude-2"], + "adapt_to_prompt_size": True, + "model": { + "claude-instant-1": { + "needs_moderation": True + }, + "gpt-3.5-turbo": { + "error_handling": { + "ContextWindowExceededError": {"fallback_model": "gpt-3.5-turbo-16k"} + } + } + } +} + +def test_config_context_adapt_to_prompt(): + try: + sample_text = "how does a court case get to the Supreme Court?" * 1000 + messages = [{"content": sample_text, "role": "user"}] + response = completion_with_config(model="gpt-3.5-turbo", messages=messages, config=config) + print(response) + except Exception as e: + print(f"Exception: {e}") + pytest.fail(f"An exception occurred: {e}") + +test_config_context_adapt_to_prompt() \ No newline at end of file diff --git a/litellm/utils.py b/litellm/utils.py index 51e0ebdb0..0bc3f6a42 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2772,7 +2772,7 @@ def read_config_args(config_path) -> dict: ########## experimental completion variants ############################ -def completion_with_config(*args, config: Union[dict, str], **kwargs): +def completion_with_config(*, config: Union[dict, str], **kwargs): if config is not None: if isinstance(config, str): config = read_config_args(config) @@ -2793,11 +2793,31 @@ def completion_with_config(*args, config: Union[dict, str], **kwargs): raise Exception("No completion config in the config file") models_with_config = completion_config["model"].keys() - model = args[0] if len(args) > 0 else kwargs["model"] - messages = args[1] if len(args) > 1 else kwargs["messages"] + model = kwargs["model"] + messages = kwargs["messages"] - ## Default fallback models - fallback_models = completion_config.get("default_fallback_models") + ## completion config + fallback_models = completion_config.get("default_fallback_models", None) + available_models = completion_config.get("available_models", None) + adapt_to_prompt_size = completion_config.get("adapt_to_prompt_size", False) + start_time = time.time() + if adapt_to_prompt_size: + ## Pick model based on token window + prompt_tokens = litellm.token_counter(model="gpt-3.5-turbo", text="".join(message["content"] for message in messages)) + try: + curr_max_tokens = litellm.get_max_tokens(model)["max_tokens"] + except: + curr_max_tokens = 2048 + if curr_max_tokens < prompt_tokens: + for available_model in available_models: + try: + curr_max_tokens = litellm.get_max_tokens(available_model)["max_tokens"] + if curr_max_tokens > prompt_tokens: + model = available_model + except: + continue + end_time = time.time() + kwargs["model"] = model try: if model in models_with_config: ## Moderation check @@ -2814,7 +2834,7 @@ def completion_with_config(*args, config: Union[dict, str], **kwargs): error_handling = completion_config["model"][model]["error_handling"] try: - response = litellm.completion(*args, **kwargs) + response = litellm.completion(**kwargs) return response except Exception as e: exception_name = type(e).__name__ @@ -2825,10 +2845,10 @@ def completion_with_config(*args, config: Union[dict, str], **kwargs): fallback_model = error_handler.get("fallback_model", None) if fallback_model: kwargs["model"] = fallback_model - return litellm.completion(*args, **kwargs) + return litellm.completion(**kwargs) raise e else: - return litellm.completion(*args, **kwargs) + return litellm.completion(**kwargs) except Exception as e: if fallback_models: model = fallback_models.pop(0) @@ -2933,7 +2953,7 @@ def completion_with_fallbacks(**kwargs): # delete model from kwargs if it exists if kwargs.get("model"): del kwargs["model"] - + response = litellm.completion(**kwargs, model=model) if response != None: diff --git a/pyproject.toml b/pyproject.toml index 56d588af0..62a4cf1a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "0.1.726" +version = "0.1.727" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT License"