fix(utils.py): add extra body params for text completion calls

This commit is contained in:
Krrish Dholakia 2024-06-21 08:28:08 -07:00
parent 12f4fb3a42
commit fdb7101aaf
3 changed files with 45 additions and 4 deletions

View file

@ -1436,6 +1436,43 @@ def test_hf_test_completion_tgi():
# hf_test_completion_tgi()
@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai",
@pytest.mark.asyncio
async def test_openai_compatible_custom_api_base(provider):
litellm.set_verbose = True
messages = [
{
"role": "user",
"content": "Hello world",
}
]
from openai import OpenAI
openai_client = OpenAI(api_key="fake-key")
with patch.object(
openai_client.chat.completions, "create", new=MagicMock()
) as mock_call:
try:
response = completion(
model="openai/my-vllm-model",
messages=messages,
response_format={"type": "json_object"},
client=openai_client,
api_base="my-custom-api-base",
hello="world",
)
except Exception as e:
pass
mock_call.assert_called_once()
print("Call KWARGS - {}".format(mock_call.call_args.kwargs))
assert "hello" in mock_call.call_args.kwargs["extra_body"]
# ################### Hugging Face Conversational models ########################
# def hf_test_completion_conv():
# try:

View file

@ -4189,12 +4189,12 @@ def test_completion_vllm():
with patch.object(client.completions, "create", side_effect=mock_post) as mock_call:
response = text_completion(
model="openai/gemini-1.5-flash",
prompt="ping",
client=client,
model="openai/gemini-1.5-flash", prompt="ping", client=client, hello="world"
)
print(response)
assert response.usage.prompt_tokens == 2
mock_call.assert_called_once()
assert "hello" in mock_call.call_args.kwargs["extra_body"]

View file

@ -3265,7 +3265,11 @@ def get_optional_params(
optional_params["top_logprobs"] = top_logprobs
if extra_headers is not None:
optional_params["extra_headers"] = extra_headers
if custom_llm_provider in ["openai", "azure"] + litellm.openai_compatible_providers:
if (
custom_llm_provider
in ["openai", "azure", "text-completion-openai"]
+ litellm.openai_compatible_providers
):
# for openai, azure we should pass the extra/passed params within `extra_body` https://github.com/openai/openai-python/blob/ac33853ba10d13ac149b1fa3ca6dba7d613065c9/src/openai/resources/models.py#L46
extra_body = passed_params.pop("extra_body", {})
for k in passed_params.keys():