LiteLLM Minor Fixes & Improvements (10/07/2024) (#6101)

* fix(utils.py): support dropping temperature param for azure o1 models

* fix(main.py): handle azure o1 streaming requests

o1 doesn't support streaming, fake it to ensure code works as expected

* feat(utils.py): expose `hosted_vllm/` endpoint, with tool handling for vllm

Fixes https://github.com/BerriAI/litellm/issues/6088

* refactor(internal_user_endpoints.py): cleanup unused params + update docstring

Closes https://github.com/BerriAI/litellm/issues/6100

* fix(main.py): expose custom image generation api support

Fixes https://github.com/BerriAI/litellm/issues/6097

* fix: fix linting errors

* docs(custom_llm_server.md): add docs on custom api for image gen calls

* fix(types/utils.py): handle dict type

* fix(types/utils.py): fix linting errors
This commit is contained in:
Krish Dholakia 2024-10-08 01:17:22 -04:00 committed by GitHub
parent 5de69cb1b2
commit 6729c9ca7f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 643 additions and 76 deletions

View file

@ -620,16 +620,28 @@ def test_o1_model_params():
assert optional_params["user"] == "John"
def test_azure_o1_model_params():
optional_params = get_optional_params(
model="o1-preview",
custom_llm_provider="azure",
seed=10,
user="John",
)
assert optional_params["seed"] == 10
assert optional_params["user"] == "John"
@pytest.mark.parametrize(
"temperature, expected_error",
[(0.2, True), (1, False)],
)
def test_o1_model_temperature_params(temperature, expected_error):
@pytest.mark.parametrize("provider", ["openai", "azure"])
def test_o1_model_temperature_params(provider, temperature, expected_error):
if expected_error:
with pytest.raises(litellm.UnsupportedParamsError):
get_optional_params(
model="o1-preview-2024-09-12",
custom_llm_provider="openai",
model="o1-preview",
custom_llm_provider=provider,
temperature=temperature,
)
else:
@ -650,3 +662,45 @@ def test_unmapped_gemini_model_params():
stop="stop_word",
)
assert optional_params["stop_sequences"] == ["stop_word"]
def test_drop_nested_params_vllm():
"""
Relevant issue - https://github.com/BerriAI/litellm/issues/5288
"""
tools = [
{
"type": "function",
"function": {
"name": "structure_output",
"description": "Send structured output back to the user",
"strict": True,
"parameters": {
"type": "object",
"properties": {
"reasoning": {"type": "string"},
"sentiment": {"type": "string"},
},
"required": ["reasoning", "sentiment"],
"additionalProperties": False,
},
"additionalProperties": False,
},
}
]
tool_choice = {"type": "function", "function": {"name": "structure_output"}}
optional_params = get_optional_params(
model="my-vllm-model",
custom_llm_provider="hosted_vllm",
temperature=0.2,
tools=tools,
tool_choice=tool_choice,
additional_drop_params=[
["tools", "function", "strict"],
["tools", "function", "additionalProperties"],
],
)
print(optional_params["tools"][0]["function"])
assert "additionalProperties" not in optional_params["tools"][0]["function"]
assert "strict" not in optional_params["tools"][0]["function"]

View file

@ -1929,7 +1929,7 @@ def test_hf_test_completion_tgi():
# hf_test_completion_tgi()
@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai",
@pytest.mark.parametrize("provider", ["openai", "hosted_vllm"]) # "vertex_ai",
@pytest.mark.asyncio
async def test_openai_compatible_custom_api_base(provider):
litellm.set_verbose = True
@ -1947,15 +1947,15 @@ async def test_openai_compatible_custom_api_base(provider):
openai_client.chat.completions, "create", new=MagicMock()
) as mock_call:
try:
response = completion(
model="openai/my-vllm-model",
completion(
model="{provider}/my-vllm-model".format(provider=provider),
messages=messages,
response_format={"type": "json_object"},
client=openai_client,
api_base="my-custom-api-base",
hello="world",
)
except Exception as e:
except Exception:
pass
mock_call.assert_called_once()

View file

@ -42,8 +42,11 @@ from litellm import (
acompletion,
completion,
get_llm_provider,
image_generation,
)
from litellm.utils import ModelResponseIterator
from litellm.types.utils import ImageResponse, ImageObject
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
class CustomModelResponseIterator:
@ -219,6 +222,38 @@ class MyCustomLLM(CustomLLM):
yield generic_streaming_chunk # type: ignore
def image_generation(
self,
model: str,
prompt: str,
model_response: ImageResponse,
optional_params: dict,
logging_obj: Any,
timeout=None,
client: Optional[HTTPHandler] = None,
):
return ImageResponse(
created=int(time.time()),
data=[ImageObject(url="https://example.com/image.png")],
response_ms=1000,
)
async def aimage_generation(
self,
model: str,
prompt: str,
model_response: ImageResponse,
optional_params: dict,
logging_obj: Any,
timeout=None,
client: Optional[AsyncHTTPHandler] = None,
):
return ImageResponse(
created=int(time.time()),
data=[ImageObject(url="https://example.com/image.png")],
response_ms=1000,
)
def test_get_llm_provider():
""""""
@ -300,3 +335,30 @@ async def test_simple_completion_async_streaming():
assert isinstance(chunk.choices[0].delta.content, str)
else:
assert chunk.choices[0].finish_reason == "stop"
def test_simple_image_generation():
my_custom_llm = MyCustomLLM()
litellm.custom_provider_map = [
{"provider": "custom_llm", "custom_handler": my_custom_llm}
]
resp = image_generation(
model="custom_llm/my-fake-model",
prompt="Hello world",
)
print(resp)
@pytest.mark.asyncio
async def test_simple_image_generation_async():
my_custom_llm = MyCustomLLM()
litellm.custom_provider_map = [
{"provider": "custom_llm", "custom_handler": my_custom_llm}
]
resp = await litellm.aimage_generation(
model="custom_llm/my-fake-model",
prompt="Hello world",
)
print(resp)

View file

@ -2156,7 +2156,13 @@ def test_openai_chat_completion_complete_response_call():
# test_openai_chat_completion_complete_response_call()
@pytest.mark.parametrize(
"model",
["gpt-3.5-turbo", "azure/chatgpt-v-2", "claude-3-haiku-20240307", "o1-preview"], #
[
"gpt-3.5-turbo",
"azure/chatgpt-v-2",
"claude-3-haiku-20240307",
"o1-preview",
"azure/fake-o1-mini",
],
)
@pytest.mark.parametrize(
"sync",
@ -2164,6 +2170,7 @@ def test_openai_chat_completion_complete_response_call():
)
@pytest.mark.asyncio
async def test_openai_stream_options_call(model, sync):
litellm.enable_preview_features = True
litellm.set_verbose = True
usage = None
chunks = []
@ -2175,7 +2182,6 @@ async def test_openai_stream_options_call(model, sync):
],
stream=True,
stream_options={"include_usage": True},
max_tokens=10,
)
for chunk in response:
print("chunk: ", chunk)
@ -2186,7 +2192,6 @@ async def test_openai_stream_options_call(model, sync):
messages=[{"role": "user", "content": "say GM - we're going to make it "}],
stream=True,
stream_options={"include_usage": True},
max_tokens=10,
)
async for chunk in response:

View file

@ -4223,7 +4223,8 @@ def mock_post(*args, **kwargs):
return mock_response
def test_completion_vllm():
@pytest.mark.parametrize("provider", ["openai", "hosted_vllm"])
def test_completion_vllm(provider):
"""
Asserts a text completion call for vllm actually goes to the text completion endpoint
"""
@ -4235,7 +4236,10 @@ def test_completion_vllm():
client.completions.with_raw_response, "create", side_effect=mock_post
) as mock_call:
response = text_completion(
model="openai/gemini-1.5-flash", prompt="ping", client=client, hello="world"
model="{provider}/gemini-1.5-flash".format(provider=provider),
prompt="ping",
client=client,
hello="world",
)
print("raw response", response)