fix(openai.py): return model name with custom llm provider for openai compatible endpoints

This commit is contained in:
Krrish Dholakia 2024-03-12 10:30:10 -07:00
parent 10f5f342bd
commit 0033613b9e
4 changed files with 5 additions and 1 deletions

View file

@ -239,6 +239,7 @@ class OpenAIChatCompletion(BaseLLM):
)
if custom_llm_provider != "openai":
model_response.model = f"{custom_llm_provider}/{model}"
# process all OpenAI compatible provider logic here
if custom_llm_provider == "mistral":
# check if message content passed in as list, and not string
@ -254,6 +255,7 @@ class OpenAIChatCompletion(BaseLLM):
messages=messages,
custom_llm_provider=custom_llm_provider,
)
for _ in range(
2
): # if call fails due to alternating messages, retry with reformatted message

View file

@ -875,6 +875,7 @@ def completion(
custom_prompt_dict=custom_prompt_dict,
client=client, # pass AsyncOpenAI, OpenAI client
organization=organization,
custom_llm_provider=custom_llm_provider,
)
except Exception as e:
## LOGGING - log the original exception returned

View file

@ -289,6 +289,7 @@ def test_completion_mistral_api():
cost = litellm.completion_cost(completion_response=response)
print("cost to make mistral completion=", cost)
assert cost > 0.0
assert response.model == "mistral/mistral-tiny"
except Exception as e:
pytest.fail(f"Error occurred: {e}")

View file

@ -6422,7 +6422,7 @@ def convert_to_model_response_object(
"system_fingerprint"
]
if "model" in response_object:
if "model" in response_object and model_response_object.model is None:
model_response_object.model = response_object["model"]
if start_time is not None and end_time is not None: