mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
add completion configs
This commit is contained in:
parent
371e0428d3
commit
2f44191642
7 changed files with 102 additions and 6 deletions
|
@ -281,7 +281,8 @@ from .utils import (
|
||||||
register_prompt_template,
|
register_prompt_template,
|
||||||
validate_environment,
|
validate_environment,
|
||||||
check_valid_key,
|
check_valid_key,
|
||||||
get_llm_provider
|
get_llm_provider,
|
||||||
|
completion_with_config
|
||||||
)
|
)
|
||||||
from .main import * # type: ignore
|
from .main import * # type: ignore
|
||||||
from .integrations import *
|
from .integrations import *
|
||||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -1321,11 +1321,8 @@ def text_completion(*args, **kwargs):
|
||||||
return completion(*args, **kwargs)
|
return completion(*args, **kwargs)
|
||||||
|
|
||||||
##### Moderation #######################
|
##### Moderation #######################
|
||||||
def moderation(*args, **kwargs):
|
def moderation(input: str, api_key: Optional[str]=None):
|
||||||
# only supports open ai for now
|
# only supports open ai for now
|
||||||
api_key = None
|
|
||||||
if "api_key" in kwargs:
|
|
||||||
api_key = kwargs["api_key"]
|
|
||||||
api_key = (
|
api_key = (
|
||||||
api_key or
|
api_key or
|
||||||
litellm.api_key or
|
litellm.api_key or
|
||||||
|
@ -1336,7 +1333,7 @@ def moderation(*args, **kwargs):
|
||||||
openai.api_type = "open_ai"
|
openai.api_type = "open_ai"
|
||||||
openai.api_version = None
|
openai.api_version = None
|
||||||
openai.api_base = "https://api.openai.com/v1"
|
openai.api_base = "https://api.openai.com/v1"
|
||||||
response = openai.Moderation.create(*args, **kwargs)
|
response = openai.Moderation.create(input)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
####### HELPER FUNCTIONS ################
|
####### HELPER FUNCTIONS ################
|
||||||
|
|
42
litellm/tests/test_config.py
Normal file
42
litellm/tests/test_config.py
Normal file
|
@ -0,0 +1,42 @@
|
||||||
|
import sys, os
|
||||||
|
import traceback
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
import os
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import pytest
|
||||||
|
import litellm
|
||||||
|
from litellm import completion_with_config
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"function": "completion",
|
||||||
|
"model": {
|
||||||
|
"claude-instant-1": {
|
||||||
|
"needs_moderation": True
|
||||||
|
},
|
||||||
|
"gpt-3.5-turbo": {
|
||||||
|
"error_handling": {
|
||||||
|
"ContextWindowExceededError": {"fallback_model": "gpt-3.5-turbo-16k"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_config():
|
||||||
|
try:
|
||||||
|
sample_text = "how does a court case get to the Supreme Court?" * 1000
|
||||||
|
messages = [{"content": sample_text, "role": "user"}]
|
||||||
|
response = completion_with_config(model="gpt-3.5-turbo", messages=messages, config=config)
|
||||||
|
print(response)
|
||||||
|
messages=[{"role": "user", "content": "I want to kill them."}]
|
||||||
|
response = completion_with_config(model="claude-instant-1", messages=messages, config=config)
|
||||||
|
print(response)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Exception: {e}")
|
||||||
|
pytest.fail(f"An exception occurred: {e}")
|
||||||
|
|
||||||
|
# test_config()
|
|
@ -2772,6 +2772,62 @@ def read_config_args(config_path):
|
||||||
|
|
||||||
########## experimental completion variants ############################
|
########## experimental completion variants ############################
|
||||||
|
|
||||||
|
def completion_with_config(*args, config: Union[dict, str], **kwargs):
|
||||||
|
if config is not None:
|
||||||
|
if isinstance(config, str):
|
||||||
|
config = read_config_args(config)
|
||||||
|
elif isinstance(config, dict):
|
||||||
|
config = config
|
||||||
|
else:
|
||||||
|
raise Exception("Config path must be a string or a dictionary.")
|
||||||
|
else:
|
||||||
|
raise Exception("Config path not passed in.")
|
||||||
|
|
||||||
|
## load the completion config
|
||||||
|
completion_config = None
|
||||||
|
|
||||||
|
if config["function"] == "completion":
|
||||||
|
completion_config = config
|
||||||
|
|
||||||
|
if completion_config is None:
|
||||||
|
raise Exception("No completion config in the config file")
|
||||||
|
|
||||||
|
models_with_config = completion_config["model"].keys()
|
||||||
|
model = args[0] if len(args) > 0 else kwargs["model"]
|
||||||
|
messages = args[1] if len(args) > 1 else kwargs["messages"]
|
||||||
|
if model in models_with_config:
|
||||||
|
## Moderation check
|
||||||
|
if completion_config["model"][model].get("needs_moderation"):
|
||||||
|
input = " ".join(message["content"] for message in messages)
|
||||||
|
response = litellm.moderation(input=input)
|
||||||
|
flagged = response["results"][0]["flagged"]
|
||||||
|
if flagged:
|
||||||
|
raise Exception("This response was flagged as inappropriate")
|
||||||
|
|
||||||
|
## Load Error Handling Logic
|
||||||
|
error_handling = None
|
||||||
|
if completion_config["model"][model].get("error_handling"):
|
||||||
|
error_handling = completion_config["model"][model]["error_handling"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = litellm.completion(*args, **kwargs)
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
exception_name = type(e).__name__
|
||||||
|
fallback_model = None
|
||||||
|
if exception_name in error_handling:
|
||||||
|
error_handler = error_handling[exception_name]
|
||||||
|
# either switch model or api key
|
||||||
|
fallback_model = error_handler.get("fallback_model", None)
|
||||||
|
if fallback_model:
|
||||||
|
kwargs["model"] = fallback_model
|
||||||
|
return litellm.completion(*args, **kwargs)
|
||||||
|
raise e
|
||||||
|
else:
|
||||||
|
return litellm.completion(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_split_test(models, completion_call_id):
|
def get_model_split_test(models, completion_call_id):
|
||||||
global last_fetched_at
|
global last_fetched_at
|
||||||
try:
|
try:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue