allow custom provider via model name

This commit is contained in:
Krrish Dholakia 2023-08-19 16:34:52 -07:00
parent 245b00f5f5
commit 6cee3b22ab
11 changed files with 15 additions and 0 deletions

View file

@ -94,6 +94,9 @@ def completion(
model_response = ModelResponse()
if azure: # this flag is deprecated, remove once notebooks are also updated.
custom_llm_provider = "azure"
elif model.split("/", 1)[0] in litellm.provider_list: # allow custom provider to be passed in via the model name "azure/chatgpt-test"
custom_llm_provider = model.split("/", 1)[0]
model = model.split("/", 1)[1]
args = locals()
# check if user passed in any of the OpenAI optional params
optional_params = get_optional_params(

View file

@ -25,6 +25,18 @@ def logger_fn(user_model_dict):
print(f"user_model_dict: {user_model_dict}")
def test_completion_custom_provider_model_name():
try:
response = completion(
model="together_ai/togethercomputer/llama-2-70b-chat", messages=messages, logger_fn=logger_fn
)
# Add any assertions here to check the response
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
test_completion_custom_provider_model_name()
def test_completion_claude():
try:
response = completion(