update model split tests with ui

This commit is contained in:
Krrish Dholakia 2023-08-31 16:42:40 -07:00
parent 66bfd70253
commit b44299cce7
8 changed files with 80 additions and 3 deletions

View file

@ -286,6 +286,7 @@ from .utils import (
Logging, Logging,
acreate, acreate,
get_model_list, get_model_list,
completion_with_split_tests
) )
from .main import * # type: ignore from .main import * # type: ignore
from .integrations import * from .integrations import *

View file

@ -94,7 +94,7 @@ def completion(
custom_api_base=None, custom_api_base=None,
litellm_call_id=None, litellm_call_id=None,
litellm_logging_obj=None, litellm_logging_obj=None,
completion_call_id=None, # this is an optional param to tag individual completion calls id=None, # this is an optional param to tag individual completion calls
# model specific optional params # model specific optional params
# used by text-bison only # used by text-bison only
top_k=40, top_k=40,
@ -154,7 +154,7 @@ def completion(
custom_api_base=custom_api_base, custom_api_base=custom_api_base,
litellm_call_id=litellm_call_id, litellm_call_id=litellm_call_id,
model_alias_map=litellm.model_alias_map, model_alias_map=litellm.model_alias_map,
completion_call_id=completion_call_id completion_call_id=id
) )
logging.update_environment_variables(optional_params=optional_params, litellm_params=litellm_params) logging.update_environment_variables(optional_params=optional_params, litellm_params=litellm_params)
if custom_llm_provider == "azure": if custom_llm_provider == "azure":

View file

@ -0,0 +1,24 @@
#### What this tests ####
# This tests the 'completion_with_split_tests' function to enable a/b testing between llm models
import sys, os
import traceback
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
from litellm import completion_with_split_tests
litellm.set_verbose = True
split_per_model = {
"gpt-4": 0.7,
"claude-instant-1.2": 0.3
}
messages = [{ "content": "Hello, how are you?","role": "user"}]
# print(completion_with_split_tests(models=split_per_model, messages=messages))
# test with client
print(completion_with_split_tests(models=split_per_model, messages=messages, use_client=True, id=1234))

View file

@ -1898,6 +1898,58 @@ async def stream_to_string(generator):
return response return response
########## experimental completion variants ############################
def get_model_split_test(models, completion_call_id):
global last_fetched_at
try:
# make the api call
last_fetched_at = time.time()
print(f"last_fetched_at: {last_fetched_at}")
response = requests.post(
#http://api.litellm.ai
url="http://api.litellm.ai/get_model_split_test", # get the updated dict from table or update the table with the dict
headers={"content-type": "application/json"},
data=json.dumps({"completion_call_id": completion_call_id, "models": models}),
)
print_verbose(f"get_model_list response: {response.text}")
data = response.json()
# update model list
split_test_models = data["split_test_models"]
# update environment - if required
threading.Thread(target=get_all_keys, args=()).start()
return split_test_models
except:
print_verbose(
f"[Non-Blocking Error] get_all_keys error - {traceback.format_exc()}"
)
def completion_with_split_tests(models={}, messages=[], use_client=False, **kwargs):
"""
Example Usage:
models = {
"gpt-4": 0.7,
"huggingface/wizard-coder": 0.3
}
messages = [{ "content": "Hello, how are you?","role": "user"}]
completion_with_split_tests(models=models, messages=messages)
"""
import random
if use_client:
if "id" not in kwargs or kwargs["id"] is None:
raise ValueError("Please tag this completion call, if you'd like to update it's split test values through the UI. - eg. `completion_with_split_tests(.., id=1234)`.")
# get the most recent model split list from server
models = get_model_split_test(models=models, completion_call_id=kwargs["id"])
try:
selected_llm = random.choices(list(models.keys()), weights=list(models.values()))[0]
except:
traceback.print_exc()
raise ValueError("""models does not follow the required format - {'model_name': 'split_percentage'}, e.g. {'gpt-4': 0.7, 'huggingface/wizard-coder': 0.3}""")
return litellm.completion(model=selected_llm, messages=messages, **kwargs)
def completion_with_fallbacks(**kwargs): def completion_with_fallbacks(**kwargs):
response = None response = None
rate_limited_models = set() rate_limited_models = set()

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "litellm" name = "litellm"
version = "0.1.511" version = "0.1.512"
description = "Library to easily interface with LLM API providers" description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"] authors = ["BerriAI"]
license = "MIT License" license = "MIT License"