LiteLLM Minor Fixes & Improvements (10/07/2024) (#6101)

* fix(utils.py): support dropping temperature param for azure o1 models

* fix(main.py): handle azure o1 streaming requests

o1 doesn't support streaming, fake it to ensure code works as expected

* feat(utils.py): expose `hosted_vllm/` endpoint, with tool handling for vllm

Fixes https://github.com/BerriAI/litellm/issues/6088

* refactor(internal_user_endpoints.py): cleanup unused params + update docstring

Closes https://github.com/BerriAI/litellm/issues/6100

* fix(main.py): expose custom image generation api support

Fixes https://github.com/BerriAI/litellm/issues/6097

* fix: fix linting errors

* docs(custom_llm_server.md): add docs on custom api for image gen calls

* fix(types/utils.py): handle dict type

* fix(types/utils.py): fix linting errors
This commit is contained in:
Krish Dholakia 2024-10-08 01:17:22 -04:00 committed by GitHub
parent 5de69cb1b2
commit 6729c9ca7f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 643 additions and 76 deletions

View file

@ -42,8 +42,11 @@ from litellm import (
acompletion,
completion,
get_llm_provider,
image_generation,
)
from litellm.utils import ModelResponseIterator
from litellm.types.utils import ImageResponse, ImageObject
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
class CustomModelResponseIterator:
@ -219,6 +222,38 @@ class MyCustomLLM(CustomLLM):
yield generic_streaming_chunk # type: ignore
def image_generation(
self,
model: str,
prompt: str,
model_response: ImageResponse,
optional_params: dict,
logging_obj: Any,
timeout=None,
client: Optional[HTTPHandler] = None,
):
return ImageResponse(
created=int(time.time()),
data=[ImageObject(url="https://example.com/image.png")],
response_ms=1000,
)
async def aimage_generation(
self,
model: str,
prompt: str,
model_response: ImageResponse,
optional_params: dict,
logging_obj: Any,
timeout=None,
client: Optional[AsyncHTTPHandler] = None,
):
return ImageResponse(
created=int(time.time()),
data=[ImageObject(url="https://example.com/image.png")],
response_ms=1000,
)
def test_get_llm_provider():
""""""
@ -300,3 +335,30 @@ async def test_simple_completion_async_streaming():
assert isinstance(chunk.choices[0].delta.content, str)
else:
assert chunk.choices[0].finish_reason == "stop"
def test_simple_image_generation():
my_custom_llm = MyCustomLLM()
litellm.custom_provider_map = [
{"provider": "custom_llm", "custom_handler": my_custom_llm}
]
resp = image_generation(
model="custom_llm/my-fake-model",
prompt="Hello world",
)
print(resp)
@pytest.mark.asyncio
async def test_simple_image_generation_async():
my_custom_llm = MyCustomLLM()
litellm.custom_provider_map = [
{"provider": "custom_llm", "custom_handler": my_custom_llm}
]
resp = await litellm.aimage_generation(
model="custom_llm/my-fake-model",
prompt="Hello world",
)
print(resp)