Merge pull request #2473 from BerriAI/litellm_fix_compatible_provider_model_name

fix(openai.py): return model name with custom llm provider for openai-compatible endpoints (e.g. mistral, together ai, etc.)
This commit is contained in:
Krish Dholakia 2024-03-12 12:58:29 -07:00 committed by GitHub
commit 0d18f3c0ca
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 5 additions and 1 deletions

View file

@ -239,6 +239,7 @@ class OpenAIChatCompletion(BaseLLM):
) )
if custom_llm_provider != "openai": if custom_llm_provider != "openai":
model_response.model = f"{custom_llm_provider}/{model}"
# process all OpenAI compatible provider logic here # process all OpenAI compatible provider logic here
if custom_llm_provider == "mistral": if custom_llm_provider == "mistral":
# check if message content passed in as list, and not string # check if message content passed in as list, and not string
@ -254,6 +255,7 @@ class OpenAIChatCompletion(BaseLLM):
messages=messages, messages=messages,
custom_llm_provider=custom_llm_provider, custom_llm_provider=custom_llm_provider,
) )
for _ in range( for _ in range(
2 2
): # if call fails due to alternating messages, retry with reformatted message ): # if call fails due to alternating messages, retry with reformatted message

View file

@ -945,6 +945,7 @@ def completion(
custom_prompt_dict=custom_prompt_dict, custom_prompt_dict=custom_prompt_dict,
client=client, # pass AsyncOpenAI, OpenAI client client=client, # pass AsyncOpenAI, OpenAI client
organization=organization, organization=organization,
custom_llm_provider=custom_llm_provider,
) )
except Exception as e: except Exception as e:
## LOGGING - log the original exception returned ## LOGGING - log the original exception returned

View file

@ -289,6 +289,7 @@ def test_completion_mistral_api():
cost = litellm.completion_cost(completion_response=response) cost = litellm.completion_cost(completion_response=response)
print("cost to make mistral completion=", cost) print("cost to make mistral completion=", cost)
assert cost > 0.0 assert cost > 0.0
assert response.model == "mistral/mistral-tiny"
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")

View file

@ -6427,7 +6427,7 @@ def convert_to_model_response_object(
"system_fingerprint" "system_fingerprint"
] ]
if "model" in response_object: if "model" in response_object and model_response_object.model is None:
model_response_object.model = response_object["model"] model_response_object.model = response_object["model"]
if start_time is not None and end_time is not None: if start_time is not None and end_time is not None: