LiteLLM Minor Fixes & Improvements (10/07/2024) (#6101)

* fix(utils.py): support dropping temperature param for azure o1 models

* fix(main.py): handle azure o1 streaming requests

o1 doesn't support streaming, fake it to ensure code works as expected

* feat(utils.py): expose `hosted_vllm/` endpoint, with tool handling for vllm

Fixes https://github.com/BerriAI/litellm/issues/6088

* refactor(internal_user_endpoints.py): cleanup unused params + update docstring

Closes https://github.com/BerriAI/litellm/issues/6100

* fix(main.py): expose custom image generation api support

Fixes https://github.com/BerriAI/litellm/issues/6097

* fix: fix linting errors

* docs(custom_llm_server.md): add docs on custom api for image gen calls

* fix(types/utils.py): handle dict type

* fix(types/utils.py): fix linting errors
This commit is contained in:
Krish Dholakia 2024-10-08 01:17:22 -04:00 committed by GitHub
parent 5de69cb1b2
commit 6729c9ca7f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 643 additions and 76 deletions

View file

@ -183,11 +183,80 @@ class UnixTimeLLM(CustomLLM):
unixtime = UnixTimeLLM() unixtime = UnixTimeLLM()
``` ```
## Image Generation
1. Setup your `custom_handler.py` file
```python
import litellm
from litellm import CustomLLM
from litellm.types.utils import ImageResponse, ImageObject
class MyCustomLLM(CustomLLM):
async def aimage_generation(self, model: str, prompt: str, model_response: ImageResponse, optional_params: dict, logging_obj: Any, timeout: Optional[Union[float, httpx.Timeout]] = None, client: Optional[AsyncHTTPHandler] = None,) -> ImageResponse:
return ImageResponse(
created=int(time.time()),
data=[ImageObject(url="https://example.com/image.png")],
)
my_custom_llm = MyCustomLLM()
```
2. Add to `config.yaml`
In the config below, we pass
python_filename: `custom_handler.py`
custom_handler_instance_name: `my_custom_llm`. This is defined in Step 1
custom_handler: `custom_handler.my_custom_llm`
```yaml
model_list:
- model_name: "test-model"
litellm_params:
model: "openai/text-embedding-ada-002"
- model_name: "my-custom-model"
litellm_params:
model: "my-custom-llm/my-model"
litellm_settings:
custom_provider_map:
- {"provider": "my-custom-llm", "custom_handler": custom_handler.my_custom_llm}
```
```bash
litellm --config /path/to/config.yaml
```
3. Test it!
```bash
curl -X POST 'http://0.0.0.0:4000/v1/images/generations' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer sk-1234' \
-d '{
"model": "my-custom-model",
"prompt": "A cute baby sea otter",
}'
```
Expected Response
```
{
"created": 1721955063,
"data": [{"url": "https://example.com/image.png"}],
}
```
## Custom Handler Spec ## Custom Handler Spec
```python ```python
from litellm.types.utils import GenericStreamingChunk, ModelResponse from litellm.types.utils import GenericStreamingChunk, ModelResponse, ImageResponse
from typing import Iterator, AsyncIterator from typing import Iterator, AsyncIterator, Any, Optional, Union
from litellm.llms.base import BaseLLM from litellm.llms.base import BaseLLM
class CustomLLMError(Exception): # use this for all your exceptions class CustomLLMError(Exception): # use this for all your exceptions
@ -217,4 +286,28 @@ class CustomLLM(BaseLLM):
async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]: async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]:
raise CustomLLMError(status_code=500, message="Not implemented yet!") raise CustomLLMError(status_code=500, message="Not implemented yet!")
def image_generation(
self,
model: str,
prompt: str,
model_response: ImageResponse,
optional_params: dict,
logging_obj: Any,
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[HTTPHandler] = None,
) -> ImageResponse:
raise CustomLLMError(status_code=500, message="Not implemented yet!")
async def aimage_generation(
self,
model: str,
prompt: str,
model_response: ImageResponse,
optional_params: dict,
logging_obj: Any,
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[AsyncHTTPHandler] = None,
) -> ImageResponse:
raise CustomLLMError(status_code=500, message="Not implemented yet!")
``` ```

View file

