From 0033613b9e2a2eaeafc7c105b31bcf0c094d75b0 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 12 Mar 2024 10:30:10 -0700 Subject: [PATCH] fix(openai.py): return model name with custom llm provider for openai compatible endpoints --- litellm/llms/openai.py | 2 ++ litellm/main.py | 1 + litellm/tests/test_completion.py | 1 + litellm/utils.py | 2 +- 4 files changed, 5 insertions(+), 1 deletion(-) diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index f65d96b11..ecc8d5f70 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -239,6 +239,7 @@ class OpenAIChatCompletion(BaseLLM): ) if custom_llm_provider != "openai": + model_response.model = f"{custom_llm_provider}/{model}" # process all OpenAI compatible provider logic here if custom_llm_provider == "mistral": # check if message content passed in as list, and not string @@ -254,6 +255,7 @@ class OpenAIChatCompletion(BaseLLM): messages=messages, custom_llm_provider=custom_llm_provider, ) + for _ in range( 2 ): # if call fails due to alternating messages, retry with reformatted message diff --git a/litellm/main.py b/litellm/main.py index 114b46948..4a4d4aaa6 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -875,6 +875,7 @@ def completion( custom_prompt_dict=custom_prompt_dict, client=client, # pass AsyncOpenAI, OpenAI client organization=organization, + custom_llm_provider=custom_llm_provider, ) except Exception as e: ## LOGGING - log the original exception returned diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 729bf7bd9..6531f1cb0 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -289,6 +289,7 @@ def test_completion_mistral_api(): cost = litellm.completion_cost(completion_response=response) print("cost to make mistral completion=", cost) assert cost > 0.0 + assert response.model == "mistral/mistral-tiny" except Exception as e: pytest.fail(f"Error occurred: {e}") diff --git a/litellm/utils.py b/litellm/utils.py index 3b6169770..4cf99c3bb 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -6422,7 +6422,7 @@ def convert_to_model_response_object( "system_fingerprint" ] - if "model" in response_object: + if "model" in response_object and model_response_object.model is None: model_response_object.model = response_object["model"] if start_time is not None and end_time is not None: