use dynamic config args

This commit is contained in:
ishaan-jaff 2023-09-01 15:39:28 -07:00
parent 34ed4cc23c
commit 51e486441d
2 changed files with 21 additions and 5 deletions

View file

@ -1934,9 +1934,10 @@ def get_model_split_test(models, completion_call_id):
data = response.json()
# update model list
split_test_models = data["split_test_models"]
model_configs = data.get("model_configs", {})
# update environment - if required
threading.Thread(target=get_all_keys, args=()).start()
return split_test_models
return split_test_models, model_configs
except:
print_verbose(
f"[Non-Blocking Error] get_all_keys error - {traceback.format_exc()}"
@ -1959,13 +1960,28 @@ def completion_with_split_tests(models={}, messages=[], use_client=False, overri
if "id" not in kwargs or kwargs["id"] is None:
raise ValueError("Please tag this completion call, if you'd like to update it's split test values through the UI. - eg. `completion_with_split_tests(.., id=1234)`.")
# get the most recent model split list from server
models = get_model_split_test(models=models, completion_call_id=kwargs["id"])
models, model_configs = get_model_split_test(models=models, completion_call_id=kwargs["id"])
try:
selected_llm = random.choices(list(models.keys()), weights=list(models.values()))[0]
except:
traceback.print_exc()
raise ValueError("""models does not follow the required format - {'model_name': 'split_percentage'}, e.g. {'gpt-4': 0.7, 'huggingface/wizard-coder': 0.3}""")
# use dynamic model configs if users set
if model_configs!={}:
selected_model_configs = model_configs.get(selected_llm, {})
if "prompt" in selected_model_configs: # special case, add this to messages as system prompt
messages.append({"role": "system", "content": selected_model_configs["prompt"]})
selected_model_configs.pop("prompt")
for param_name in selected_model_configs:
if param_name == "temperature":
kwargs[param_name] = float(selected_model_configs[param_name])
elif param_name == "max_tokens":
kwargs[param_name] = int(selected_model_configs[param_name])
else:
kwargs[param_name] = selected_model_configs[param_name]
return litellm.completion(model=selected_llm, messages=messages, use_client=use_client, **kwargs)
def completion_with_fallbacks(**kwargs):