fix(utils.py): adding support for rules + mythomax/alpaca prompt template

This commit is contained in:
Krrish Dholakia 2023-11-20 18:57:58 -08:00
parent 4f46ac4ab5
commit 855964ed45
7 changed files with 186 additions and 8 deletions

View file

@ -1055,10 +1055,50 @@ def exception_logging(
pass
####### RULES ###################
class Rules:
"""
Fail calls based on the input or llm api output
Example usage:
import litellm
def my_custom_rule(input): # receives the model response
if "i don't think i can answer" in input: # trigger fallback if the model refuses to answer
return False
return True
litellm.post_call_rules = [my_custom_rule] # have these be functions that can be called to fail a call
response = litellm.completion(model="gpt-3.5-turbo", messages=[{"role": "user",
"content": "Hey, how's it going?"}], fallbacks=["openrouter/mythomax"])
"""
def __init__(self) -> None:
pass
def pre_call_rules(self, input: str, model: str):
model, custom_llm_provider, _, _ = get_llm_provider(model=model)
for rule in litellm.pre_call_rules:
if callable(rule):
decision = rule(input)
if decision is False:
raise litellm.APIResponseValidationError("LLM Response failed post-call-rule check", llm_provider=custom_llm_provider, model=model)
return True
def post_call_rules(self, input: str, model: str):
model, custom_llm_provider, _, _ = get_llm_provider(model=model)
for rule in litellm.post_call_rules:
if callable(rule):
decision = rule(input)
if decision is False:
raise litellm.APIResponseValidationError("LLM Response failed post-call-rule check", llm_provider=custom_llm_provider, model=model)
return True
####### CLIENT ###################
# make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking
def client(original_function):
global liteDebuggerClient, get_all_keys
rules_obj = Rules()
def function_setup(
start_time, *args, **kwargs
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
@ -1115,18 +1155,28 @@ def client(original_function):
messages = args[1]
elif kwargs.get("messages", None):
messages = kwargs["messages"]
elif kwargs.get("prompt", None):
messages = kwargs["prompt"]
### PRE-CALL RULES ###
rules_obj.pre_call_rules(input="".join(m["content"] for m in messages), model=model)
elif call_type == CallTypes.embedding.value:
messages = args[1] if len(args) > 1 else kwargs["input"]
stream = True if "stream" in kwargs and kwargs["stream"] == True else False
logging_obj = Logging(model=model, messages=messages, stream=stream, litellm_call_id=kwargs["litellm_call_id"], function_id=function_id, call_type=call_type, start_time=start_time)
return logging_obj
except Exception as e: # DO NOT BLOCK running the function because of this
except Exception as e:
import logging
logging.debug(f"[Non-Blocking] {traceback.format_exc()}; args - {args}; kwargs - {kwargs}")
raise e
def post_call_processing(original_response, model):
try:
call_type = original_function.__name__
if call_type == CallTypes.completion.value or call_type == CallTypes.acompletion.value:
model_response = original_response['choices'][0]['message']['content']
### POST-CALL RULES ###
rules_obj.post_call_rules(input=model_response, model=model)
except Exception as e:
raise e
def crash_reporting(*args, **kwargs):
if litellm.telemetry:
try:
@ -1203,6 +1253,9 @@ def client(original_function):
return result
elif "acompletion" in kwargs and kwargs["acompletion"] == True:
return result
### POST-CALL RULES ###
post_call_processing(original_response=result, model=model)
# [OPTIONAL] ADD TO CACHE
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
@ -1308,6 +1361,10 @@ def client(original_function):
return litellm.stream_chunk_builder(chunks)
else:
return result
### POST-CALL RULES ###
post_call_processing(original_response=result, model=model)
# [OPTIONAL] ADD TO CACHE
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
litellm.cache.add_cache(result, *args, **kwargs)
@ -3272,7 +3329,7 @@ def register_prompt_template(model: str, roles: dict, initial_prompt_value: str
}
return litellm.custom_prompt_dict
####### [BETA] HOSTED PRODUCT ################ - https://docs.litellm.ai/docs/debugging/hosted_debugging
####### DEPRECATED ################
def get_all_keys(llm_provider=None):