diff --git a/litellm/tests/test_split_test.py b/litellm/tests/test_split_test.py index 35ff007be..e26079642 100644 --- a/litellm/tests/test_split_test.py +++ b/litellm/tests/test_split_test.py @@ -11,8 +11,8 @@ import litellm from litellm import completion_with_split_tests litellm.set_verbose = True split_per_model = { - "gpt-4": 0.7, - "claude-instant-1.2": 0.3 + "gpt-3.5-turbo": 0.8, + "claude-instant-1.2": 0.1 } messages = [{ "content": "Hello, how are you?","role": "user"}] @@ -21,4 +21,4 @@ messages = [{ "content": "Hello, how are you?","role": "user"}] # test with client -print(completion_with_split_tests(models=split_per_model, messages=messages, use_client=True, id=1234)) \ No newline at end of file +print(completion_with_split_tests(models=split_per_model, messages=messages, use_client=True, id="ishaan_1@berri.ai")) \ No newline at end of file diff --git a/litellm/utils.py b/litellm/utils.py index 12260de02..9fda83790 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1934,9 +1934,10 @@ def get_model_split_test(models, completion_call_id): data = response.json() # update model list split_test_models = data["split_test_models"] + model_configs = data.get("model_configs", {}) # update environment - if required threading.Thread(target=get_all_keys, args=()).start() - return split_test_models + return split_test_models, model_configs except: print_verbose( f"[Non-Blocking Error] get_all_keys error - {traceback.format_exc()}" @@ -1959,13 +1960,28 @@ def completion_with_split_tests(models={}, messages=[], use_client=False, overri if "id" not in kwargs or kwargs["id"] is None: raise ValueError("Please tag this completion call, if you'd like to update it's split test values through the UI. - eg. `completion_with_split_tests(.., id=1234)`.") # get the most recent model split list from server - models = get_model_split_test(models=models, completion_call_id=kwargs["id"]) + models, model_configs = get_model_split_test(models=models, completion_call_id=kwargs["id"]) try: selected_llm = random.choices(list(models.keys()), weights=list(models.values()))[0] except: traceback.print_exc() raise ValueError("""models does not follow the required format - {'model_name': 'split_percentage'}, e.g. {'gpt-4': 0.7, 'huggingface/wizard-coder': 0.3}""") + + # use dynamic model configs if users set + if model_configs!={}: + selected_model_configs = model_configs.get(selected_llm, {}) + if "prompt" in selected_model_configs: # special case, add this to messages as system prompt + messages.append({"role": "system", "content": selected_model_configs["prompt"]}) + selected_model_configs.pop("prompt") + for param_name in selected_model_configs: + if param_name == "temperature": + kwargs[param_name] = float(selected_model_configs[param_name]) + elif param_name == "max_tokens": + kwargs[param_name] = int(selected_model_configs[param_name]) + else: + kwargs[param_name] = selected_model_configs[param_name] + return litellm.completion(model=selected_llm, messages=messages, use_client=use_client, **kwargs) def completion_with_fallbacks(**kwargs):