mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
fix(utils.py): adding support for rules + mythomax/alpaca prompt template
This commit is contained in:
parent
4f46ac4ab5
commit
855964ed45
7 changed files with 186 additions and 8 deletions
|
@ -8,6 +8,8 @@ input_callback: List[Union[str, Callable]] = []
|
|||
success_callback: List[Union[str, Callable]] = []
|
||||
failure_callback: List[Union[str, Callable]] = []
|
||||
callbacks: List[Callable] = []
|
||||
pre_call_rules: List[Callable] = []
|
||||
post_call_rules: List[Callable] = []
|
||||
set_verbose = False
|
||||
email: Optional[
|
||||
str
|
||||
|
@ -386,7 +388,8 @@ from .exceptions import (
|
|||
BudgetExceededError,
|
||||
APIError,
|
||||
Timeout,
|
||||
APIConnectionError
|
||||
APIConnectionError,
|
||||
APIResponseValidationError
|
||||
)
|
||||
from .budget_manager import BudgetManager
|
||||
from .proxy.proxy_cli import run_server
|
||||
|
|
|
@ -18,6 +18,7 @@ from openai import (
|
|||
APIError,
|
||||
APITimeoutError,
|
||||
APIConnectionError,
|
||||
APIResponseValidationError
|
||||
)
|
||||
import httpx
|
||||
|
||||
|
@ -119,6 +120,22 @@ class APIConnectionError(APIConnectionError): # type: ignore
|
|||
request=request
|
||||
)
|
||||
|
||||
# raised if an invalid request (not get, delete, put, post) is made
|
||||
class APIResponseValidationError(APIResponseValidationError): # type: ignore
|
||||
def __init__(self, message, llm_provider, model):
|
||||
self.message = message
|
||||
self.llm_provider = llm_provider
|
||||
self.model = model
|
||||
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
|
||||
response = httpx.Response(status_code=500, request=request)
|
||||
super().__init__(
|
||||
response=response,
|
||||
body=None,
|
||||
message=message
|
||||
)
|
||||
|
||||
|
||||
|
||||
class OpenAIError(OpenAIError): # type: ignore
|
||||
def __init__(self, original_exception):
|
||||
self.status_code = original_exception.http_status
|
||||
|
|
|
@ -6,6 +6,7 @@ from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper,
|
|||
from typing import Callable, Optional
|
||||
import aiohttp, requests
|
||||
import litellm
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
from openai import OpenAI, AsyncOpenAI
|
||||
|
||||
class OpenAIError(Exception):
|
||||
|
@ -172,7 +173,8 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers: Optional[dict]=None):
|
||||
headers: Optional[dict]=None,
|
||||
custom_prompt_dict: dict={}):
|
||||
super().completion()
|
||||
exception_mapping_worked = False
|
||||
try:
|
||||
|
|
|
@ -7,6 +7,29 @@ from typing import Optional
|
|||
def default_pt(messages):
|
||||
return " ".join(message["content"] for message in messages)
|
||||
|
||||
# alpaca prompt template - for models like mythomax, etc.
|
||||
def alpaca_pt(messages):
|
||||
prompt = custom_prompt(
|
||||
role_dict={
|
||||
"system": {
|
||||
"pre_message": "### Instruction:\n",
|
||||
"post_message": "\n\n",
|
||||
},
|
||||
"user": {
|
||||
"pre_message": "### Instruction:\n",
|
||||
"post_message": "\n\n",
|
||||
},
|
||||
"assistant": {
|
||||
"pre_message": "### Response:\n",
|
||||
"post_message": "\n\n"
|
||||
}
|
||||
},
|
||||
bos_token="<s>",
|
||||
eos_token="</s>",
|
||||
messages=messages
|
||||
)
|
||||
return prompt
|
||||
|
||||
# Llama2 prompt template
|
||||
def llama_2_chat_pt(messages):
|
||||
prompt = custom_prompt(
|
||||
|
@ -276,7 +299,6 @@ def custom_prompt(role_dict: dict, messages: list, initial_prompt_value: str="",
|
|||
def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str]=None):
|
||||
original_model_name = model
|
||||
model = model.lower()
|
||||
|
||||
if custom_llm_provider == "ollama":
|
||||
return ollama_pt(model=model, messages=messages)
|
||||
elif custom_llm_provider == "anthropic":
|
||||
|
@ -302,6 +324,8 @@ def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str
|
|||
return phind_codellama_pt(messages=messages)
|
||||
elif "togethercomputer/llama-2" in model and ("instruct" in model or "chat" in model):
|
||||
return llama_2_chat_pt(messages=messages)
|
||||
elif model in ["gryphe/mythomax-l2-13b", "gryphe/mythomix-l2-13b", "gryphe/mythologic-l2-13b"]:
|
||||
return alpaca_pt(messages=messages)
|
||||
else:
|
||||
return hf_chat_template(original_model_name, messages)
|
||||
except:
|
||||
|
|
|
@ -556,7 +556,8 @@ def completion(
|
|||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
timeout=timeout
|
||||
timeout=timeout,
|
||||
custom_prompt_dict=custom_prompt_dict
|
||||
)
|
||||
except Exception as e:
|
||||
## LOGGING - log the original exception returned
|
||||
|
|
74
litellm/tests/test_rules.py
Normal file
74
litellm/tests/test_rules.py
Normal file
|
@ -0,0 +1,74 @@
|
|||
#### What this tests ####
|
||||
# This tests setting rules before / after making llm api calls
|
||||
import sys, os, time
|
||||
import traceback, asyncio
|
||||
import pytest
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
from litellm import completion, acompletion
|
||||
|
||||
def my_pre_call_rule(input: str):
|
||||
print(f"input: {input}")
|
||||
print(f"INSIDE MY PRE CALL RULE, len(input) - {len(input)}")
|
||||
if len(input) > 10:
|
||||
return False
|
||||
return True
|
||||
|
||||
def my_post_call_rule(input: str):
|
||||
input = input.lower()
|
||||
print(f"input: {input}")
|
||||
print(f"INSIDE MY POST CALL RULE, len(input) - {len(input)}")
|
||||
if "sorry" in input:
|
||||
return False
|
||||
return True
|
||||
|
||||
## Test 1: Pre-call rule
|
||||
def test_pre_call_rule():
|
||||
try:
|
||||
litellm.pre_call_rules = [my_pre_call_rule]
|
||||
### completion
|
||||
response = completion(model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "say something inappropriate"}])
|
||||
pytest.fail(f"Completion call should have been failed. ")
|
||||
except:
|
||||
pass
|
||||
### async completion
|
||||
async def test_async_response():
|
||||
user_message = "Hello, how are you?"
|
||||
messages = [{"content": user_message, "role": "user"}]
|
||||
try:
|
||||
response = await acompletion(model="gpt-3.5-turbo", messages=messages)
|
||||
pytest.fail(f"acompletion call should have been failed. ")
|
||||
except Exception as e:
|
||||
pass
|
||||
asyncio.run(test_async_response())
|
||||
|
||||
# test_pre_call_rule()
|
||||
## Test 2: Post-call rule
|
||||
|
||||
def test_post_call_rule():
|
||||
try:
|
||||
litellm.pre_call_rules = []
|
||||
litellm.post_call_rules = [my_post_call_rule]
|
||||
### completion
|
||||
response = completion(model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "say sorry"}],
|
||||
fallbacks=["deepinfra/Gryphe/MythoMax-L2-13b"])
|
||||
pytest.fail(f"Completion call should have been failed. ")
|
||||
except:
|
||||
pass
|
||||
print(f"MAKING ACOMPLETION CALL")
|
||||
# litellm.set_verbose = True
|
||||
### async completion
|
||||
async def test_async_response():
|
||||
messages=[{"role": "user", "content": "say sorry"}]
|
||||
try:
|
||||
response = await acompletion(model="gpt-3.5-turbo", messages=messages)
|
||||
pytest.fail(f"acompletion call should have been failed.")
|
||||
except Exception as e:
|
||||
pass
|
||||
asyncio.run(test_async_response())
|
||||
|
||||
# test_post_call_rule()
|
|
@ -1055,10 +1055,50 @@ def exception_logging(
|
|||
pass
|
||||
|
||||
|
||||
####### RULES ###################
|
||||
|
||||
class Rules:
|
||||
"""
|
||||
Fail calls based on the input or llm api output
|
||||
|
||||
Example usage:
|
||||
import litellm
|
||||
def my_custom_rule(input): # receives the model response
|
||||
if "i don't think i can answer" in input: # trigger fallback if the model refuses to answer
|
||||
return False
|
||||
return True
|
||||
|
||||
litellm.post_call_rules = [my_custom_rule] # have these be functions that can be called to fail a call
|
||||
|
||||
response = litellm.completion(model="gpt-3.5-turbo", messages=[{"role": "user",
|
||||
"content": "Hey, how's it going?"}], fallbacks=["openrouter/mythomax"])
|
||||
"""
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def pre_call_rules(self, input: str, model: str):
|
||||
model, custom_llm_provider, _, _ = get_llm_provider(model=model)
|
||||
for rule in litellm.pre_call_rules:
|
||||
if callable(rule):
|
||||
decision = rule(input)
|
||||
if decision is False:
|
||||
raise litellm.APIResponseValidationError("LLM Response failed post-call-rule check", llm_provider=custom_llm_provider, model=model)
|
||||
return True
|
||||
|
||||
def post_call_rules(self, input: str, model: str):
|
||||
model, custom_llm_provider, _, _ = get_llm_provider(model=model)
|
||||
for rule in litellm.post_call_rules:
|
||||
if callable(rule):
|
||||
decision = rule(input)
|
||||
if decision is False:
|
||||
raise litellm.APIResponseValidationError("LLM Response failed post-call-rule check", llm_provider=custom_llm_provider, model=model)
|
||||
return True
|
||||
|
||||
####### CLIENT ###################
|
||||
# make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking
|
||||
def client(original_function):
|
||||
global liteDebuggerClient, get_all_keys
|
||||
rules_obj = Rules()
|
||||
def function_setup(
|
||||
start_time, *args, **kwargs
|
||||
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
|
||||
|
@ -1115,18 +1155,28 @@ def client(original_function):
|
|||
messages = args[1]
|
||||
elif kwargs.get("messages", None):
|
||||
messages = kwargs["messages"]
|
||||
elif kwargs.get("prompt", None):
|
||||
messages = kwargs["prompt"]
|
||||
### PRE-CALL RULES ###
|
||||
rules_obj.pre_call_rules(input="".join(m["content"] for m in messages), model=model)
|
||||
elif call_type == CallTypes.embedding.value:
|
||||
messages = args[1] if len(args) > 1 else kwargs["input"]
|
||||
stream = True if "stream" in kwargs and kwargs["stream"] == True else False
|
||||
logging_obj = Logging(model=model, messages=messages, stream=stream, litellm_call_id=kwargs["litellm_call_id"], function_id=function_id, call_type=call_type, start_time=start_time)
|
||||
return logging_obj
|
||||
except Exception as e: # DO NOT BLOCK running the function because of this
|
||||
except Exception as e:
|
||||
import logging
|
||||
logging.debug(f"[Non-Blocking] {traceback.format_exc()}; args - {args}; kwargs - {kwargs}")
|
||||
raise e
|
||||
|
||||
def post_call_processing(original_response, model):
|
||||
try:
|
||||
call_type = original_function.__name__
|
||||
if call_type == CallTypes.completion.value or call_type == CallTypes.acompletion.value:
|
||||
model_response = original_response['choices'][0]['message']['content']
|
||||
### POST-CALL RULES ###
|
||||
rules_obj.post_call_rules(input=model_response, model=model)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def crash_reporting(*args, **kwargs):
|
||||
if litellm.telemetry:
|
||||
try:
|
||||
|
@ -1203,6 +1253,9 @@ def client(original_function):
|
|||
return result
|
||||
elif "acompletion" in kwargs and kwargs["acompletion"] == True:
|
||||
return result
|
||||
|
||||
### POST-CALL RULES ###
|
||||
post_call_processing(original_response=result, model=model)
|
||||
|
||||
# [OPTIONAL] ADD TO CACHE
|
||||
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
|
||||
|
@ -1308,6 +1361,10 @@ def client(original_function):
|
|||
return litellm.stream_chunk_builder(chunks)
|
||||
else:
|
||||
return result
|
||||
|
||||
### POST-CALL RULES ###
|
||||
post_call_processing(original_response=result, model=model)
|
||||
|
||||
# [OPTIONAL] ADD TO CACHE
|
||||
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
|
||||
litellm.cache.add_cache(result, *args, **kwargs)
|
||||
|
@ -3272,7 +3329,7 @@ def register_prompt_template(model: str, roles: dict, initial_prompt_value: str
|
|||
}
|
||||
return litellm.custom_prompt_dict
|
||||
|
||||
####### [BETA] HOSTED PRODUCT ################ - https://docs.litellm.ai/docs/debugging/hosted_debugging
|
||||
####### DEPRECATED ################
|
||||
|
||||
|
||||
def get_all_keys(llm_provider=None):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue