forked from phoenix/litellm-mirror
fix(utils.py): add extra body params for text completion calls
This commit is contained in:
parent
12f4fb3a42
commit
fdb7101aaf
3 changed files with 45 additions and 4 deletions
|
@ -1436,6 +1436,43 @@ def test_hf_test_completion_tgi():
|
|||
|
||||
# hf_test_completion_tgi()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai",
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_compatible_custom_api_base(provider):
|
||||
litellm.set_verbose = True
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello world",
|
||||
}
|
||||
]
|
||||
from openai import OpenAI
|
||||
|
||||
openai_client = OpenAI(api_key="fake-key")
|
||||
|
||||
with patch.object(
|
||||
openai_client.chat.completions, "create", new=MagicMock()
|
||||
) as mock_call:
|
||||
try:
|
||||
response = completion(
|
||||
model="openai/my-vllm-model",
|
||||
messages=messages,
|
||||
response_format={"type": "json_object"},
|
||||
client=openai_client,
|
||||
api_base="my-custom-api-base",
|
||||
hello="world",
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
mock_call.assert_called_once()
|
||||
|
||||
print("Call KWARGS - {}".format(mock_call.call_args.kwargs))
|
||||
|
||||
assert "hello" in mock_call.call_args.kwargs["extra_body"]
|
||||
|
||||
|
||||
# ################### Hugging Face Conversational models ########################
|
||||
# def hf_test_completion_conv():
|
||||
# try:
|
||||
|
|
|
@ -4189,12 +4189,12 @@ def test_completion_vllm():
|
|||
|
||||
with patch.object(client.completions, "create", side_effect=mock_post) as mock_call:
|
||||
response = text_completion(
|
||||
model="openai/gemini-1.5-flash",
|
||||
prompt="ping",
|
||||
client=client,
|
||||
model="openai/gemini-1.5-flash", prompt="ping", client=client, hello="world"
|
||||
)
|
||||
print(response)
|
||||
|
||||
assert response.usage.prompt_tokens == 2
|
||||
|
||||
mock_call.assert_called_once()
|
||||
|
||||
assert "hello" in mock_call.call_args.kwargs["extra_body"]
|
||||
|
|
|
@ -3265,7 +3265,11 @@ def get_optional_params(
|
|||
optional_params["top_logprobs"] = top_logprobs
|
||||
if extra_headers is not None:
|
||||
optional_params["extra_headers"] = extra_headers
|
||||
if custom_llm_provider in ["openai", "azure"] + litellm.openai_compatible_providers:
|
||||
if (
|
||||
custom_llm_provider
|
||||
in ["openai", "azure", "text-completion-openai"]
|
||||
+ litellm.openai_compatible_providers
|
||||
):
|
||||
# for openai, azure we should pass the extra/passed params within `extra_body` https://github.com/openai/openai-python/blob/ac33853ba10d13ac149b1fa3ca6dba7d613065c9/src/openai/resources/models.py#L46
|
||||
extra_body = passed_params.pop("extra_body", {})
|
||||
for k in passed_params.keys():
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue