diff --git a/litellm/__init__.py b/litellm/__init__.py index 1602f8ffdc..57ba28b998 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -286,6 +286,7 @@ from .utils import ( Logging, acreate, get_model_list, + completion_with_split_tests ) from .main import * # type: ignore from .integrations import * diff --git a/litellm/__pycache__/__init__.cpython-311.pyc b/litellm/__pycache__/__init__.cpython-311.pyc index 61bd76f596..8740f2ca7b 100644 Binary files a/litellm/__pycache__/__init__.cpython-311.pyc and b/litellm/__pycache__/__init__.cpython-311.pyc differ diff --git a/litellm/__pycache__/main.cpython-311.pyc b/litellm/__pycache__/main.cpython-311.pyc index 0d73a670da..d9973c270e 100644 Binary files a/litellm/__pycache__/main.cpython-311.pyc and b/litellm/__pycache__/main.cpython-311.pyc differ diff --git a/litellm/__pycache__/utils.cpython-311.pyc b/litellm/__pycache__/utils.cpython-311.pyc index e6cb1a5594..27f57b6dae 100644 Binary files a/litellm/__pycache__/utils.cpython-311.pyc and b/litellm/__pycache__/utils.cpython-311.pyc differ diff --git a/litellm/main.py b/litellm/main.py index 1b9b7ca155..62a8900932 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -94,7 +94,7 @@ def completion( custom_api_base=None, litellm_call_id=None, litellm_logging_obj=None, - completion_call_id=None, # this is an optional param to tag individual completion calls + id=None, # this is an optional param to tag individual completion calls # model specific optional params # used by text-bison only top_k=40, @@ -154,7 +154,7 @@ def completion( custom_api_base=custom_api_base, litellm_call_id=litellm_call_id, model_alias_map=litellm.model_alias_map, - completion_call_id=completion_call_id + completion_call_id=id ) logging.update_environment_variables(optional_params=optional_params, litellm_params=litellm_params) if custom_llm_provider == "azure": diff --git a/litellm/tests/test_split_test.py b/litellm/tests/test_split_test.py new file mode 100644 index 0000000000..35ff007be3 --- /dev/null +++ b/litellm/tests/test_split_test.py @@ -0,0 +1,24 @@ +#### What this tests #### +# This tests the 'completion_with_split_tests' function to enable a/b testing between llm models + +import sys, os +import traceback + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import litellm +from litellm import completion_with_split_tests +litellm.set_verbose = True +split_per_model = { + "gpt-4": 0.7, + "claude-instant-1.2": 0.3 +} + +messages = [{ "content": "Hello, how are you?","role": "user"}] + +# print(completion_with_split_tests(models=split_per_model, messages=messages)) + +# test with client + +print(completion_with_split_tests(models=split_per_model, messages=messages, use_client=True, id=1234)) \ No newline at end of file diff --git a/litellm/utils.py b/litellm/utils.py index def92ec093..3aaf00b9ec 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1898,6 +1898,58 @@ async def stream_to_string(generator): return response +########## experimental completion variants ############################ + +def get_model_split_test(models, completion_call_id): + global last_fetched_at + try: + # make the api call + last_fetched_at = time.time() + print(f"last_fetched_at: {last_fetched_at}") + response = requests.post( + #http://api.litellm.ai + url="http://api.litellm.ai/get_model_split_test", # get the updated dict from table or update the table with the dict + headers={"content-type": "application/json"}, + data=json.dumps({"completion_call_id": completion_call_id, "models": models}), + ) + print_verbose(f"get_model_list response: {response.text}") + data = response.json() + # update model list + split_test_models = data["split_test_models"] + # update environment - if required + threading.Thread(target=get_all_keys, args=()).start() + return split_test_models + except: + print_verbose( + f"[Non-Blocking Error] get_all_keys error - {traceback.format_exc()}" + ) + + +def completion_with_split_tests(models={}, messages=[], use_client=False, **kwargs): + """ + Example Usage: + + models = { + "gpt-4": 0.7, + "huggingface/wizard-coder": 0.3 + } + messages = [{ "content": "Hello, how are you?","role": "user"}] + completion_with_split_tests(models=models, messages=messages) + """ + import random + if use_client: + if "id" not in kwargs or kwargs["id"] is None: + raise ValueError("Please tag this completion call, if you'd like to update it's split test values through the UI. - eg. `completion_with_split_tests(.., id=1234)`.") + # get the most recent model split list from server + models = get_model_split_test(models=models, completion_call_id=kwargs["id"]) + + try: + selected_llm = random.choices(list(models.keys()), weights=list(models.values()))[0] + except: + traceback.print_exc() + raise ValueError("""models does not follow the required format - {'model_name': 'split_percentage'}, e.g. {'gpt-4': 0.7, 'huggingface/wizard-coder': 0.3}""") + return litellm.completion(model=selected_llm, messages=messages, **kwargs) + def completion_with_fallbacks(**kwargs): response = None rate_limited_models = set() diff --git a/pyproject.toml b/pyproject.toml index eb5c9be7b6..d226ddfeba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "0.1.511" +version = "0.1.512" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT License"