@ -12,14 +12,14 @@ vLLM Provides an OpenAI compatible endpoints - here's how to call it with LiteLL
In order to use litellm to call a hosted vllm server add the following to your completion call In order to use litellm to call a hosted vllm server add the following to your completion call
* `model="openai/<your-vllm-model-name>"` * `model="hosted_vllm/<your-vllm-model-name>"`
* `api_base = "your-hosted-vllm-server"` * `api_base = "your-hosted-vllm-server"`
```python ```python
import litellm import litellm
response = litellm.completion( response = litellm.completion(
model="openai/facebook/opt-125m", # pass the vllm model name model="hosted_vllm/facebook/opt-125m", # pass the vllm model name
messages=messages, messages=messages,
api_base="https://hosted-vllm-api.co", api_base="https://hosted-vllm-api.co",
temperature=0.2, temperature=0.2,
@ -39,7 +39,7 @@ Here's how to call an OpenAI-Compatible Endpoint with the LiteLLM Proxy Server
model_list: model_list:
- model_name: my-model - model_name: my-model
litellm_params: litellm_params:
model: openai/facebook/opt-125m # add openai/ prefix to route as OpenAI provider model: hosted_vllm/facebook/opt-125m # add hosted_vllm/ prefix to route as OpenAI provider
api_base: https://hosted-vllm-api.co # add api base for OpenAI compatible provider api_base: https://hosted-vllm-api.co # add api base for OpenAI compatible provider
``` ```

View file

@ -504,11 +504,13 @@ openai_compatible_providers: List = [
"azure_ai", "azure_ai",
"github", "github",
"litellm_proxy", "litellm_proxy",
"hosted_vllm",
] ]
openai_text_completion_compatible_providers: List = ( openai_text_completion_compatible_providers: List = (
[ # providers that support `/v1/completions` [ # providers that support `/v1/completions`
"together_ai", "together_ai",
"fireworks_ai", "fireworks_ai",
"hosted_vllm",
] ]
) )
@ -758,6 +760,7 @@ class LlmProviders(str, Enum):
GITHUB = "github" GITHUB = "github"
CUSTOM = "custom" CUSTOM = "custom"
LITELLM_PROXY = "litellm_proxy" LITELLM_PROXY = "litellm_proxy"
HOSTED_VLLM = "hosted_vllm"
provider_list: List[Union[LlmProviders, str]] = list(LlmProviders) provider_list: List[Union[LlmProviders, str]] = list(LlmProviders)
@ -1003,6 +1006,8 @@ from .llms.AzureOpenAI.azure import (
AzureOpenAIError, AzureOpenAIError,
AzureOpenAIAssistantsAPIConfig, AzureOpenAIAssistantsAPIConfig,
) )
from .llms.hosted_vllm.chat.transformation import HostedVLLMChatConfig
from .llms.AzureOpenAI.chat.o1_transformation import AzureOpenAIO1Config
from .llms.watsonx import IBMWatsonXAIConfig from .llms.watsonx import IBMWatsonXAIConfig
from .main import * # type: ignore from .main import * # type: ignore
from .integrations import * from .integrations import *

View file

@ -206,6 +206,14 @@ def get_llm_provider(
or "https://codestral.mistral.ai/v1" or "https://codestral.mistral.ai/v1"
) # type: ignore ) # type: ignore
dynamic_api_key = api_key or get_secret("CODESTRAL_API_KEY") dynamic_api_key = api_key or get_secret("CODESTRAL_API_KEY")
elif custom_llm_provider == "hosted_vllm":
# vllm is openai compatible, we just need to set this to custom_openai
api_base = api_base or get_secret(
"HOSTED_VLLM_API_BASE"
) # type: ignore
dynamic_api_key = (
api_key or get_secret("HOSTED_VLLM_API_KEY") or ""
) # vllm does not require an api key
elif custom_llm_provider == "deepseek": elif custom_llm_provider == "deepseek":
# deepseek is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.deepseek.com/v1 # deepseek is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.deepseek.com/v1
api_base = ( api_base = (

View file

@ -0,0 +1,97 @@
"""
Handler file for calls to Azure OpenAI's o1 family of models
Written separately to handle faking streaming for o1 models.
"""
import asyncio
from typing import Any, Callable, List, Optional, Union
from httpx._config import Timeout
from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.llms.bedrock.chat.invoke_handler import MockResponseIterator
from litellm.types.utils import ModelResponse
from litellm.utils import CustomStreamWrapper
from ..azure import AzureChatCompletion
class AzureOpenAIO1ChatCompletion(AzureChatCompletion):
async def mock_async_streaming(
self,
response: Any,
model: Optional[str],
logging_obj: Any,
):
model_response = await response
completion_stream = MockResponseIterator(model_response=model_response)
streaming_response = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="azure",
logging_obj=logging_obj,
)
return streaming_response
def completion(
self,
model: str,
messages: List,
model_response: ModelResponse,
api_key: str,
api_base: str,
api_version: str,
api_type: str,
azure_ad_token: str,
dynamic_params: bool,
print_verbose: Callable[..., Any],
timeout: Union[float, Timeout],
logging_obj: Logging,
optional_params,
litellm_params,
logger_fn,
acompletion: bool = False,
headers: Optional[dict] = None,
client=None,
):
stream: Optional[bool] = optional_params.pop("stream", False)
response = super().completion(
model,
messages,
model_response,
api_key,
api_base,
api_version,
api_type,
azure_ad_token,
dynamic_params,
print_verbose,
timeout,
logging_obj,
optional_params,
litellm_params,
logger_fn,
acompletion,
headers,
client,
)
if stream is True:
if asyncio.iscoroutine(response):
return self.mock_async_streaming(
response=response, model=model, logging_obj=logging_obj # type: ignore
)
completion_stream = MockResponseIterator(model_response=response)
streaming_response = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="openai",
logging_obj=logging_obj,
)
return streaming_response
else:
return response

