From 6729c9ca7f5b3ee95b9dac3a352ba9f2c55a7cd9 Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Tue, 8 Oct 2024 01:17:22 -0400 Subject: [PATCH] LiteLLM Minor Fixes & Improvements (10/07/2024) (#6101) * fix(utils.py): support dropping temperature param for azure o1 models * fix(main.py): handle azure o1 streaming requests o1 doesn't support streaming, fake it to ensure code works as expected * feat(utils.py): expose `hosted_vllm/` endpoint, with tool handling for vllm Fixes https://github.com/BerriAI/litellm/issues/6088 * refactor(internal_user_endpoints.py): cleanup unused params + update docstring Closes https://github.com/BerriAI/litellm/issues/6100 * fix(main.py): expose custom image generation api support Fixes https://github.com/BerriAI/litellm/issues/6097 * fix: fix linting errors * docs(custom_llm_server.md): add docs on custom api for image gen calls * fix(types/utils.py): handle dict type * fix(types/utils.py): fix linting errors --- .../docs/providers/custom_llm_server.md | 97 +++++++++++- docs/my-website/docs/providers/vllm.md | 6 +- litellm/__init__.py | 5 + .../get_llm_provider_logic.py | 8 + litellm/llms/AzureOpenAI/chat/o1_handler.py | 97 ++++++++++++ .../AzureOpenAI/chat/o1_transformation.py | 30 ++++ litellm/llms/custom_llm.py | 32 +++- .../llms/hosted_vllm/chat/transformation.py | 34 +++++ litellm/main.py | 138 ++++++++++++++---- .../internal_user_endpoints.py | 12 +- litellm/types/utils.py | 15 +- litellm/utils.py | 96 +++++++++--- tests/llm_translation/test_optional_params.py | 60 +++++++- tests/local_testing/test_completion.py | 8 +- tests/local_testing/test_custom_llm.py | 62 ++++++++ tests/local_testing/test_streaming.py | 11 +- tests/local_testing/test_text_completion.py | 8 +- 17 files changed, 643 insertions(+), 76 deletions(-) create mode 100644 litellm/llms/AzureOpenAI/chat/o1_handler.py create mode 100644 litellm/llms/AzureOpenAI/chat/o1_transformation.py create mode 100644 litellm/llms/hosted_vllm/chat/transformation.py diff --git a/docs/my-website/docs/providers/custom_llm_server.md b/docs/my-website/docs/providers/custom_llm_server.md index 0807e2650..6d2015010 100644 --- a/docs/my-website/docs/providers/custom_llm_server.md +++ b/docs/my-website/docs/providers/custom_llm_server.md @@ -183,11 +183,80 @@ class UnixTimeLLM(CustomLLM): unixtime = UnixTimeLLM() ``` +## Image Generation + +1. Setup your `custom_handler.py` file +```python +import litellm +from litellm import CustomLLM +from litellm.types.utils import ImageResponse, ImageObject + + +class MyCustomLLM(CustomLLM): + async def aimage_generation(self, model: str, prompt: str, model_response: ImageResponse, optional_params: dict, logging_obj: Any, timeout: Optional[Union[float, httpx.Timeout]] = None, client: Optional[AsyncHTTPHandler] = None,) -> ImageResponse: + return ImageResponse( + created=int(time.time()), + data=[ImageObject(url="https://example.com/image.png")], + ) + +my_custom_llm = MyCustomLLM() +``` + + +2. Add to `config.yaml` + +In the config below, we pass + +python_filename: `custom_handler.py` +custom_handler_instance_name: `my_custom_llm`. This is defined in Step 1 + +custom_handler: `custom_handler.my_custom_llm` + +```yaml +model_list: + - model_name: "test-model" + litellm_params: + model: "openai/text-embedding-ada-002" + - model_name: "my-custom-model" + litellm_params: + model: "my-custom-llm/my-model" + +litellm_settings: + custom_provider_map: + - {"provider": "my-custom-llm", "custom_handler": custom_handler.my_custom_llm} +``` + +```bash +litellm --config /path/to/config.yaml +``` + +3. Test it! + +```bash +curl -X POST 'http://0.0.0.0:4000/v1/images/generations' \ +-H 'Content-Type: application/json' \ +-H 'Authorization: Bearer sk-1234' \ +-d '{ + "model": "my-custom-model", + "prompt": "A cute baby sea otter", +}' +``` + +Expected Response + +``` +{ + "created": 1721955063, + "data": [{"url": "https://example.com/image.png"}], +} +``` + + ## Custom Handler Spec ```python -from litellm.types.utils import GenericStreamingChunk, ModelResponse -from typing import Iterator, AsyncIterator +from litellm.types.utils import GenericStreamingChunk, ModelResponse, ImageResponse +from typing import Iterator, AsyncIterator, Any, Optional, Union from litellm.llms.base import BaseLLM class CustomLLMError(Exception): # use this for all your exceptions @@ -217,4 +286,28 @@ class CustomLLM(BaseLLM): async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]: raise CustomLLMError(status_code=500, message="Not implemented yet!") + + def image_generation( + self, + model: str, + prompt: str, + model_response: ImageResponse, + optional_params: dict, + logging_obj: Any, + timeout: Optional[Union[float, httpx.Timeout]] = None, + client: Optional[HTTPHandler] = None, + ) -> ImageResponse: + raise CustomLLMError(status_code=500, message="Not implemented yet!") + + async def aimage_generation( + self, + model: str, + prompt: str, + model_response: ImageResponse, + optional_params: dict, + logging_obj: Any, + timeout: Optional[Union[float, httpx.Timeout]] = None, + client: Optional[AsyncHTTPHandler] = None, + ) -> ImageResponse: + raise CustomLLMError(status_code=500, message="Not implemented yet!") ``` diff --git a/docs/my-website/docs/providers/vllm.md b/docs/my-website/docs/providers/vllm.md index 61dd1fffd..5388a0bb7 100644 --- a/docs/my-website/docs/providers/vllm.md +++ b/docs/my-website/docs/providers/vllm.md @@ -12,14 +12,14 @@ vLLM Provides an OpenAI compatible endpoints - here's how to call it with LiteLL In order to use litellm to call a hosted vllm server add the following to your completion call -* `model="openai/"` +* `model="hosted_vllm/"` * `api_base = "your-hosted-vllm-server"` ```python import litellm response = litellm.completion( - model="openai/facebook/opt-125m", # pass the vllm model name + model="hosted_vllm/facebook/opt-125m", # pass the vllm model name messages=messages, api_base="https://hosted-vllm-api.co", temperature=0.2, @@ -39,7 +39,7 @@ Here's how to call an OpenAI-Compatible Endpoint with the LiteLLM Proxy Server model_list: - model_name: my-model litellm_params: - model: openai/facebook/opt-125m # add openai/ prefix to route as OpenAI provider + model: hosted_vllm/facebook/opt-125m # add hosted_vllm/ prefix to route as OpenAI provider api_base: https://hosted-vllm-api.co # add api base for OpenAI compatible provider ``` diff --git a/litellm/__init__.py b/litellm/__init__.py index 02cec3c12..2f8ae7cde 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -504,11 +504,13 @@ openai_compatible_providers: List = [ "azure_ai", "github", "litellm_proxy", + "hosted_vllm", ] openai_text_completion_compatible_providers: List = ( [ # providers that support `/v1/completions` "together_ai", "fireworks_ai", + "hosted_vllm", ] ) @@ -758,6 +760,7 @@ class LlmProviders(str, Enum): GITHUB = "github" CUSTOM = "custom" LITELLM_PROXY = "litellm_proxy" + HOSTED_VLLM = "hosted_vllm" provider_list: List[Union[LlmProviders, str]] = list(LlmProviders) @@ -1003,6 +1006,8 @@ from .llms.AzureOpenAI.azure import ( AzureOpenAIError, AzureOpenAIAssistantsAPIConfig, ) +from .llms.hosted_vllm.chat.transformation import HostedVLLMChatConfig +from .llms.AzureOpenAI.chat.o1_transformation import AzureOpenAIO1Config from .llms.watsonx import IBMWatsonXAIConfig from .main import * # type: ignore from .integrations import * diff --git a/litellm/litellm_core_utils/get_llm_provider_logic.py b/litellm/litellm_core_utils/get_llm_provider_logic.py index 41132a39e..d778ce723 100644 --- a/litellm/litellm_core_utils/get_llm_provider_logic.py +++ b/litellm/litellm_core_utils/get_llm_provider_logic.py @@ -206,6 +206,14 @@ def get_llm_provider( or "https://codestral.mistral.ai/v1" ) # type: ignore dynamic_api_key = api_key or get_secret("CODESTRAL_API_KEY") + elif custom_llm_provider == "hosted_vllm": + # vllm is openai compatible, we just need to set this to custom_openai + api_base = api_base or get_secret( + "HOSTED_VLLM_API_BASE" + ) # type: ignore + dynamic_api_key = ( + api_key or get_secret("HOSTED_VLLM_API_KEY") or "" + ) # vllm does not require an api key elif custom_llm_provider == "deepseek": # deepseek is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.deepseek.com/v1 api_base = ( diff --git a/litellm/llms/AzureOpenAI/chat/o1_handler.py b/litellm/llms/AzureOpenAI/chat/o1_handler.py new file mode 100644 index 000000000..45c35d627 --- /dev/null +++ b/litellm/llms/AzureOpenAI/chat/o1_handler.py @@ -0,0 +1,97 @@ +""" +Handler file for calls to Azure OpenAI's o1 family of models + +Written separately to handle faking streaming for o1 models. +""" + +import asyncio +from typing import Any, Callable, List, Optional, Union + +from httpx._config import Timeout + +from litellm.litellm_core_utils.litellm_logging import Logging +from litellm.llms.bedrock.chat.invoke_handler import MockResponseIterator +from litellm.types.utils import ModelResponse +from litellm.utils import CustomStreamWrapper + +from ..azure import AzureChatCompletion + + +class AzureOpenAIO1ChatCompletion(AzureChatCompletion): + + async def mock_async_streaming( + self, + response: Any, + model: Optional[str], + logging_obj: Any, + ): + model_response = await response + completion_stream = MockResponseIterator(model_response=model_response) + streaming_response = CustomStreamWrapper( + completion_stream=completion_stream, + model=model, + custom_llm_provider="azure", + logging_obj=logging_obj, + ) + return streaming_response + + def completion( + self, + model: str, + messages: List, + model_response: ModelResponse, + api_key: str, + api_base: str, + api_version: str, + api_type: str, + azure_ad_token: str, + dynamic_params: bool, + print_verbose: Callable[..., Any], + timeout: Union[float, Timeout], + logging_obj: Logging, + optional_params, + litellm_params, + logger_fn, + acompletion: bool = False, + headers: Optional[dict] = None, + client=None, + ): + stream: Optional[bool] = optional_params.pop("stream", False) + response = super().completion( + model, + messages, + model_response, + api_key, + api_base, + api_version, + api_type, + azure_ad_token, + dynamic_params, + print_verbose, + timeout, + logging_obj, + optional_params, + litellm_params, + logger_fn, + acompletion, + headers, + client, + ) + + if stream is True: + if asyncio.iscoroutine(response): + return self.mock_async_streaming( + response=response, model=model, logging_obj=logging_obj # type: ignore + ) + + completion_stream = MockResponseIterator(model_response=response) + streaming_response = CustomStreamWrapper( + completion_stream=completion_stream, + model=model, + custom_llm_provider="openai", + logging_obj=logging_obj, + ) + + return streaming_response + else: + return response diff --git a/litellm/llms/AzureOpenAI/chat/o1_transformation.py b/litellm/llms/AzureOpenAI/chat/o1_transformation.py new file mode 100644 index 000000000..e1677f681 --- /dev/null +++ b/litellm/llms/AzureOpenAI/chat/o1_transformation.py @@ -0,0 +1,30 @@ +""" +Support for o1 model family + +https://platform.openai.com/docs/guides/reasoning + +Translations handled by LiteLLM: +- modalities: image => drop param (if user opts in to dropping param) +- role: system ==> translate to role 'user' +- streaming => faked by LiteLLM +- Tools, response_format => drop param (if user opts in to dropping param) +- Logprobs => drop param (if user opts in to dropping param) +- Temperature => drop param (if user opts in to dropping param) +""" + +import types +from typing import Any, List, Optional, Union + +import litellm +from litellm.types.llms.openai import AllMessageValues, ChatCompletionUserMessage + +from ...OpenAI.chat.o1_transformation import OpenAIO1Config + + +class AzureOpenAIO1Config(OpenAIO1Config): + def is_o1_model(self, model: str) -> bool: + o1_models = ["o1-mini", "o1-preview"] + for m in o1_models: + if m in model: + return True + return False diff --git a/litellm/llms/custom_llm.py b/litellm/llms/custom_llm.py index 47c5a485c..89798eef5 100644 --- a/litellm/llms/custom_llm.py +++ b/litellm/llms/custom_llm.py @@ -36,7 +36,13 @@ import litellm from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.types.utils import GenericStreamingChunk, ProviderField -from litellm.utils import CustomStreamWrapper, EmbeddingResponse, ModelResponse, Usage +from litellm.utils import ( + CustomStreamWrapper, + EmbeddingResponse, + ImageResponse, + ModelResponse, + Usage, +) from .base import BaseLLM from .prompt_templates.factory import custom_prompt, prompt_factory @@ -143,6 +149,30 @@ class CustomLLM(BaseLLM): ) -> AsyncIterator[GenericStreamingChunk]: raise CustomLLMError(status_code=500, message="Not implemented yet!") + def image_generation( + self, + model: str, + prompt: str, + model_response: ImageResponse, + optional_params: dict, + logging_obj: Any, + timeout: Optional[Union[float, httpx.Timeout]] = None, + client: Optional[HTTPHandler] = None, + ) -> ImageResponse: + raise CustomLLMError(status_code=500, message="Not implemented yet!") + + async def aimage_generation( + self, + model: str, + prompt: str, + model_response: ImageResponse, + optional_params: dict, + logging_obj: Any, + timeout: Optional[Union[float, httpx.Timeout]] = None, + client: Optional[AsyncHTTPHandler] = None, + ) -> ImageResponse: + raise CustomLLMError(status_code=500, message="Not implemented yet!") + def custom_chat_llm_router( async_fn: bool, stream: Optional[bool], custom_llm: CustomLLM diff --git a/litellm/llms/hosted_vllm/chat/transformation.py b/litellm/llms/hosted_vllm/chat/transformation.py new file mode 100644 index 000000000..0b1259dbf --- /dev/null +++ b/litellm/llms/hosted_vllm/chat/transformation.py @@ -0,0 +1,34 @@ +""" +Translate from OpenAI's `/v1/chat/completions` to VLLM's `/v1/chat/completions` +""" + +import types +from typing import List, Optional, Union + +from pydantic import BaseModel + +import litellm +from litellm.types.llms.openai import AllMessageValues, ChatCompletionAssistantMessage + +from ....utils import _remove_additional_properties, _remove_strict_from_schema +from ...OpenAI.chat.gpt_transformation import OpenAIGPTConfig + + +class HostedVLLMChatConfig(OpenAIGPTConfig): + def map_openai_params( + self, + non_default_params: dict, + optional_params: dict, + model: str, + drop_params: bool, + ) -> dict: + _tools = non_default_params.pop("tools", None) + if _tools is not None: + # remove 'additionalProperties' from tools + _tools = _remove_additional_properties(_tools) + # remove 'strict' from tools + _tools = _remove_strict_from_schema(_tools) + non_default_params["tools"] = _tools + return super().map_openai_params( + non_default_params, optional_params, model, drop_params + ) diff --git a/litellm/main.py b/litellm/main.py index 87c169f4e..b53db67f4 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -42,6 +42,7 @@ from litellm import ( # type: ignore ) from litellm.integrations.custom_logger import CustomLogger from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.secret_managers.main import get_secret_str from litellm.utils import ( CustomStreamWrapper, @@ -89,6 +90,7 @@ from .llms.azure_ai.embed import AzureAIEmbedding from .llms.azure_text import AzureTextCompletion from .llms.AzureOpenAI.audio_transcriptions import AzureAudioTranscription from .llms.AzureOpenAI.azure import AzureChatCompletion, _check_dynamic_azure_params +from .llms.AzureOpenAI.chat.o1_handler import AzureOpenAIO1ChatCompletion from .llms.bedrock import image_generation as bedrock_image_generation # type: ignore from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM from .llms.bedrock.embed.embedding import BedrockEmbedding @@ -178,6 +180,7 @@ azure_ai_embedding = AzureAIEmbedding() anthropic_chat_completions = AnthropicChatCompletion() anthropic_text_completions = AnthropicTextCompletion() azure_chat_completions = AzureChatCompletion() +azure_o1_chat_completions = AzureOpenAIO1ChatCompletion() azure_text_completions = AzureTextCompletion() azure_audio_transcriptions = AzureAudioTranscription() huggingface = Huggingface() @@ -1064,35 +1067,68 @@ def completion( # type: ignore headers = headers or litellm.headers - ## LOAD CONFIG - if set - config = litellm.AzureOpenAIConfig.get_config() - for k, v in config.items(): - if ( - k not in optional_params - ): # completion(top_k=3) > azure_config(top_k=3) <- allows for dynamic variables to be passed in - optional_params[k] = v + if ( + litellm.enable_preview_features + and litellm.AzureOpenAIO1Config().is_o1_model(model=model) + ): + ## LOAD CONFIG - if set + config = litellm.AzureOpenAIO1Config.get_config() + for k, v in config.items(): + if ( + k not in optional_params + ): # completion(top_k=3) > azure_config(top_k=3) <- allows for dynamic variables to be passed in + optional_params[k] = v - ## COMPLETION CALL - response = azure_chat_completions.completion( - model=model, - messages=messages, - headers=headers, - api_key=api_key, - api_base=api_base, - api_version=api_version, - api_type=api_type, - dynamic_params=dynamic_params, - azure_ad_token=azure_ad_token, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - logging_obj=logging, - acompletion=acompletion, - timeout=timeout, # type: ignore - client=client, # pass AsyncAzureOpenAI, AzureOpenAI client - ) + response = azure_o1_chat_completions.completion( + model=model, + messages=messages, + headers=headers, + api_key=api_key, + api_base=api_base, + api_version=api_version, + api_type=api_type, + dynamic_params=dynamic_params, + azure_ad_token=azure_ad_token, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + logging_obj=logging, + acompletion=acompletion, + timeout=timeout, # type: ignore + client=client, # pass AsyncAzureOpenAI, AzureOpenAI client + ) + else: + ## LOAD CONFIG - if set + config = litellm.AzureOpenAIConfig.get_config() + for k, v in config.items(): + if ( + k not in optional_params + ): # completion(top_k=3) > azure_config(top_k=3) <- allows for dynamic variables to be passed in + optional_params[k] = v + + ## COMPLETION CALL + response = azure_chat_completions.completion( + model=model, + messages=messages, + headers=headers, + api_key=api_key, + api_base=api_base, + api_version=api_version, + api_type=api_type, + dynamic_params=dynamic_params, + azure_ad_token=azure_ad_token, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + logging_obj=logging, + acompletion=acompletion, + timeout=timeout, # type: ignore + client=client, # pass AsyncAzureOpenAI, AzureOpenAI client + ) if optional_params.get("stream", False): ## LOGGING @@ -4582,6 +4618,7 @@ def image_generation( Currently supports just Azure + OpenAI. """ try: + args = locals() aimg_generation = kwargs.get("aimg_generation", False) litellm_call_id = kwargs.get("litellm_call_id", None) logger_fn = kwargs.get("logger_fn", None) @@ -4787,6 +4824,51 @@ def image_generation( vertex_credentials=vertex_credentials, aimg_generation=aimg_generation, ) + elif ( + custom_llm_provider in litellm._custom_providers + ): # Assume custom LLM provider + # Get the Custom Handler + custom_handler: Optional[CustomLLM] = None + for item in litellm.custom_provider_map: + if item["provider"] == custom_llm_provider: + custom_handler = item["custom_handler"] + + if custom_handler is None: + raise ValueError( + f"Unable to map your input to a model. Check your input - {args}" + ) + + ## ROUTE LLM CALL ## + if aimg_generation is True: + async_custom_client: Optional[AsyncHTTPHandler] = None + if client is not None and isinstance(client, AsyncHTTPHandler): + async_custom_client = client + + ## CALL FUNCTION + model_response = custom_handler.aimage_generation( # type: ignore + model=model, + prompt=prompt, + model_response=model_response, + optional_params=optional_params, + logging_obj=litellm_logging_obj, + timeout=timeout, + client=async_custom_client, + ) + else: + custom_client: Optional[HTTPHandler] = None + if client is not None and isinstance(client, HTTPHandler): + custom_client = client + + ## CALL FUNCTION + model_response = custom_handler.image_generation( + model=model, + prompt=prompt, + model_response=model_response, + optional_params=optional_params, + logging_obj=litellm_logging_obj, + timeout=timeout, + client=custom_client, + ) return model_response except Exception as e: diff --git a/litellm/proxy/management_endpoints/internal_user_endpoints.py b/litellm/proxy/management_endpoints/internal_user_endpoints.py index 92f129576..e4b0e82a4 100644 --- a/litellm/proxy/management_endpoints/internal_user_endpoints.py +++ b/litellm/proxy/management_endpoints/internal_user_endpoints.py @@ -299,17 +299,13 @@ async def user_info( user_id: Optional[str] = fastapi.Query( default=None, description="User ID in the request parameters" ), - page: Optional[int] = fastapi.Query( - default=0, - description="Page number for pagination. Only use when view_all is true", - ), - page_size: Optional[int] = fastapi.Query( - default=25, - description="Number of items per page. Only use when view_all is true", - ), user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): """ + [10/07/2024] + Note: To get all users (+pagination), use `/user/list` endpoint. + + Use this to get user information. (user row + all user key info) Example request diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 4e4699afa..394c1db6c 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -1018,7 +1018,10 @@ class TextCompletionResponse(OpenAIObject): setattr(self, key, value) -class ImageObject(OpenAIObject): +from openai.types.images_response import Image as OpenAIImage + + +class ImageObject(OpenAIImage): """ Represents the url or the content of an image generated by the OpenAI API. @@ -1070,7 +1073,7 @@ class ImageResponse(OpenAIImageResponse): def __init__( self, created: Optional[int] = None, - data: Optional[list] = None, + data: Optional[List[ImageObject]] = None, response_ms=None, ): if response_ms: @@ -1087,7 +1090,13 @@ class ImageResponse(OpenAIImageResponse): else: created = int(time.time()) - super().__init__(created=created, data=data) + _data: List[OpenAIImage] = [] + for d in data: + if isinstance(d, dict): + _data.append(ImageObject(**d)) + elif isinstance(d, BaseModel): + _data.append(ImageObject(**d.model_dump())) + super().__init__(created=created, data=_data) self.usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} def __contains__(self, key): diff --git a/litellm/utils.py b/litellm/utils.py index c2d0f0b9f..e33096411 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2788,6 +2788,24 @@ def _remove_additional_properties(schema): return schema +def _remove_strict_from_schema(schema): + if isinstance(schema, dict): + # Remove the 'additionalProperties' key if it exists and is set to False + if "strict" in schema: + del schema["strict"] + + # Recursively process all dictionary values + for key, value in schema.items(): + _remove_strict_from_schema(value) + + elif isinstance(schema, list): + # Recursively process all items in the list + for item in schema: + _remove_strict_from_schema(item) + + return schema + + def get_optional_params( # use the openai defaults # https://platform.openai.com/docs/api-reference/chat/create @@ -2999,13 +3017,19 @@ def get_optional_params( from litellm.llms.vertex_ai_and_google_ai_studio.common_utils import ( _build_vertex_schema, ) + old_schema = copy.deepcopy( non_default_params["response_format"] .get("json_schema", {}) .get("schema") ) new_schema = _remove_additional_properties(schema=old_schema) - new_schema = _build_vertex_schema(parameters=new_schema) + if isinstance(new_schema, list): + for item in new_schema: + if isinstance(item, dict): + item = _build_vertex_schema(parameters=item) + elif isinstance(new_schema, dict): + new_schema = _build_vertex_schema(parameters=new_schema) non_default_params["response_format"]["json_schema"]["schema"] = new_schema if "tools" in non_default_params and isinstance( non_default_params, list @@ -3767,6 +3791,21 @@ def get_optional_params( optional_params=optional_params, model=model, ) + elif custom_llm_provider == "hosted_vllm": + supported_params = get_supported_openai_params( + model=model, custom_llm_provider=custom_llm_provider + ) + _check_valid_arg(supported_params=supported_params) + optional_params = litellm.HostedVLLMChatConfig().map_openai_params( + non_default_params=non_default_params, + optional_params=optional_params, + model=model, + drop_params=( + drop_params + if drop_params is not None and isinstance(drop_params, bool) + else False + ), + ) elif custom_llm_provider == "groq": supported_params = get_supported_openai_params( @@ -3926,24 +3965,36 @@ def get_optional_params( model=model, custom_llm_provider="azure" ) _check_valid_arg(supported_params=supported_params) - verbose_logger.debug( - "Azure optional params - api_version: api_version={}, litellm.api_version={}, os.environ['AZURE_API_VERSION']={}".format( - api_version, litellm.api_version, get_secret("AZURE_API_VERSION") + if litellm.AzureOpenAIO1Config().is_o1_model(model=model): + optional_params = litellm.AzureOpenAIO1Config().map_openai_params( + non_default_params=non_default_params, + optional_params=optional_params, + model=model, + drop_params=( + drop_params + if drop_params is not None and isinstance(drop_params, bool) + else False + ), + ) + else: + verbose_logger.debug( + "Azure optional params - api_version: api_version={}, litellm.api_version={}, os.environ['AZURE_API_VERSION']={}".format( + api_version, litellm.api_version, get_secret("AZURE_API_VERSION") + ) + ) + api_version = ( + api_version + or litellm.api_version + or get_secret("AZURE_API_VERSION") + or litellm.AZURE_DEFAULT_API_VERSION + ) + optional_params = litellm.AzureOpenAIConfig().map_openai_params( + non_default_params=non_default_params, + optional_params=optional_params, + model=model, + api_version=api_version, # type: ignore + drop_params=drop_params, ) - ) - api_version = ( - api_version - or litellm.api_version - or get_secret("AZURE_API_VERSION") - or litellm.AZURE_DEFAULT_API_VERSION - ) - optional_params = litellm.AzureOpenAIConfig().map_openai_params( - non_default_params=non_default_params, - optional_params=optional_params, - model=model, - api_version=api_version, # type: ignore - drop_params=drop_params, - ) else: # assume passing in params for text-completion openai supported_params = get_supported_openai_params( model=model, custom_llm_provider="custom_openai" @@ -4409,6 +4460,8 @@ def get_supported_openai_params( "extra_headers", "extra_body", ] + elif custom_llm_provider == "hosted_vllm": + return litellm.HostedVLLMChatConfig().get_supported_openai_params(model=model) elif custom_llm_provider == "deepseek": return [ # https://platform.deepseek.com/api-docs/api/create-chat-completion @@ -4465,7 +4518,12 @@ def get_supported_openai_params( elif custom_llm_provider == "openai": return litellm.OpenAIConfig().get_supported_openai_params(model=model) elif custom_llm_provider == "azure": - return litellm.AzureOpenAIConfig().get_supported_openai_params() + if litellm.AzureOpenAIO1Config().is_o1_model(model=model): + return litellm.AzureOpenAIO1Config().get_supported_openai_params( + model=model + ) + else: + return litellm.AzureOpenAIConfig().get_supported_openai_params() elif custom_llm_provider == "openrouter": return [ "temperature", diff --git a/tests/llm_translation/test_optional_params.py b/tests/llm_translation/test_optional_params.py index 877880e3d..f8f90fb6d 100644 --- a/tests/llm_translation/test_optional_params.py +++ b/tests/llm_translation/test_optional_params.py @@ -620,16 +620,28 @@ def test_o1_model_params(): assert optional_params["user"] == "John" +def test_azure_o1_model_params(): + optional_params = get_optional_params( + model="o1-preview", + custom_llm_provider="azure", + seed=10, + user="John", + ) + assert optional_params["seed"] == 10 + assert optional_params["user"] == "John" + + @pytest.mark.parametrize( "temperature, expected_error", [(0.2, True), (1, False)], ) -def test_o1_model_temperature_params(temperature, expected_error): +@pytest.mark.parametrize("provider", ["openai", "azure"]) +def test_o1_model_temperature_params(provider, temperature, expected_error): if expected_error: with pytest.raises(litellm.UnsupportedParamsError): get_optional_params( - model="o1-preview-2024-09-12", - custom_llm_provider="openai", + model="o1-preview", + custom_llm_provider=provider, temperature=temperature, ) else: @@ -650,3 +662,45 @@ def test_unmapped_gemini_model_params(): stop="stop_word", ) assert optional_params["stop_sequences"] == ["stop_word"] + + +def test_drop_nested_params_vllm(): + """ + Relevant issue - https://github.com/BerriAI/litellm/issues/5288 + """ + tools = [ + { + "type": "function", + "function": { + "name": "structure_output", + "description": "Send structured output back to the user", + "strict": True, + "parameters": { + "type": "object", + "properties": { + "reasoning": {"type": "string"}, + "sentiment": {"type": "string"}, + }, + "required": ["reasoning", "sentiment"], + "additionalProperties": False, + }, + "additionalProperties": False, + }, + } + ] + tool_choice = {"type": "function", "function": {"name": "structure_output"}} + optional_params = get_optional_params( + model="my-vllm-model", + custom_llm_provider="hosted_vllm", + temperature=0.2, + tools=tools, + tool_choice=tool_choice, + additional_drop_params=[ + ["tools", "function", "strict"], + ["tools", "function", "additionalProperties"], + ], + ) + print(optional_params["tools"][0]["function"]) + + assert "additionalProperties" not in optional_params["tools"][0]["function"] + assert "strict" not in optional_params["tools"][0]["function"] diff --git a/tests/local_testing/test_completion.py b/tests/local_testing/test_completion.py index b573a688b..3600bccab 100644 --- a/tests/local_testing/test_completion.py +++ b/tests/local_testing/test_completion.py @@ -1929,7 +1929,7 @@ def test_hf_test_completion_tgi(): # hf_test_completion_tgi() -@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai", +@pytest.mark.parametrize("provider", ["openai", "hosted_vllm"]) # "vertex_ai", @pytest.mark.asyncio async def test_openai_compatible_custom_api_base(provider): litellm.set_verbose = True @@ -1947,15 +1947,15 @@ async def test_openai_compatible_custom_api_base(provider): openai_client.chat.completions, "create", new=MagicMock() ) as mock_call: try: - response = completion( - model="openai/my-vllm-model", + completion( + model="{provider}/my-vllm-model".format(provider=provider), messages=messages, response_format={"type": "json_object"}, client=openai_client, api_base="my-custom-api-base", hello="world", ) - except Exception as e: + except Exception: pass mock_call.assert_called_once() diff --git a/tests/local_testing/test_custom_llm.py b/tests/local_testing/test_custom_llm.py index a0f8b569e..c9edde4a8 100644 --- a/tests/local_testing/test_custom_llm.py +++ b/tests/local_testing/test_custom_llm.py @@ -42,8 +42,11 @@ from litellm import ( acompletion, completion, get_llm_provider, + image_generation, ) from litellm.utils import ModelResponseIterator +from litellm.types.utils import ImageResponse, ImageObject +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler class CustomModelResponseIterator: @@ -219,6 +222,38 @@ class MyCustomLLM(CustomLLM): yield generic_streaming_chunk # type: ignore + def image_generation( + self, + model: str, + prompt: str, + model_response: ImageResponse, + optional_params: dict, + logging_obj: Any, + timeout=None, + client: Optional[HTTPHandler] = None, + ): + return ImageResponse( + created=int(time.time()), + data=[ImageObject(url="https://example.com/image.png")], + response_ms=1000, + ) + + async def aimage_generation( + self, + model: str, + prompt: str, + model_response: ImageResponse, + optional_params: dict, + logging_obj: Any, + timeout=None, + client: Optional[AsyncHTTPHandler] = None, + ): + return ImageResponse( + created=int(time.time()), + data=[ImageObject(url="https://example.com/image.png")], + response_ms=1000, + ) + def test_get_llm_provider(): """""" @@ -300,3 +335,30 @@ async def test_simple_completion_async_streaming(): assert isinstance(chunk.choices[0].delta.content, str) else: assert chunk.choices[0].finish_reason == "stop" + + +def test_simple_image_generation(): + my_custom_llm = MyCustomLLM() + litellm.custom_provider_map = [ + {"provider": "custom_llm", "custom_handler": my_custom_llm} + ] + resp = image_generation( + model="custom_llm/my-fake-model", + prompt="Hello world", + ) + + print(resp) + + +@pytest.mark.asyncio +async def test_simple_image_generation_async(): + my_custom_llm = MyCustomLLM() + litellm.custom_provider_map = [ + {"provider": "custom_llm", "custom_handler": my_custom_llm} + ] + resp = await litellm.aimage_generation( + model="custom_llm/my-fake-model", + prompt="Hello world", + ) + + print(resp) diff --git a/tests/local_testing/test_streaming.py b/tests/local_testing/test_streaming.py index 593d41c17..d64134aa8 100644 --- a/tests/local_testing/test_streaming.py +++ b/tests/local_testing/test_streaming.py @@ -2156,7 +2156,13 @@ def test_openai_chat_completion_complete_response_call(): # test_openai_chat_completion_complete_response_call() @pytest.mark.parametrize( "model", - ["gpt-3.5-turbo", "azure/chatgpt-v-2", "claude-3-haiku-20240307", "o1-preview"], # + [ + "gpt-3.5-turbo", + "azure/chatgpt-v-2", + "claude-3-haiku-20240307", + "o1-preview", + "azure/fake-o1-mini", + ], ) @pytest.mark.parametrize( "sync", @@ -2164,6 +2170,7 @@ def test_openai_chat_completion_complete_response_call(): ) @pytest.mark.asyncio async def test_openai_stream_options_call(model, sync): + litellm.enable_preview_features = True litellm.set_verbose = True usage = None chunks = [] @@ -2175,7 +2182,6 @@ async def test_openai_stream_options_call(model, sync): ], stream=True, stream_options={"include_usage": True}, - max_tokens=10, ) for chunk in response: print("chunk: ", chunk) @@ -2186,7 +2192,6 @@ async def test_openai_stream_options_call(model, sync): messages=[{"role": "user", "content": "say GM - we're going to make it "}], stream=True, stream_options={"include_usage": True}, - max_tokens=10, ) async for chunk in response: diff --git a/tests/local_testing/test_text_completion.py b/tests/local_testing/test_text_completion.py index c4d2305fc..76d1dbb19 100644 --- a/tests/local_testing/test_text_completion.py +++ b/tests/local_testing/test_text_completion.py @@ -4223,7 +4223,8 @@ def mock_post(*args, **kwargs): return mock_response -def test_completion_vllm(): +@pytest.mark.parametrize("provider", ["openai", "hosted_vllm"]) +def test_completion_vllm(provider): """ Asserts a text completion call for vllm actually goes to the text completion endpoint """ @@ -4235,7 +4236,10 @@ def test_completion_vllm(): client.completions.with_raw_response, "create", side_effect=mock_post ) as mock_call: response = text_completion( - model="openai/gemini-1.5-flash", prompt="ping", client=client, hello="world" + model="{provider}/gemini-1.5-flash".format(provider=provider), + prompt="ping", + client=client, + hello="world", ) print("raw response", response)