add model load testing functionality

This commit is contained in:
Krrish Dholakia 2023-08-11 17:59:51 -07:00
parent a2efd32f5c
commit 211e1edfcb
8 changed files with 35 additions and 8 deletions

View file

@ -59,7 +59,7 @@ def completion(
# params to identify the model
model=model, replicate=replicate, hugging_face=hugging_face, together_ai=together_ai
)
if azure == True:
if azure == True or custom_llm_provider == "azure": # [TODO]: remove azure=True flag, move to 'custom_llm_provider' approach
# azure configs
openai.api_type = "azure"
openai.api_base = litellm.api_base if litellm.api_base is not None else get_secret("AZURE_API_BASE")
@ -153,7 +153,7 @@ def completion(
model_response["model"] = model
model_response["usage"] = response["usage"]
response = model_response
elif "replicate" in model or replicate == True:
elif "replicate" in model or replicate == True or custom_llm_provider == "replicate":
# import replicate/if it fails then pip install replicate
install_and_import("replicate")
import replicate
@ -256,7 +256,7 @@ def completion(
}
response = model_response
elif model in litellm.openrouter_models:
elif model in litellm.openrouter_models or custom_llm_provider == "openrouter":
openai.api_type = "openai"
# not sure if this will work after someone first uses another API
openai.api_base = litellm.api_base if litellm.api_base is not None else "https://openrouter.ai/api/v1"
@ -338,7 +338,7 @@ def completion(
"total_tokens": prompt_tokens + completion_tokens
}
response = model_response
elif hugging_face == True:
elif hugging_face == True or custom_llm_provider == "huggingface":
import requests
API_URL = f"https://api-inference.huggingface.co/models/{model}"
HF_TOKEN = get_secret("HF_TOKEN")
@ -364,7 +364,7 @@ def completion(
"total_tokens": prompt_tokens + completion_tokens
}
response = model_response
elif together_ai == True:
elif together_ai == True or custom_llm_provider == "together_ai":
import requests
TOGETHER_AI_TOKEN = get_secret("TOGETHER_AI_TOKEN")
headers = {"Authorization": f"Bearer {TOGETHER_AI_TOKEN}"}
@ -430,7 +430,7 @@ def completion(
## LOGGING
logging(model=model, input=messages, azure=azure, logger_fn=logger_fn)
args = locals()
raise ValueError(f"No valid completion model args passed in - {args}")
raise ValueError(f"Invalid completion model args passed in. Check your input - {args}")
return response
except Exception as e:
## LOGGING