View file

@ -0,0 +1,30 @@
"""
Support for o1 model family
https://platform.openai.com/docs/guides/reasoning
Translations handled by LiteLLM:
- modalities: image => drop param (if user opts in to dropping param)
- role: system ==> translate to role 'user'
- streaming => faked by LiteLLM
- Tools, response_format => drop param (if user opts in to dropping param)
- Logprobs => drop param (if user opts in to dropping param)
- Temperature => drop param (if user opts in to dropping param)
"""
import types
from typing import Any, List, Optional, Union
import litellm
from litellm.types.llms.openai import AllMessageValues, ChatCompletionUserMessage
from ...OpenAI.chat.o1_transformation import OpenAIO1Config
class AzureOpenAIO1Config(OpenAIO1Config):
def is_o1_model(self, model: str) -> bool:
o1_models = ["o1-mini", "o1-preview"]
for m in o1_models:
if m in model:
return True
return False

View file

@ -36,7 +36,13 @@ import litellm
from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.utils import GenericStreamingChunk, ProviderField from litellm.types.utils import GenericStreamingChunk, ProviderField
from litellm.utils import CustomStreamWrapper, EmbeddingResponse, ModelResponse, Usage from litellm.utils import (
CustomStreamWrapper,
EmbeddingResponse,
ImageResponse,
ModelResponse,
Usage,
)
from .base import BaseLLM from .base import BaseLLM
from .prompt_templates.factory import custom_prompt, prompt_factory from .prompt_templates.factory import custom_prompt, prompt_factory
@ -143,6 +149,30 @@ class CustomLLM(BaseLLM):
) -> AsyncIterator[GenericStreamingChunk]: ) -> AsyncIterator[GenericStreamingChunk]:
raise CustomLLMError(status_code=500, message="Not implemented yet!") raise CustomLLMError(status_code=500, message="Not implemented yet!")
def image_generation(
self,
model: str,
prompt: str,
model_response: ImageResponse,
optional_params: dict,
logging_obj: Any,
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[HTTPHandler] = None,
) -> ImageResponse:
raise CustomLLMError(status_code=500, message="Not implemented yet!")
async def aimage_generation(
self,
model: str,
prompt: str,
model_response: ImageResponse,
optional_params: dict,
logging_obj: Any,
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[AsyncHTTPHandler] = None,
) -> ImageResponse:
raise CustomLLMError(status_code=500, message="Not implemented yet!")
def custom_chat_llm_router( def custom_chat_llm_router(
async_fn: bool, stream: Optional[bool], custom_llm: CustomLLM async_fn: bool, stream: Optional[bool], custom_llm: CustomLLM

View file

@ -0,0 +1,34 @@
"""
Translate from OpenAI's `/v1/chat/completions` to VLLM's `/v1/chat/completions`
"""
import types
from typing import List, Optional, Union
from pydantic import BaseModel
import litellm
from litellm.types.llms.openai import AllMessageValues, ChatCompletionAssistantMessage
from ....utils import _remove_additional_properties, _remove_strict_from_schema
from ...OpenAI.chat.gpt_transformation import OpenAIGPTConfig
class HostedVLLMChatConfig(OpenAIGPTConfig):
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
_tools = non_default_params.pop("tools", None)
if _tools is not None:
# remove 'additionalProperties' from tools
_tools = _remove_additional_properties(_tools)
# remove 'strict' from tools
_tools = _remove_strict_from_schema(_tools)
non_default_params["tools"] = _tools
return super().map_openai_params(
non_default_params, optional_params, model, drop_params
)

View file

@ -42,6 +42,7 @@ from litellm import ( # type: ignore
) )
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.secret_managers.main import get_secret_str from litellm.secret_managers.main import get_secret_str
from litellm.utils import ( from litellm.utils import (
CustomStreamWrapper, CustomStreamWrapper,
@ -89,6 +90,7 @@ from .llms.azure_ai.embed import AzureAIEmbedding
from .llms.azure_text import AzureTextCompletion from .llms.azure_text import AzureTextCompletion
from .llms.AzureOpenAI.audio_transcriptions import AzureAudioTranscription from .llms.AzureOpenAI.audio_transcriptions import AzureAudioTranscription
from .llms.AzureOpenAI.azure import AzureChatCompletion, _check_dynamic_azure_params from .llms.AzureOpenAI.azure import AzureChatCompletion, _check_dynamic_azure_params
from .llms.AzureOpenAI.chat.o1_handler import AzureOpenAIO1ChatCompletion
from .llms.bedrock import image_generation as bedrock_image_generation # type: ignore from .llms.bedrock import image_generation as bedrock_image_generation # type: ignore
from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
from .llms.bedrock.embed.embedding import BedrockEmbedding from .llms.bedrock.embed.embedding import BedrockEmbedding
@ -178,6 +180,7 @@ azure_ai_embedding = AzureAIEmbedding()
anthropic_chat_completions = AnthropicChatCompletion() anthropic_chat_completions = AnthropicChatCompletion()
anthropic_text_completions = AnthropicTextCompletion() anthropic_text_completions = AnthropicTextCompletion()
azure_chat_completions = AzureChatCompletion() azure_chat_completions = AzureChatCompletion()
azure_o1_chat_completions = AzureOpenAIO1ChatCompletion()
azure_text_completions = AzureTextCompletion() azure_text_completions = AzureTextCompletion()
azure_audio_transcriptions = AzureAudioTranscription() azure_audio_transcriptions = AzureAudioTranscription()
huggingface = Huggingface() huggingface = Huggingface()
@ -1064,35 +1067,68 @@ def completion( # type: ignore
headers = headers or litellm.headers headers = headers or litellm.headers
## LOAD CONFIG - if set if (
config = litellm.AzureOpenAIConfig.get_config() litellm.enable_preview_features
for k, v in config.items(): and litellm.AzureOpenAIO1Config().is_o1_model(model=model)
if ( ):
k not in optional_params ## LOAD CONFIG - if set
): # completion(top_k=3) > azure_config(top_k=3) <- allows for dynamic variables to be passed in config = litellm.AzureOpenAIO1Config.get_config()
optional_params[k] = v for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > azure_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
## COMPLETION CALL response = azure_o1_chat_completions.completion(
response = azure_chat_completions.completion( model=model,
model=model, messages=messages,
messages=messages, headers=headers,
headers=headers, api_key=api_key,
api_key=api_key, api_base=api_base,
api_base=api_base, api_version=api_version,
api_version=api_version, api_type=api_type,
api_type=api_type, dynamic_params=dynamic_params,
dynamic_params=dynamic_params, azure_ad_token=azure_ad_token,
azure_ad_token=azure_ad_token, model_response=model_response,
model_response=model_response, print_verbose=print_verbose,
print_verbose=print_verbose, optional_params=optional_params,
optional_params=optional_params, litellm_params=litellm_params,
litellm_params=litellm_params, logger_fn=logger_fn,
logger_fn=logger_fn, logging_obj=logging,
logging_obj=logging, acompletion=acompletion,
acompletion=acompletion, timeout=timeout, # type: ignore
timeout=timeout, # type: ignore client=client, # pass AsyncAzureOpenAI, AzureOpenAI client
client=client, # pass AsyncAzureOpenAI, AzureOpenAI client )
) else:
## LOAD CONFIG - if set
config = litellm.AzureOpenAIConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > azure_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
## COMPLETION CALL
response = azure_chat_completions.completion(
model=model,
messages=messages,
headers=headers,
api_key=api_key,
api_base=api_base,
api_version=api_version,
api_type=api_type,
dynamic_params=dynamic_params,
azure_ad_token=azure_ad_token,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
logging_obj=logging,
acompletion=acompletion,
timeout=timeout, # type: ignore
client=client, # pass AsyncAzureOpenAI, AzureOpenAI client
)
if optional_params.get("stream", False): if optional_params.get("stream", False):
## LOGGING ## LOGGING
@ -4582,6 +4618,7 @@ def image_generation(
Currently supports just Azure + OpenAI. Currently supports just Azure + OpenAI.
""" """
try: try:
args = locals()
aimg_generation = kwargs.get("aimg_generation", False) aimg_generation = kwargs.get("aimg_generation", False)
litellm_call_id = kwargs.get("litellm_call_id", None) litellm_call_id = kwargs.get("litellm_call_id", None)
logger_fn = kwargs.get("logger_fn", None) logger_fn = kwargs.get("logger_fn", None)
@ -4787,6 +4824,51 @@ def image_generation(
vertex_credentials=vertex_credentials, vertex_credentials=vertex_credentials,
aimg_generation=aimg_generation, aimg_generation=aimg_generation,
) )
elif (
custom_llm_provider in litellm._custom_providers
): # Assume custom LLM provider
# Get the Custom Handler
custom_handler: Optional[CustomLLM] = None
for item in litellm.custom_provider_map:
if item["provider"] == custom_llm_provider:
custom_handler = item["custom_handler"]
if custom_handler is None:
raise ValueError(
f"Unable to map your input to a model. Check your input - {args}"
)
## ROUTE LLM CALL ##
if aimg_generation is True:
async_custom_client: Optional[AsyncHTTPHandler] = None
if client is not None and isinstance(client, AsyncHTTPHandler):
async_custom_client = client
## CALL FUNCTION
model_response = custom_handler.aimage_generation( # type: ignore
model=model,
prompt=prompt,
model_response=model_response,
optional_params=optional_params,
logging_obj=litellm_logging_obj,
timeout=timeout,
client=async_custom_client,
)
else:
custom_client: Optional[HTTPHandler] = None
if client is not None and isinstance(client, HTTPHandler):
custom_client = client
## CALL FUNCTION
model_response = custom_handler.image_generation(
model=model,
prompt=prompt,
model_response=model_response,
optional_params=optional_params,
logging_obj=litellm_logging_obj,
timeout=timeout,
client=custom_client,
)
return model_response return model_response
except Exception as e: except Exception as e:

View file

@ -299,17 +299,13 @@ async def user_info(
user_id: Optional[str] = fastapi.Query( user_id: Optional[str] = fastapi.Query(
default=None, description="User ID in the request parameters" default=None, description="User ID in the request parameters"
), ),
page: Optional[int] = fastapi.Query(
default=0,
description="Page number for pagination. Only use when view_all is true",
),
page_size: Optional[int] = fastapi.Query(
default=25,
description="Number of items per page. Only use when view_all is true",
),
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
): ):
""" """
[10/07/2024]
Note: To get all users (+pagination), use `/user/list` endpoint.
Use this to get user information. (user row + all user key info) Use this to get user information. (user row + all user key info)
Example request Example request

View file

@ -1018,7 +1018,10 @@ class TextCompletionResponse(OpenAIObject):
setattr(self, key, value) setattr(self, key, value)
class ImageObject(OpenAIObject): from openai.types.images_response import Image as OpenAIImage
class ImageObject(OpenAIImage):
""" """
Represents the url or the content of an image generated by the OpenAI API. Represents the url or the content of an image generated by the OpenAI API.
@ -1070,7 +1073,7 @@ class ImageResponse(OpenAIImageResponse):
def __init__( def __init__(
self, self,
created: Optional[int] = None, created: Optional[int] = None,
data: Optional[list] = None, data: Optional[List[ImageObject]] = None,
response_ms=None, response_ms=None,
): ):
if response_ms: if response_ms:
@ -1087,7 +1090,13 @@ class ImageResponse(OpenAIImageResponse):
else: else:
created = int(time.time()) created = int(time.time())
super().__init__(created=created, data=data) _data: List[OpenAIImage] = []
for d in data:
if isinstance(d, dict):
_data.append(ImageObject(**d))
elif isinstance(d, BaseModel):
_data.append(ImageObject(**d.model_dump()))
super().__init__(created=created, data=_data)
self.usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} self.usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
def __contains__(self, key): def __contains__(self, key):

View file

@ -2788,6 +2788,24 @@ def _remove_additional_properties(schema):
return schema return schema
def _remove_strict_from_schema(schema):
if isinstance(schema, dict):
# Remove the 'additionalProperties' key if it exists and is set to False
if "strict" in schema:
del schema["strict"]
# Recursively process all dictionary values
for key, value in schema.items():
_remove_strict_from_schema(value)
elif isinstance(schema, list):
# Recursively process all items in the list
for item in schema:
_remove_strict_from_schema(item)
return schema
def get_optional_params( def get_optional_params(
# use the openai defaults # use the openai defaults
# https://platform.openai.com/docs/api-reference/chat/create # https://platform.openai.com/docs/api-reference/chat/create
@ -2999,13 +3017,19 @@ def get_optional_params(
from litellm.llms.vertex_ai_and_google_ai_studio.common_utils import ( from litellm.llms.vertex_ai_and_google_ai_studio.common_utils import (
_build_vertex_schema, _build_vertex_schema,
) )
old_schema = copy.deepcopy( old_schema = copy.deepcopy(
non_default_params["response_format"] non_default_params["response_format"]
.get("json_schema", {}) .get("json_schema", {})
.get("schema") .get("schema")
) )
new_schema = _remove_additional_properties(schema=old_schema) new_schema = _remove_additional_properties(schema=old_schema)
new_schema = _build_vertex_schema(parameters=new_schema) if isinstance(new_schema, list):
for item in new_schema:
if isinstance(item, dict):
item = _build_vertex_schema(parameters=item)
elif isinstance(new_schema, dict):
new_schema = _build_vertex_schema(parameters=new_schema)
non_default_params["response_format"]["json_schema"]["schema"] = new_schema non_default_params["response_format"]["json_schema"]["schema"] = new_schema
if "tools" in non_default_params and isinstance( if "tools" in non_default_params and isinstance(
non_default_params, list non_default_params, list
@ -3767,6 +3791,21 @@ def get_optional_params(
optional_params=optional_params, optional_params=optional_params,
model=model, model=model,
) )
elif custom_llm_provider == "hosted_vllm":
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
optional_params = litellm.HostedVLLMChatConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
elif custom_llm_provider == "groq": elif custom_llm_provider == "groq":
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
@ -3926,24 +3965,36 @@ def get_optional_params(
model=model, custom_llm_provider="azure" model=model, custom_llm_provider="azure"
) )
_check_valid_arg(supported_params=supported_params) _check_valid_arg(supported_params=supported_params)
verbose_logger.debug( if litellm.AzureOpenAIO1Config().is_o1_model(model=model):
"Azure optional params - api_version: api_version={}, litellm.api_version={}, os.environ['AZURE_API_VERSION']={}".format( optional_params = litellm.AzureOpenAIO1Config().map_openai_params(
api_version, litellm.api_version, get_secret("AZURE_API_VERSION") non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
else:
verbose_logger.debug(
"Azure optional params - api_version: api_version={}, litellm.api_version={}, os.environ['AZURE_API_VERSION']={}".format(
api_version, litellm.api_version, get_secret("AZURE_API_VERSION")
)
)
api_version = (
api_version
or litellm.api_version
or get_secret("AZURE_API_VERSION")
or litellm.AZURE_DEFAULT_API_VERSION
)
optional_params = litellm.AzureOpenAIConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
api_version=api_version, # type: ignore
drop_params=drop_params,
) )
)
api_version = (
api_version
or litellm.api_version
or get_secret("AZURE_API_VERSION")
or litellm.AZURE_DEFAULT_API_VERSION
)
optional_params = litellm.AzureOpenAIConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
api_version=api_version, # type: ignore
drop_params=drop_params,
)
else: # assume passing in params for text-completion openai else: # assume passing in params for text-completion openai
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
model=model, custom_llm_provider="custom_openai" model=model, custom_llm_provider="custom_openai"
@ -4409,6 +4460,8 @@ def get_supported_openai_params(
"extra_headers", "extra_headers",
"extra_body", "extra_body",
] ]
elif custom_llm_provider == "hosted_vllm":
return litellm.HostedVLLMChatConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "deepseek": elif custom_llm_provider == "deepseek":
return [ return [
# https://platform.deepseek.com/api-docs/api/create-chat-completion # https://platform.deepseek.com/api-docs/api/create-chat-completion
@ -4465,7 +4518,12 @@ def get_supported_openai_params(
elif custom_llm_provider == "openai": elif custom_llm_provider == "openai":
return litellm.OpenAIConfig().get_supported_openai_params(model=model) return litellm.OpenAIConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "azure": elif custom_llm_provider == "azure":
return litellm.AzureOpenAIConfig().get_supported_openai_params() if litellm.AzureOpenAIO1Config().is_o1_model(model=model):
return litellm.AzureOpenAIO1Config().get_supported_openai_params(
model=model
)
else:
return litellm.AzureOpenAIConfig().get_supported_openai_params()
elif custom_llm_provider == "openrouter": elif custom_llm_provider == "openrouter":
return [ return [
"temperature", "temperature",

View file

@ -620,16 +620,28 @@ def test_o1_model_params():
assert optional_params["user"] == "John" assert optional_params["user"] == "John"
def test_azure_o1_model_params():
optional_params = get_optional_params(
model="o1-preview",
custom_llm_provider="azure",
seed=10,
user="John",
)
assert optional_params["seed"] == 10
assert optional_params["user"] == "John"
@pytest.mark.parametrize( @pytest.mark.parametrize(
"temperature, expected_error", "temperature, expected_error",
[(0.2, True), (1, False)], [(0.2, True), (1, False)],
) )
def test_o1_model_temperature_params(temperature, expected_error): @pytest.mark.parametrize("provider", ["openai", "azure"])
def test_o1_model_temperature_params(provider, temperature, expected_error):
if expected_error: if expected_error:
with pytest.raises(litellm.UnsupportedParamsError): with pytest.raises(litellm.UnsupportedParamsError):
get_optional_params( get_optional_params(
model="o1-preview-2024-09-12", model="o1-preview",
custom_llm_provider="openai", custom_llm_provider=provider,
temperature=temperature, temperature=temperature,
) )
else: else:
@ -650,3 +662,45 @@ def test_unmapped_gemini_model_params():
stop="stop_word", stop="stop_word",
) )
assert optional_params["stop_sequences"] == ["stop_word"] assert optional_params["stop_sequences"] == ["stop_word"]
def test_drop_nested_params_vllm():
"""
Relevant issue - https://github.com/BerriAI/litellm/issues/5288
"""
tools = [
{
"type": "function",
"function": {
"name": "structure_output",
"description": "Send structured output back to the user",
"strict": True,
"parameters": {
"type": "object",
"properties": {
"reasoning": {"type": "string"},
"sentiment": {"type": "string"},
},
"required": ["reasoning", "sentiment"],
"additionalProperties": False,
},
"additionalProperties": False,
},
}
]
tool_choice = {"type": "function", "function": {"name": "structure_output"}}
optional_params = get_optional_params(
model="my-vllm-model",
custom_llm_provider="hosted_vllm",
temperature=0.2,
tools=tools,
tool_choice=tool_choice,
additional_drop_params=[
["tools", "function", "strict"],
["tools", "function", "additionalProperties"],
],
)
print(optional_params["tools"][0]["function"])
assert "additionalProperties" not in optional_params["tools"][0]["function"]
assert "strict" not in optional_params["tools"][0]["function"]

View file

@ -1929,7 +1929,7 @@ def test_hf_test_completion_tgi():
# hf_test_completion_tgi() # hf_test_completion_tgi()
@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai", @pytest.mark.parametrize("provider", ["openai", "hosted_vllm"]) # "vertex_ai",
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_openai_compatible_custom_api_base(provider): async def test_openai_compatible_custom_api_base(provider):
litellm.set_verbose = True litellm.set_verbose = True
@ -1947,15 +1947,15 @@ async def test_openai_compatible_custom_api_base(provider):
openai_client.chat.completions, "create", new=MagicMock() openai_client.chat.completions, "create", new=MagicMock()
) as mock_call: ) as mock_call:
try: try:
response = completion( completion(
model="openai/my-vllm-model", model="{provider}/my-vllm-model".format(provider=provider),
messages=messages, messages=messages,
response_format={"type": "json_object"}, response_format={"type": "json_object"},
client=openai_client, client=openai_client,
api_base="my-custom-api-base", api_base="my-custom-api-base",
hello="world", hello="world",
) )
except Exception as e: except Exception:
pass pass
mock_call.assert_called_once() mock_call.assert_called_once()

View file

@ -42,8 +42,11 @@ from litellm import (
acompletion, acompletion,
completion, completion,
get_llm_provider, get_llm_provider,
image_generation,
) )
from litellm.utils import ModelResponseIterator from litellm.utils import ModelResponseIterator
from litellm.types.utils import ImageResponse, ImageObject
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
class CustomModelResponseIterator: class CustomModelResponseIterator:
@ -219,6 +222,38 @@ class MyCustomLLM(CustomLLM):
yield generic_streaming_chunk # type: ignore yield generic_streaming_chunk # type: ignore
def image_generation(
self,
model: str,
prompt: str,
model_response: ImageResponse,
optional_params: dict,
logging_obj: Any,
timeout=None,
client: Optional[HTTPHandler] = None,
):
return ImageResponse(
created=int(time.time()),
data=[ImageObject(url="https://example.com/image.png")],
response_ms=1000,
)
async def aimage_generation(
self,
model: str,
prompt: str,
model_response: ImageResponse,
optional_params: dict,
logging_obj: Any,
timeout=None,
client: Optional[AsyncHTTPHandler] = None,
):
return ImageResponse(
created=int(time.time()),
data=[ImageObject(url="https://example.com/image.png")],
response_ms=1000,
)
def test_get_llm_provider(): def test_get_llm_provider():
"""""" """"""
@ -300,3 +335,30 @@ async def test_simple_completion_async_streaming():
assert isinstance(chunk.choices[0].delta.content, str) assert isinstance(chunk.choices[0].delta.content, str)
else: else:
assert chunk.choices[0].finish_reason == "stop" assert chunk.choices[0].finish_reason == "stop"
def test_simple_image_generation():
my_custom_llm = MyCustomLLM()
litellm.custom_provider_map = [
{"provider": "custom_llm", "custom_handler": my_custom_llm}
]
resp = image_generation(
model="custom_llm/my-fake-model",
prompt="Hello world",
)
print(resp)
@pytest.mark.asyncio
async def test_simple_image_generation_async():
my_custom_llm = MyCustomLLM()
litellm.custom_provider_map = [
{"provider": "custom_llm", "custom_handler": my_custom_llm}
]
resp = await litellm.aimage_generation(
model="custom_llm/my-fake-model",
prompt="Hello world",
)
print(resp)

View file

@ -2156,7 +2156,13 @@ def test_openai_chat_completion_complete_response_call():
# test_openai_chat_completion_complete_response_call() # test_openai_chat_completion_complete_response_call()
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model", "model",
["gpt-3.5-turbo", "azure/chatgpt-v-2", "claude-3-haiku-20240307", "o1-preview"], # [
"gpt-3.5-turbo",
"azure/chatgpt-v-2",
"claude-3-haiku-20240307",
"o1-preview",
"azure/fake-o1-mini",
],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"sync", "sync",
@ -2164,6 +2170,7 @@ def test_openai_chat_completion_complete_response_call():
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_openai_stream_options_call(model, sync): async def test_openai_stream_options_call(model, sync):
litellm.enable_preview_features = True
litellm.set_verbose = True litellm.set_verbose = True
usage = None usage = None
chunks = [] chunks = []
@ -2175,7 +2182,6 @@ async def test_openai_stream_options_call(model, sync):
], ],
stream=True, stream=True,
stream_options={"include_usage": True}, stream_options={"include_usage": True},
max_tokens=10,
) )
for chunk in response: for chunk in response:
print("chunk: ", chunk) print("chunk: ", chunk)
@ -2186,7 +2192,6 @@ async def test_openai_stream_options_call(model, sync):
messages=[{"role": "user", "content": "say GM - we're going to make it "}], messages=[{"role": "user", "content": "say GM - we're going to make it "}],
stream=True, stream=True,
stream_options={"include_usage": True}, stream_options={"include_usage": True},
max_tokens=10,
) )
async for chunk in response: async for chunk in response:

View file

@ -4223,7 +4223,8 @@ def mock_post(*args, **kwargs):
return mock_response return mock_response
def test_completion_vllm(): @pytest.mark.parametrize("provider", ["openai", "hosted_vllm"])
def test_completion_vllm(provider):
""" """
Asserts a text completion call for vllm actually goes to the text completion endpoint Asserts a text completion call for vllm actually goes to the text completion endpoint
""" """
@ -4235,7 +4236,10 @@ def test_completion_vllm():
client.completions.with_raw_response, "create", side_effect=mock_post client.completions.with_raw_response, "create", side_effect=mock_post
) as mock_call: ) as mock_call:
response = text_completion( response = text_completion(
model="openai/gemini-1.5-flash", prompt="ping", client=client, hello="world" model="{provider}/gemini-1.5-flash".format(provider=provider),
prompt="ping",
client=client,
hello="world",
) )
print("raw response", response) print("raw response", response)