forked from phoenix/litellm-mirror
fix(utils.py): add extra body params for text completion calls
This commit is contained in:
parent
12f4fb3a42
commit
fdb7101aaf
3 changed files with 45 additions and 4 deletions
|
@ -1436,6 +1436,43 @@ def test_hf_test_completion_tgi():
|
||||||
|
|
||||||
# hf_test_completion_tgi()
|
# hf_test_completion_tgi()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai",
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_compatible_custom_api_base(provider):
|
||||||
|
litellm.set_verbose = True
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hello world",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
openai_client = OpenAI(api_key="fake-key")
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
openai_client.chat.completions, "create", new=MagicMock()
|
||||||
|
) as mock_call:
|
||||||
|
try:
|
||||||
|
response = completion(
|
||||||
|
model="openai/my-vllm-model",
|
||||||
|
messages=messages,
|
||||||
|
response_format={"type": "json_object"},
|
||||||
|
client=openai_client,
|
||||||
|
api_base="my-custom-api-base",
|
||||||
|
hello="world",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
mock_call.assert_called_once()
|
||||||
|
|
||||||
|
print("Call KWARGS - {}".format(mock_call.call_args.kwargs))
|
||||||
|
|
||||||
|
assert "hello" in mock_call.call_args.kwargs["extra_body"]
|
||||||
|
|
||||||
|
|
||||||
# ################### Hugging Face Conversational models ########################
|
# ################### Hugging Face Conversational models ########################
|
||||||
# def hf_test_completion_conv():
|
# def hf_test_completion_conv():
|
||||||
# try:
|
# try:
|
||||||
|
|
|
@ -4189,12 +4189,12 @@ def test_completion_vllm():
|
||||||
|
|
||||||
with patch.object(client.completions, "create", side_effect=mock_post) as mock_call:
|
with patch.object(client.completions, "create", side_effect=mock_post) as mock_call:
|
||||||
response = text_completion(
|
response = text_completion(
|
||||||
model="openai/gemini-1.5-flash",
|
model="openai/gemini-1.5-flash", prompt="ping", client=client, hello="world"
|
||||||
prompt="ping",
|
|
||||||
client=client,
|
|
||||||
)
|
)
|
||||||
print(response)
|
print(response)
|
||||||
|
|
||||||
assert response.usage.prompt_tokens == 2
|
assert response.usage.prompt_tokens == 2
|
||||||
|
|
||||||
mock_call.assert_called_once()
|
mock_call.assert_called_once()
|
||||||
|
|
||||||
|
assert "hello" in mock_call.call_args.kwargs["extra_body"]
|
||||||
|
|
|
@ -3265,7 +3265,11 @@ def get_optional_params(
|
||||||
optional_params["top_logprobs"] = top_logprobs
|
optional_params["top_logprobs"] = top_logprobs
|
||||||
if extra_headers is not None:
|
if extra_headers is not None:
|
||||||
optional_params["extra_headers"] = extra_headers
|
optional_params["extra_headers"] = extra_headers
|
||||||
if custom_llm_provider in ["openai", "azure"] + litellm.openai_compatible_providers:
|
if (
|
||||||
|
custom_llm_provider
|
||||||
|
in ["openai", "azure", "text-completion-openai"]
|
||||||
|
+ litellm.openai_compatible_providers
|
||||||
|
):
|
||||||
# for openai, azure we should pass the extra/passed params within `extra_body` https://github.com/openai/openai-python/blob/ac33853ba10d13ac149b1fa3ca6dba7d613065c9/src/openai/resources/models.py#L46
|
# for openai, azure we should pass the extra/passed params within `extra_body` https://github.com/openai/openai-python/blob/ac33853ba10d13ac149b1fa3ca6dba7d613065c9/src/openai/resources/models.py#L46
|
||||||
extra_body = passed_params.pop("extra_body", {})
|
extra_body = passed_params.pop("extra_body", {})
|
||||||
for k in passed_params.keys():
|
for k in passed_params.keys():
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue