forked from phoenix/litellm-mirror
fix(utils.py): adding support for rules + mythomax/alpaca prompt template
This commit is contained in:
parent
4f46ac4ab5
commit
855964ed45
7 changed files with 186 additions and 8 deletions
|
@ -1055,10 +1055,50 @@ def exception_logging(
|
|||
pass
|
||||
|
||||
|
||||
####### RULES ###################
|
||||
|
||||
class Rules:
|
||||
"""
|
||||
Fail calls based on the input or llm api output
|
||||
|
||||
Example usage:
|
||||
import litellm
|
||||
def my_custom_rule(input): # receives the model response
|
||||
if "i don't think i can answer" in input: # trigger fallback if the model refuses to answer
|
||||
return False
|
||||
return True
|
||||
|
||||
litellm.post_call_rules = [my_custom_rule] # have these be functions that can be called to fail a call
|
||||
|
||||
response = litellm.completion(model="gpt-3.5-turbo", messages=[{"role": "user",
|
||||
"content": "Hey, how's it going?"}], fallbacks=["openrouter/mythomax"])
|
||||
"""
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def pre_call_rules(self, input: str, model: str):
|
||||
model, custom_llm_provider, _, _ = get_llm_provider(model=model)
|
||||
for rule in litellm.pre_call_rules:
|
||||
if callable(rule):
|
||||
decision = rule(input)
|
||||
if decision is False:
|
||||
raise litellm.APIResponseValidationError("LLM Response failed post-call-rule check", llm_provider=custom_llm_provider, model=model)
|
||||
return True
|
||||
|
||||
def post_call_rules(self, input: str, model: str):
|
||||
model, custom_llm_provider, _, _ = get_llm_provider(model=model)
|
||||
for rule in litellm.post_call_rules:
|
||||
if callable(rule):
|
||||
decision = rule(input)
|
||||
if decision is False:
|
||||
raise litellm.APIResponseValidationError("LLM Response failed post-call-rule check", llm_provider=custom_llm_provider, model=model)
|
||||
return True
|
||||
|
||||
####### CLIENT ###################
|
||||
# make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking
|
||||
def client(original_function):
|
||||
global liteDebuggerClient, get_all_keys
|
||||
rules_obj = Rules()
|
||||
def function_setup(
|
||||
start_time, *args, **kwargs
|
||||
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
|
||||
|
@ -1115,18 +1155,28 @@ def client(original_function):
|
|||
messages = args[1]
|
||||
elif kwargs.get("messages", None):
|
||||
messages = kwargs["messages"]
|
||||
elif kwargs.get("prompt", None):
|
||||
messages = kwargs["prompt"]
|
||||
### PRE-CALL RULES ###
|
||||
rules_obj.pre_call_rules(input="".join(m["content"] for m in messages), model=model)
|
||||
elif call_type == CallTypes.embedding.value:
|
||||
messages = args[1] if len(args) > 1 else kwargs["input"]
|
||||
stream = True if "stream" in kwargs and kwargs["stream"] == True else False
|
||||
logging_obj = Logging(model=model, messages=messages, stream=stream, litellm_call_id=kwargs["litellm_call_id"], function_id=function_id, call_type=call_type, start_time=start_time)
|
||||
return logging_obj
|
||||
except Exception as e: # DO NOT BLOCK running the function because of this
|
||||
except Exception as e:
|
||||
import logging
|
||||
logging.debug(f"[Non-Blocking] {traceback.format_exc()}; args - {args}; kwargs - {kwargs}")
|
||||
raise e
|
||||
|
||||
def post_call_processing(original_response, model):
|
||||
try:
|
||||
call_type = original_function.__name__
|
||||
if call_type == CallTypes.completion.value or call_type == CallTypes.acompletion.value:
|
||||
model_response = original_response['choices'][0]['message']['content']
|
||||
### POST-CALL RULES ###
|
||||
rules_obj.post_call_rules(input=model_response, model=model)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def crash_reporting(*args, **kwargs):
|
||||
if litellm.telemetry:
|
||||
try:
|
||||
|
@ -1203,6 +1253,9 @@ def client(original_function):
|
|||
return result
|
||||
elif "acompletion" in kwargs and kwargs["acompletion"] == True:
|
||||
return result
|
||||
|
||||
### POST-CALL RULES ###
|
||||
post_call_processing(original_response=result, model=model)
|
||||
|
||||
# [OPTIONAL] ADD TO CACHE
|
||||
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
|
||||
|
@ -1308,6 +1361,10 @@ def client(original_function):
|
|||
return litellm.stream_chunk_builder(chunks)
|
||||
else:
|
||||
return result
|
||||
|
||||
### POST-CALL RULES ###
|
||||
post_call_processing(original_response=result, model=model)
|
||||
|
||||
# [OPTIONAL] ADD TO CACHE
|
||||
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
|
||||
litellm.cache.add_cache(result, *args, **kwargs)
|
||||
|
@ -3272,7 +3329,7 @@ def register_prompt_template(model: str, roles: dict, initial_prompt_value: str
|
|||
}
|
||||
return litellm.custom_prompt_dict
|
||||
|
||||
####### [BETA] HOSTED PRODUCT ################ - https://docs.litellm.ai/docs/debugging/hosted_debugging
|
||||
####### DEPRECATED ################
|
||||
|
||||
|
||||
def get_all_keys(llm_provider=None):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue