fix(utils.py): adding support for rules + mythomax/alpaca prompt template

This commit is contained in:
Krrish Dholakia 2023-11-20 18:57:58 -08:00
parent 4f46ac4ab5
commit 855964ed45
7 changed files with 186 additions and 8 deletions

View file

@ -8,6 +8,8 @@ input_callback: List[Union[str, Callable]] = []
success_callback: List[Union[str, Callable]] = []
failure_callback: List[Union[str, Callable]] = []
callbacks: List[Callable] = []
pre_call_rules: List[Callable] = []
post_call_rules: List[Callable] = []
set_verbose = False
email: Optional[
str
@ -386,7 +388,8 @@ from .exceptions import (
BudgetExceededError,
APIError,
Timeout,
APIConnectionError
APIConnectionError,
APIResponseValidationError
)
from .budget_manager import BudgetManager
from .proxy.proxy_cli import run_server

View file

@ -18,6 +18,7 @@ from openai import (
APIError,
APITimeoutError,
APIConnectionError,
APIResponseValidationError
)
import httpx
@ -119,6 +120,22 @@ class APIConnectionError(APIConnectionError): # type: ignore
request=request
)
# raised if an invalid request (not get, delete, put, post) is made
class APIResponseValidationError(APIResponseValidationError): # type: ignore
def __init__(self, message, llm_provider, model):
self.message = message
self.llm_provider = llm_provider
self.model = model
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
response = httpx.Response(status_code=500, request=request)
super().__init__(
response=response,
body=None,
message=message
)
class OpenAIError(OpenAIError): # type: ignore
def __init__(self, original_exception):
self.status_code = original_exception.http_status

View file

@ -6,6 +6,7 @@ from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper,
from typing import Callable, Optional
import aiohttp, requests
import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt
from openai import OpenAI, AsyncOpenAI
class OpenAIError(Exception):
@ -172,7 +173,8 @@ class OpenAIChatCompletion(BaseLLM):
optional_params=None,
litellm_params=None,
logger_fn=None,
headers: Optional[dict]=None):
headers: Optional[dict]=None,
custom_prompt_dict: dict={}):
super().completion()
exception_mapping_worked = False
try:

View file

@ -7,6 +7,29 @@ from typing import Optional
def default_pt(messages):
return " ".join(message["content"] for message in messages)
# alpaca prompt template - for models like mythomax, etc.
def alpaca_pt(messages):
prompt = custom_prompt(
role_dict={
"system": {
"pre_message": "### Instruction:\n",
"post_message": "\n\n",
},
"user": {
"pre_message": "### Instruction:\n",
"post_message": "\n\n",
},
"assistant": {
"pre_message": "### Response:\n",
"post_message": "\n\n"
}
},
bos_token="<s>",
eos_token="</s>",
messages=messages
)
return prompt
# Llama2 prompt template
def llama_2_chat_pt(messages):
prompt = custom_prompt(
@ -276,7 +299,6 @@ def custom_prompt(role_dict: dict, messages: list, initial_prompt_value: str="",
def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str]=None):
original_model_name = model
model = model.lower()
if custom_llm_provider == "ollama":
return ollama_pt(model=model, messages=messages)
elif custom_llm_provider == "anthropic":
@ -302,6 +324,8 @@ def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str
return phind_codellama_pt(messages=messages)
elif "togethercomputer/llama-2" in model and ("instruct" in model or "chat" in model):
return llama_2_chat_pt(messages=messages)
elif model in ["gryphe/mythomax-l2-13b", "gryphe/mythomix-l2-13b", "gryphe/mythologic-l2-13b"]:
return alpaca_pt(messages=messages)
else:
return hf_chat_template(original_model_name, messages)
except:

View file

@ -556,7 +556,8 @@ def completion(
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
timeout=timeout
timeout=timeout,
custom_prompt_dict=custom_prompt_dict
)
except Exception as e:
## LOGGING - log the original exception returned

View file

@ -0,0 +1,74 @@
#### What this tests ####
# This tests setting rules before / after making llm api calls
import sys, os, time
import traceback, asyncio
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
from litellm import completion, acompletion
def my_pre_call_rule(input: str):
print(f"input: {input}")
print(f"INSIDE MY PRE CALL RULE, len(input) - {len(input)}")
if len(input) > 10:
return False
return True
def my_post_call_rule(input: str):
input = input.lower()
print(f"input: {input}")
print(f"INSIDE MY POST CALL RULE, len(input) - {len(input)}")
if "sorry" in input:
return False
return True
## Test 1: Pre-call rule
def test_pre_call_rule():
try:
litellm.pre_call_rules = [my_pre_call_rule]
### completion
response = completion(model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "say something inappropriate"}])
pytest.fail(f"Completion call should have been failed. ")
except:
pass
### async completion
async def test_async_response():
user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}]
try:
response = await acompletion(model="gpt-3.5-turbo", messages=messages)
pytest.fail(f"acompletion call should have been failed. ")
except Exception as e:
pass
asyncio.run(test_async_response())
# test_pre_call_rule()
## Test 2: Post-call rule
def test_post_call_rule():
try:
litellm.pre_call_rules = []
litellm.post_call_rules = [my_post_call_rule]
### completion
response = completion(model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "say sorry"}],
fallbacks=["deepinfra/Gryphe/MythoMax-L2-13b"])
pytest.fail(f"Completion call should have been failed. ")
except:
pass
print(f"MAKING ACOMPLETION CALL")
# litellm.set_verbose = True
### async completion
async def test_async_response():
messages=[{"role": "user", "content": "say sorry"}]
try:
response = await acompletion(model="gpt-3.5-turbo", messages=messages)
pytest.fail(f"acompletion call should have been failed.")
except Exception as e:
pass
asyncio.run(test_async_response())
# test_post_call_rule()

View file

@ -1055,10 +1055,50 @@ def exception_logging(
pass
####### RULES ###################
class Rules:
"""
Fail calls based on the input or llm api output
Example usage:
import litellm
def my_custom_rule(input): # receives the model response
if "i don't think i can answer" in input: # trigger fallback if the model refuses to answer
return False
return True
litellm.post_call_rules = [my_custom_rule] # have these be functions that can be called to fail a call
response = litellm.completion(model="gpt-3.5-turbo", messages=[{"role": "user",
"content": "Hey, how's it going?"}], fallbacks=["openrouter/mythomax"])
"""
def __init__(self) -> None:
pass
def pre_call_rules(self, input: str, model: str):
model, custom_llm_provider, _, _ = get_llm_provider(model=model)
for rule in litellm.pre_call_rules:
if callable(rule):
decision = rule(input)
if decision is False:
raise litellm.APIResponseValidationError("LLM Response failed post-call-rule check", llm_provider=custom_llm_provider, model=model)
return True
def post_call_rules(self, input: str, model: str):
model, custom_llm_provider, _, _ = get_llm_provider(model=model)
for rule in litellm.post_call_rules:
if callable(rule):
decision = rule(input)
if decision is False:
raise litellm.APIResponseValidationError("LLM Response failed post-call-rule check", llm_provider=custom_llm_provider, model=model)
return True
####### CLIENT ###################
# make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking
def client(original_function):
global liteDebuggerClient, get_all_keys
rules_obj = Rules()
def function_setup(
start_time, *args, **kwargs
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
@ -1115,18 +1155,28 @@ def client(original_function):
messages = args[1]
elif kwargs.get("messages", None):
messages = kwargs["messages"]
elif kwargs.get("prompt", None):
messages = kwargs["prompt"]
### PRE-CALL RULES ###
rules_obj.pre_call_rules(input="".join(m["content"] for m in messages), model=model)
elif call_type == CallTypes.embedding.value:
messages = args[1] if len(args) > 1 else kwargs["input"]
stream = True if "stream" in kwargs and kwargs["stream"] == True else False
logging_obj = Logging(model=model, messages=messages, stream=stream, litellm_call_id=kwargs["litellm_call_id"], function_id=function_id, call_type=call_type, start_time=start_time)
return logging_obj
except Exception as e: # DO NOT BLOCK running the function because of this
except Exception as e:
import logging
logging.debug(f"[Non-Blocking] {traceback.format_exc()}; args - {args}; kwargs - {kwargs}")
raise e
def post_call_processing(original_response, model):
try:
call_type = original_function.__name__
if call_type == CallTypes.completion.value or call_type == CallTypes.acompletion.value:
model_response = original_response['choices'][0]['message']['content']
### POST-CALL RULES ###
rules_obj.post_call_rules(input=model_response, model=model)
except Exception as e:
raise e
def crash_reporting(*args, **kwargs):
if litellm.telemetry:
try:
@ -1203,6 +1253,9 @@ def client(original_function):
return result
elif "acompletion" in kwargs and kwargs["acompletion"] == True:
return result
### POST-CALL RULES ###
post_call_processing(original_response=result, model=model)
# [OPTIONAL] ADD TO CACHE
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
@ -1308,6 +1361,10 @@ def client(original_function):
return litellm.stream_chunk_builder(chunks)
else:
return result
### POST-CALL RULES ###
post_call_processing(original_response=result, model=model)
# [OPTIONAL] ADD TO CACHE
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
litellm.cache.add_cache(result, *args, **kwargs)
@ -3272,7 +3329,7 @@ def register_prompt_template(model: str, roles: dict, initial_prompt_value: str
}
return litellm.custom_prompt_dict
####### [BETA] HOSTED PRODUCT ################ - https://docs.litellm.ai/docs/debugging/hosted_debugging
####### DEPRECATED ################
def get_all_keys(llm_provider=None):