forked from phoenix/litellm-mirror
LiteLLM Minor Fixes & Improvements (10/07/2024) (#6101)
* fix(utils.py): support dropping temperature param for azure o1 models * fix(main.py): handle azure o1 streaming requests o1 doesn't support streaming, fake it to ensure code works as expected * feat(utils.py): expose `hosted_vllm/` endpoint, with tool handling for vllm Fixes https://github.com/BerriAI/litellm/issues/6088 * refactor(internal_user_endpoints.py): cleanup unused params + update docstring Closes https://github.com/BerriAI/litellm/issues/6100 * fix(main.py): expose custom image generation api support Fixes https://github.com/BerriAI/litellm/issues/6097 * fix: fix linting errors * docs(custom_llm_server.md): add docs on custom api for image gen calls * fix(types/utils.py): handle dict type * fix(types/utils.py): fix linting errors
This commit is contained in:
parent
5de69cb1b2
commit
6729c9ca7f
17 changed files with 643 additions and 76 deletions
|
@ -183,11 +183,80 @@ class UnixTimeLLM(CustomLLM):
|
||||||
unixtime = UnixTimeLLM()
|
unixtime = UnixTimeLLM()
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Image Generation
|
||||||
|
|
||||||
|
1. Setup your `custom_handler.py` file
|
||||||
|
```python
|
||||||
|
import litellm
|
||||||
|
from litellm import CustomLLM
|
||||||
|
from litellm.types.utils import ImageResponse, ImageObject
|
||||||
|
|
||||||
|
|
||||||
|
class MyCustomLLM(CustomLLM):
|
||||||
|
async def aimage_generation(self, model: str, prompt: str, model_response: ImageResponse, optional_params: dict, logging_obj: Any, timeout: Optional[Union[float, httpx.Timeout]] = None, client: Optional[AsyncHTTPHandler] = None,) -> ImageResponse:
|
||||||
|
return ImageResponse(
|
||||||
|
created=int(time.time()),
|
||||||
|
data=[ImageObject(url="https://example.com/image.png")],
|
||||||
|
)
|
||||||
|
|
||||||
|
my_custom_llm = MyCustomLLM()
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
2. Add to `config.yaml`
|
||||||
|
|
||||||
|
In the config below, we pass
|
||||||
|
|
||||||
|
python_filename: `custom_handler.py`
|
||||||
|
custom_handler_instance_name: `my_custom_llm`. This is defined in Step 1
|
||||||
|
|
||||||
|
custom_handler: `custom_handler.my_custom_llm`
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: "test-model"
|
||||||
|
litellm_params:
|
||||||
|
model: "openai/text-embedding-ada-002"
|
||||||
|
- model_name: "my-custom-model"
|
||||||
|
litellm_params:
|
||||||
|
model: "my-custom-llm/my-model"
|
||||||
|
|
||||||
|
litellm_settings:
|
||||||
|
custom_provider_map:
|
||||||
|
- {"provider": "my-custom-llm", "custom_handler": custom_handler.my_custom_llm}
|
||||||
|
```
|
||||||
|
|
||||||
|
```bash
|
||||||
|
litellm --config /path/to/config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Test it!
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST 'http://0.0.0.0:4000/v1/images/generations' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-H 'Authorization: Bearer sk-1234' \
|
||||||
|
-d '{
|
||||||
|
"model": "my-custom-model",
|
||||||
|
"prompt": "A cute baby sea otter",
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected Response
|
||||||
|
|
||||||
|
```
|
||||||
|
{
|
||||||
|
"created": 1721955063,
|
||||||
|
"data": [{"url": "https://example.com/image.png"}],
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
## Custom Handler Spec
|
## Custom Handler Spec
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from litellm.types.utils import GenericStreamingChunk, ModelResponse
|
from litellm.types.utils import GenericStreamingChunk, ModelResponse, ImageResponse
|
||||||
from typing import Iterator, AsyncIterator
|
from typing import Iterator, AsyncIterator, Any, Optional, Union
|
||||||
from litellm.llms.base import BaseLLM
|
from litellm.llms.base import BaseLLM
|
||||||
|
|
||||||
class CustomLLMError(Exception): # use this for all your exceptions
|
class CustomLLMError(Exception): # use this for all your exceptions
|
||||||
|
@ -217,4 +286,28 @@ class CustomLLM(BaseLLM):
|
||||||
|
|
||||||
async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]:
|
async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]:
|
||||||
raise CustomLLMError(status_code=500, message="Not implemented yet!")
|
raise CustomLLMError(status_code=500, message="Not implemented yet!")
|
||||||
|
|
||||||
|
def image_generation(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
prompt: str,
|
||||||
|
model_response: ImageResponse,
|
||||||
|
optional_params: dict,
|
||||||
|
logging_obj: Any,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||||
|
client: Optional[HTTPHandler] = None,
|
||||||
|
) -> ImageResponse:
|
||||||
|
raise CustomLLMError(status_code=500, message="Not implemented yet!")
|
||||||
|
|
||||||
|
async def aimage_generation(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
prompt: str,
|
||||||
|
model_response: ImageResponse,
|
||||||
|
optional_params: dict,
|
||||||
|
logging_obj: Any,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||||
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
|
) -> ImageResponse:
|
||||||
|
raise CustomLLMError(status_code=500, message="Not implemented yet!")
|
||||||
```
|
```
|
||||||
|
|
|
@ -12,14 +12,14 @@ vLLM Provides an OpenAI compatible endpoints - here's how to call it with LiteLL
|
||||||
|
|
||||||
In order to use litellm to call a hosted vllm server add the following to your completion call
|
In order to use litellm to call a hosted vllm server add the following to your completion call
|
||||||
|
|
||||||
* `model="openai/<your-vllm-model-name>"`
|
* `model="hosted_vllm/<your-vllm-model-name>"`
|
||||||
* `api_base = "your-hosted-vllm-server"`
|
* `api_base = "your-hosted-vllm-server"`
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
response = litellm.completion(
|
response = litellm.completion(
|
||||||
model="openai/facebook/opt-125m", # pass the vllm model name
|
model="hosted_vllm/facebook/opt-125m", # pass the vllm model name
|
||||||
messages=messages,
|
messages=messages,
|
||||||
api_base="https://hosted-vllm-api.co",
|
api_base="https://hosted-vllm-api.co",
|
||||||
temperature=0.2,
|
temperature=0.2,
|
||||||
|
@ -39,7 +39,7 @@ Here's how to call an OpenAI-Compatible Endpoint with the LiteLLM Proxy Server
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: my-model
|
- model_name: my-model
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: openai/facebook/opt-125m # add openai/ prefix to route as OpenAI provider
|
model: hosted_vllm/facebook/opt-125m # add hosted_vllm/ prefix to route as OpenAI provider
|
||||||
api_base: https://hosted-vllm-api.co # add api base for OpenAI compatible provider
|
api_base: https://hosted-vllm-api.co # add api base for OpenAI compatible provider
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -504,11 +504,13 @@ openai_compatible_providers: List = [
|
||||||
"azure_ai",
|
"azure_ai",
|
||||||
"github",
|
"github",
|
||||||
"litellm_proxy",
|
"litellm_proxy",
|
||||||
|
"hosted_vllm",
|
||||||
]
|
]
|
||||||
openai_text_completion_compatible_providers: List = (
|
openai_text_completion_compatible_providers: List = (
|
||||||
[ # providers that support `/v1/completions`
|
[ # providers that support `/v1/completions`
|
||||||
"together_ai",
|
"together_ai",
|
||||||
"fireworks_ai",
|
"fireworks_ai",
|
||||||
|
"hosted_vllm",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -758,6 +760,7 @@ class LlmProviders(str, Enum):
|
||||||
GITHUB = "github"
|
GITHUB = "github"
|
||||||
CUSTOM = "custom"
|
CUSTOM = "custom"
|
||||||
LITELLM_PROXY = "litellm_proxy"
|
LITELLM_PROXY = "litellm_proxy"
|
||||||
|
HOSTED_VLLM = "hosted_vllm"
|
||||||
|
|
||||||
|
|
||||||
provider_list: List[Union[LlmProviders, str]] = list(LlmProviders)
|
provider_list: List[Union[LlmProviders, str]] = list(LlmProviders)
|
||||||
|
@ -1003,6 +1006,8 @@ from .llms.AzureOpenAI.azure import (
|
||||||
AzureOpenAIError,
|
AzureOpenAIError,
|
||||||
AzureOpenAIAssistantsAPIConfig,
|
AzureOpenAIAssistantsAPIConfig,
|
||||||
)
|
)
|
||||||
|
from .llms.hosted_vllm.chat.transformation import HostedVLLMChatConfig
|
||||||
|
from .llms.AzureOpenAI.chat.o1_transformation import AzureOpenAIO1Config
|
||||||
from .llms.watsonx import IBMWatsonXAIConfig
|
from .llms.watsonx import IBMWatsonXAIConfig
|
||||||
from .main import * # type: ignore
|
from .main import * # type: ignore
|
||||||
from .integrations import *
|
from .integrations import *
|
||||||
|
|
|
@ -206,6 +206,14 @@ def get_llm_provider(
|
||||||
or "https://codestral.mistral.ai/v1"
|
or "https://codestral.mistral.ai/v1"
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
dynamic_api_key = api_key or get_secret("CODESTRAL_API_KEY")
|
dynamic_api_key = api_key or get_secret("CODESTRAL_API_KEY")
|
||||||
|
elif custom_llm_provider == "hosted_vllm":
|
||||||
|
# vllm is openai compatible, we just need to set this to custom_openai
|
||||||
|
api_base = api_base or get_secret(
|
||||||
|
"HOSTED_VLLM_API_BASE"
|
||||||
|
) # type: ignore
|
||||||
|
dynamic_api_key = (
|
||||||
|
api_key or get_secret("HOSTED_VLLM_API_KEY") or ""
|
||||||
|
) # vllm does not require an api key
|
||||||
elif custom_llm_provider == "deepseek":
|
elif custom_llm_provider == "deepseek":
|
||||||
# deepseek is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.deepseek.com/v1
|
# deepseek is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.deepseek.com/v1
|
||||||
api_base = (
|
api_base = (
|
||||||
|
|
97
litellm/llms/AzureOpenAI/chat/o1_handler.py
Normal file
97
litellm/llms/AzureOpenAI/chat/o1_handler.py
Normal file
|
@ -0,0 +1,97 @@
|
||||||
|
"""
|
||||||
|
Handler file for calls to Azure OpenAI's o1 family of models
|
||||||
|
|
||||||
|
Written separately to handle faking streaming for o1 models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from typing import Any, Callable, List, Optional, Union
|
||||||
|
|
||||||
|
from httpx._config import Timeout
|
||||||
|
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||||
|
from litellm.llms.bedrock.chat.invoke_handler import MockResponseIterator
|
||||||
|
from litellm.types.utils import ModelResponse
|
||||||
|
from litellm.utils import CustomStreamWrapper
|
||||||
|
|
||||||
|
from ..azure import AzureChatCompletion
|
||||||
|
|
||||||
|
|
||||||
|
class AzureOpenAIO1ChatCompletion(AzureChatCompletion):
|
||||||
|
|
||||||
|
async def mock_async_streaming(
|
||||||
|
self,
|
||||||
|
response: Any,
|
||||||
|
model: Optional[str],
|
||||||
|
logging_obj: Any,
|
||||||
|
):
|
||||||
|
model_response = await response
|
||||||
|
completion_stream = MockResponseIterator(model_response=model_response)
|
||||||
|
streaming_response = CustomStreamWrapper(
|
||||||
|
completion_stream=completion_stream,
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="azure",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
return streaming_response
|
||||||
|
|
||||||
|
def completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
api_key: str,
|
||||||
|
api_base: str,
|
||||||
|
api_version: str,
|
||||||
|
api_type: str,
|
||||||
|
azure_ad_token: str,
|
||||||
|
dynamic_params: bool,
|
||||||
|
print_verbose: Callable[..., Any],
|
||||||
|
timeout: Union[float, Timeout],
|
||||||
|
logging_obj: Logging,
|
||||||
|
optional_params,
|
||||||
|
litellm_params,
|
||||||
|
logger_fn,
|
||||||
|
acompletion: bool = False,
|
||||||
|
headers: Optional[dict] = None,
|
||||||
|
client=None,
|
||||||
|
):
|
||||||
|
stream: Optional[bool] = optional_params.pop("stream", False)
|
||||||
|
response = super().completion(
|
||||||
|
model,
|
||||||
|
messages,
|
||||||
|
model_response,
|
||||||
|
api_key,
|
||||||
|
api_base,
|
||||||
|
api_version,
|
||||||
|
api_type,
|
||||||
|
azure_ad_token,
|
||||||
|
dynamic_params,
|
||||||
|
print_verbose,
|
||||||
|
timeout,
|
||||||
|
logging_obj,
|
||||||
|
optional_params,
|
||||||
|
litellm_params,
|
||||||
|
logger_fn,
|
||||||
|
acompletion,
|
||||||
|
headers,
|
||||||
|
client,
|
||||||
|
)
|
||||||
|
|
||||||
|
if stream is True:
|
||||||
|
if asyncio.iscoroutine(response):
|
||||||
|
return self.mock_async_streaming(
|
||||||
|
response=response, model=model, logging_obj=logging_obj # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
completion_stream = MockResponseIterator(model_response=response)
|
||||||
|
streaming_response = CustomStreamWrapper(
|
||||||
|
completion_stream=completion_stream,
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="openai",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
|
||||||
|
return streaming_response
|
||||||
|
else:
|
||||||
|
return response
|
30
litellm/llms/AzureOpenAI/chat/o1_transformation.py
Normal file
30
litellm/llms/AzureOpenAI/chat/o1_transformation.py
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
"""
|
||||||
|
Support for o1 model family
|
||||||
|
|
||||||
|
https://platform.openai.com/docs/guides/reasoning
|
||||||
|
|
||||||
|
Translations handled by LiteLLM:
|
||||||
|
- modalities: image => drop param (if user opts in to dropping param)
|
||||||
|
- role: system ==> translate to role 'user'
|
||||||
|
- streaming => faked by LiteLLM
|
||||||
|
- Tools, response_format => drop param (if user opts in to dropping param)
|
||||||
|
- Logprobs => drop param (if user opts in to dropping param)
|
||||||
|
- Temperature => drop param (if user opts in to dropping param)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import types
|
||||||
|
from typing import Any, List, Optional, Union
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.types.llms.openai import AllMessageValues, ChatCompletionUserMessage
|
||||||
|
|
||||||
|
from ...OpenAI.chat.o1_transformation import OpenAIO1Config
|
||||||
|
|
||||||
|
|
||||||
|
class AzureOpenAIO1Config(OpenAIO1Config):
|
||||||
|
def is_o1_model(self, model: str) -> bool:
|
||||||
|
o1_models = ["o1-mini", "o1-preview"]
|
||||||
|
for m in o1_models:
|
||||||
|
if m in model:
|
||||||
|
return True
|
||||||
|
return False
|
|
@ -36,7 +36,13 @@ import litellm
|
||||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from litellm.types.utils import GenericStreamingChunk, ProviderField
|
from litellm.types.utils import GenericStreamingChunk, ProviderField
|
||||||
from litellm.utils import CustomStreamWrapper, EmbeddingResponse, ModelResponse, Usage
|
from litellm.utils import (
|
||||||
|
CustomStreamWrapper,
|
||||||
|
EmbeddingResponse,
|
||||||
|
ImageResponse,
|
||||||
|
ModelResponse,
|
||||||
|
Usage,
|
||||||
|
)
|
||||||
|
|
||||||
from .base import BaseLLM
|
from .base import BaseLLM
|
||||||
from .prompt_templates.factory import custom_prompt, prompt_factory
|
from .prompt_templates.factory import custom_prompt, prompt_factory
|
||||||
|
@ -143,6 +149,30 @@ class CustomLLM(BaseLLM):
|
||||||
) -> AsyncIterator[GenericStreamingChunk]:
|
) -> AsyncIterator[GenericStreamingChunk]:
|
||||||
raise CustomLLMError(status_code=500, message="Not implemented yet!")
|
raise CustomLLMError(status_code=500, message="Not implemented yet!")
|
||||||
|
|
||||||
|
def image_generation(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
prompt: str,
|
||||||
|
model_response: ImageResponse,
|
||||||
|
optional_params: dict,
|
||||||
|
logging_obj: Any,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||||
|
client: Optional[HTTPHandler] = None,
|
||||||
|
) -> ImageResponse:
|
||||||
|
raise CustomLLMError(status_code=500, message="Not implemented yet!")
|
||||||
|
|
||||||
|
async def aimage_generation(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
prompt: str,
|
||||||
|
model_response: ImageResponse,
|
||||||
|
optional_params: dict,
|
||||||
|
logging_obj: Any,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||||
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
|
) -> ImageResponse:
|
||||||
|
raise CustomLLMError(status_code=500, message="Not implemented yet!")
|
||||||
|
|
||||||
|
|
||||||
def custom_chat_llm_router(
|
def custom_chat_llm_router(
|
||||||
async_fn: bool, stream: Optional[bool], custom_llm: CustomLLM
|
async_fn: bool, stream: Optional[bool], custom_llm: CustomLLM
|
||||||
|
|
34
litellm/llms/hosted_vllm/chat/transformation.py
Normal file
34
litellm/llms/hosted_vllm/chat/transformation.py
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
"""
|
||||||
|
Translate from OpenAI's `/v1/chat/completions` to VLLM's `/v1/chat/completions`
|
||||||
|
"""
|
||||||
|
|
||||||
|
import types
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.types.llms.openai import AllMessageValues, ChatCompletionAssistantMessage
|
||||||
|
|
||||||
|
from ....utils import _remove_additional_properties, _remove_strict_from_schema
|
||||||
|
from ...OpenAI.chat.gpt_transformation import OpenAIGPTConfig
|
||||||
|
|
||||||
|
|
||||||
|
class HostedVLLMChatConfig(OpenAIGPTConfig):
|
||||||
|
def map_openai_params(
|
||||||
|
self,
|
||||||
|
non_default_params: dict,
|
||||||
|
optional_params: dict,
|
||||||
|
model: str,
|
||||||
|
drop_params: bool,
|
||||||
|
) -> dict:
|
||||||
|
_tools = non_default_params.pop("tools", None)
|
||||||
|
if _tools is not None:
|
||||||
|
# remove 'additionalProperties' from tools
|
||||||
|
_tools = _remove_additional_properties(_tools)
|
||||||
|
# remove 'strict' from tools
|
||||||
|
_tools = _remove_strict_from_schema(_tools)
|
||||||
|
non_default_params["tools"] = _tools
|
||||||
|
return super().map_openai_params(
|
||||||
|
non_default_params, optional_params, model, drop_params
|
||||||
|
)
|
138
litellm/main.py
138
litellm/main.py
|
@ -42,6 +42,7 @@ from litellm import ( # type: ignore
|
||||||
)
|
)
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from litellm.secret_managers.main import get_secret_str
|
from litellm.secret_managers.main import get_secret_str
|
||||||
from litellm.utils import (
|
from litellm.utils import (
|
||||||
CustomStreamWrapper,
|
CustomStreamWrapper,
|
||||||
|
@ -89,6 +90,7 @@ from .llms.azure_ai.embed import AzureAIEmbedding
|
||||||
from .llms.azure_text import AzureTextCompletion
|
from .llms.azure_text import AzureTextCompletion
|
||||||
from .llms.AzureOpenAI.audio_transcriptions import AzureAudioTranscription
|
from .llms.AzureOpenAI.audio_transcriptions import AzureAudioTranscription
|
||||||
from .llms.AzureOpenAI.azure import AzureChatCompletion, _check_dynamic_azure_params
|
from .llms.AzureOpenAI.azure import AzureChatCompletion, _check_dynamic_azure_params
|
||||||
|
from .llms.AzureOpenAI.chat.o1_handler import AzureOpenAIO1ChatCompletion
|
||||||
from .llms.bedrock import image_generation as bedrock_image_generation # type: ignore
|
from .llms.bedrock import image_generation as bedrock_image_generation # type: ignore
|
||||||
from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
|
from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
|
||||||
from .llms.bedrock.embed.embedding import BedrockEmbedding
|
from .llms.bedrock.embed.embedding import BedrockEmbedding
|
||||||
|
@ -178,6 +180,7 @@ azure_ai_embedding = AzureAIEmbedding()
|
||||||
anthropic_chat_completions = AnthropicChatCompletion()
|
anthropic_chat_completions = AnthropicChatCompletion()
|
||||||
anthropic_text_completions = AnthropicTextCompletion()
|
anthropic_text_completions = AnthropicTextCompletion()
|
||||||
azure_chat_completions = AzureChatCompletion()
|
azure_chat_completions = AzureChatCompletion()
|
||||||
|
azure_o1_chat_completions = AzureOpenAIO1ChatCompletion()
|
||||||
azure_text_completions = AzureTextCompletion()
|
azure_text_completions = AzureTextCompletion()
|
||||||
azure_audio_transcriptions = AzureAudioTranscription()
|
azure_audio_transcriptions = AzureAudioTranscription()
|
||||||
huggingface = Huggingface()
|
huggingface = Huggingface()
|
||||||
|
@ -1064,35 +1067,68 @@ def completion( # type: ignore
|
||||||
|
|
||||||
headers = headers or litellm.headers
|
headers = headers or litellm.headers
|
||||||
|
|
||||||
## LOAD CONFIG - if set
|
if (
|
||||||
config = litellm.AzureOpenAIConfig.get_config()
|
litellm.enable_preview_features
|
||||||
for k, v in config.items():
|
and litellm.AzureOpenAIO1Config().is_o1_model(model=model)
|
||||||
if (
|
):
|
||||||
k not in optional_params
|
## LOAD CONFIG - if set
|
||||||
): # completion(top_k=3) > azure_config(top_k=3) <- allows for dynamic variables to be passed in
|
config = litellm.AzureOpenAIO1Config.get_config()
|
||||||
optional_params[k] = v
|
for k, v in config.items():
|
||||||
|
if (
|
||||||
|
k not in optional_params
|
||||||
|
): # completion(top_k=3) > azure_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||||
|
optional_params[k] = v
|
||||||
|
|
||||||
## COMPLETION CALL
|
response = azure_o1_chat_completions.completion(
|
||||||
response = azure_chat_completions.completion(
|
model=model,
|
||||||
model=model,
|
messages=messages,
|
||||||
messages=messages,
|
headers=headers,
|
||||||
headers=headers,
|
api_key=api_key,
|
||||||
api_key=api_key,
|
api_base=api_base,
|
||||||
api_base=api_base,
|
api_version=api_version,
|
||||||
api_version=api_version,
|
api_type=api_type,
|
||||||
api_type=api_type,
|
dynamic_params=dynamic_params,
|
||||||
dynamic_params=dynamic_params,
|
azure_ad_token=azure_ad_token,
|
||||||
azure_ad_token=azure_ad_token,
|
model_response=model_response,
|
||||||
model_response=model_response,
|
print_verbose=print_verbose,
|
||||||
print_verbose=print_verbose,
|
optional_params=optional_params,
|
||||||
optional_params=optional_params,
|
litellm_params=litellm_params,
|
||||||
litellm_params=litellm_params,
|
logger_fn=logger_fn,
|
||||||
logger_fn=logger_fn,
|
logging_obj=logging,
|
||||||
logging_obj=logging,
|
acompletion=acompletion,
|
||||||
acompletion=acompletion,
|
timeout=timeout, # type: ignore
|
||||||
timeout=timeout, # type: ignore
|
client=client, # pass AsyncAzureOpenAI, AzureOpenAI client
|
||||||
client=client, # pass AsyncAzureOpenAI, AzureOpenAI client
|
)
|
||||||
)
|
else:
|
||||||
|
## LOAD CONFIG - if set
|
||||||
|
config = litellm.AzureOpenAIConfig.get_config()
|
||||||
|
for k, v in config.items():
|
||||||
|
if (
|
||||||
|
k not in optional_params
|
||||||
|
): # completion(top_k=3) > azure_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||||
|
optional_params[k] = v
|
||||||
|
|
||||||
|
## COMPLETION CALL
|
||||||
|
response = azure_chat_completions.completion(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
headers=headers,
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
api_version=api_version,
|
||||||
|
api_type=api_type,
|
||||||
|
dynamic_params=dynamic_params,
|
||||||
|
azure_ad_token=azure_ad_token,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
logging_obj=logging,
|
||||||
|
acompletion=acompletion,
|
||||||
|
timeout=timeout, # type: ignore
|
||||||
|
client=client, # pass AsyncAzureOpenAI, AzureOpenAI client
|
||||||
|
)
|
||||||
|
|
||||||
if optional_params.get("stream", False):
|
if optional_params.get("stream", False):
|
||||||
## LOGGING
|
## LOGGING
|
||||||
|
@ -4582,6 +4618,7 @@ def image_generation(
|
||||||
Currently supports just Azure + OpenAI.
|
Currently supports just Azure + OpenAI.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
args = locals()
|
||||||
aimg_generation = kwargs.get("aimg_generation", False)
|
aimg_generation = kwargs.get("aimg_generation", False)
|
||||||
litellm_call_id = kwargs.get("litellm_call_id", None)
|
litellm_call_id = kwargs.get("litellm_call_id", None)
|
||||||
logger_fn = kwargs.get("logger_fn", None)
|
logger_fn = kwargs.get("logger_fn", None)
|
||||||
|
@ -4787,6 +4824,51 @@ def image_generation(
|
||||||
vertex_credentials=vertex_credentials,
|
vertex_credentials=vertex_credentials,
|
||||||
aimg_generation=aimg_generation,
|
aimg_generation=aimg_generation,
|
||||||
)
|
)
|
||||||
|
elif (
|
||||||
|
custom_llm_provider in litellm._custom_providers
|
||||||
|
): # Assume custom LLM provider
|
||||||
|
# Get the Custom Handler
|
||||||
|
custom_handler: Optional[CustomLLM] = None
|
||||||
|
for item in litellm.custom_provider_map:
|
||||||
|
if item["provider"] == custom_llm_provider:
|
||||||
|
custom_handler = item["custom_handler"]
|
||||||
|
|
||||||
|
if custom_handler is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unable to map your input to a model. Check your input - {args}"
|
||||||
|
)
|
||||||
|
|
||||||
|
## ROUTE LLM CALL ##
|
||||||
|
if aimg_generation is True:
|
||||||
|
async_custom_client: Optional[AsyncHTTPHandler] = None
|
||||||
|
if client is not None and isinstance(client, AsyncHTTPHandler):
|
||||||
|
async_custom_client = client
|
||||||
|
|
||||||
|
## CALL FUNCTION
|
||||||
|
model_response = custom_handler.aimage_generation( # type: ignore
|
||||||
|
model=model,
|
||||||
|
prompt=prompt,
|
||||||
|
model_response=model_response,
|
||||||
|
optional_params=optional_params,
|
||||||
|
logging_obj=litellm_logging_obj,
|
||||||
|
timeout=timeout,
|
||||||
|
client=async_custom_client,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
custom_client: Optional[HTTPHandler] = None
|
||||||
|
if client is not None and isinstance(client, HTTPHandler):
|
||||||
|
custom_client = client
|
||||||
|
|
||||||
|
## CALL FUNCTION
|
||||||
|
model_response = custom_handler.image_generation(
|
||||||
|
model=model,
|
||||||
|
prompt=prompt,
|
||||||
|
model_response=model_response,
|
||||||
|
optional_params=optional_params,
|
||||||
|
logging_obj=litellm_logging_obj,
|
||||||
|
timeout=timeout,
|
||||||
|
client=custom_client,
|
||||||
|
)
|
||||||
|
|
||||||
return model_response
|
return model_response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -299,17 +299,13 @@ async def user_info(
|
||||||
user_id: Optional[str] = fastapi.Query(
|
user_id: Optional[str] = fastapi.Query(
|
||||||
default=None, description="User ID in the request parameters"
|
default=None, description="User ID in the request parameters"
|
||||||
),
|
),
|
||||||
page: Optional[int] = fastapi.Query(
|
|
||||||
default=0,
|
|
||||||
description="Page number for pagination. Only use when view_all is true",
|
|
||||||
),
|
|
||||||
page_size: Optional[int] = fastapi.Query(
|
|
||||||
default=25,
|
|
||||||
description="Number of items per page. Only use when view_all is true",
|
|
||||||
),
|
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
[10/07/2024]
|
||||||
|
Note: To get all users (+pagination), use `/user/list` endpoint.
|
||||||
|
|
||||||
|
|
||||||
Use this to get user information. (user row + all user key info)
|
Use this to get user information. (user row + all user key info)
|
||||||
|
|
||||||
Example request
|
Example request
|
||||||
|
|
|
@ -1018,7 +1018,10 @@ class TextCompletionResponse(OpenAIObject):
|
||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
|
||||||
class ImageObject(OpenAIObject):
|
from openai.types.images_response import Image as OpenAIImage
|
||||||
|
|
||||||
|
|
||||||
|
class ImageObject(OpenAIImage):
|
||||||
"""
|
"""
|
||||||
Represents the url or the content of an image generated by the OpenAI API.
|
Represents the url or the content of an image generated by the OpenAI API.
|
||||||
|
|
||||||
|
@ -1070,7 +1073,7 @@ class ImageResponse(OpenAIImageResponse):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
created: Optional[int] = None,
|
created: Optional[int] = None,
|
||||||
data: Optional[list] = None,
|
data: Optional[List[ImageObject]] = None,
|
||||||
response_ms=None,
|
response_ms=None,
|
||||||
):
|
):
|
||||||
if response_ms:
|
if response_ms:
|
||||||
|
@ -1087,7 +1090,13 @@ class ImageResponse(OpenAIImageResponse):
|
||||||
else:
|
else:
|
||||||
created = int(time.time())
|
created = int(time.time())
|
||||||
|
|
||||||
super().__init__(created=created, data=data)
|
_data: List[OpenAIImage] = []
|
||||||
|
for d in data:
|
||||||
|
if isinstance(d, dict):
|
||||||
|
_data.append(ImageObject(**d))
|
||||||
|
elif isinstance(d, BaseModel):
|
||||||
|
_data.append(ImageObject(**d.model_dump()))
|
||||||
|
super().__init__(created=created, data=_data)
|
||||||
self.usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
self.usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
||||||
|
|
||||||
def __contains__(self, key):
|
def __contains__(self, key):
|
||||||
|
|
|
@ -2788,6 +2788,24 @@ def _remove_additional_properties(schema):
|
||||||
return schema
|
return schema
|
||||||
|
|
||||||
|
|
||||||
|
def _remove_strict_from_schema(schema):
|
||||||
|
if isinstance(schema, dict):
|
||||||
|
# Remove the 'additionalProperties' key if it exists and is set to False
|
||||||
|
if "strict" in schema:
|
||||||
|
del schema["strict"]
|
||||||
|
|
||||||
|
# Recursively process all dictionary values
|
||||||
|
for key, value in schema.items():
|
||||||
|
_remove_strict_from_schema(value)
|
||||||
|
|
||||||
|
elif isinstance(schema, list):
|
||||||
|
# Recursively process all items in the list
|
||||||
|
for item in schema:
|
||||||
|
_remove_strict_from_schema(item)
|
||||||
|
|
||||||
|
return schema
|
||||||
|
|
||||||
|
|
||||||
def get_optional_params(
|
def get_optional_params(
|
||||||
# use the openai defaults
|
# use the openai defaults
|
||||||
# https://platform.openai.com/docs/api-reference/chat/create
|
# https://platform.openai.com/docs/api-reference/chat/create
|
||||||
|
@ -2999,13 +3017,19 @@ def get_optional_params(
|
||||||
from litellm.llms.vertex_ai_and_google_ai_studio.common_utils import (
|
from litellm.llms.vertex_ai_and_google_ai_studio.common_utils import (
|
||||||
_build_vertex_schema,
|
_build_vertex_schema,
|
||||||
)
|
)
|
||||||
|
|
||||||
old_schema = copy.deepcopy(
|
old_schema = copy.deepcopy(
|
||||||
non_default_params["response_format"]
|
non_default_params["response_format"]
|
||||||
.get("json_schema", {})
|
.get("json_schema", {})
|
||||||
.get("schema")
|
.get("schema")
|
||||||
)
|
)
|
||||||
new_schema = _remove_additional_properties(schema=old_schema)
|
new_schema = _remove_additional_properties(schema=old_schema)
|
||||||
new_schema = _build_vertex_schema(parameters=new_schema)
|
if isinstance(new_schema, list):
|
||||||
|
for item in new_schema:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
item = _build_vertex_schema(parameters=item)
|
||||||
|
elif isinstance(new_schema, dict):
|
||||||
|
new_schema = _build_vertex_schema(parameters=new_schema)
|
||||||
non_default_params["response_format"]["json_schema"]["schema"] = new_schema
|
non_default_params["response_format"]["json_schema"]["schema"] = new_schema
|
||||||
if "tools" in non_default_params and isinstance(
|
if "tools" in non_default_params and isinstance(
|
||||||
non_default_params, list
|
non_default_params, list
|
||||||
|
@ -3767,6 +3791,21 @@ def get_optional_params(
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
model=model,
|
model=model,
|
||||||
)
|
)
|
||||||
|
elif custom_llm_provider == "hosted_vllm":
|
||||||
|
supported_params = get_supported_openai_params(
|
||||||
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
|
)
|
||||||
|
_check_valid_arg(supported_params=supported_params)
|
||||||
|
optional_params = litellm.HostedVLLMChatConfig().map_openai_params(
|
||||||
|
non_default_params=non_default_params,
|
||||||
|
optional_params=optional_params,
|
||||||
|
model=model,
|
||||||
|
drop_params=(
|
||||||
|
drop_params
|
||||||
|
if drop_params is not None and isinstance(drop_params, bool)
|
||||||
|
else False
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
elif custom_llm_provider == "groq":
|
elif custom_llm_provider == "groq":
|
||||||
supported_params = get_supported_openai_params(
|
supported_params = get_supported_openai_params(
|
||||||
|
@ -3926,24 +3965,36 @@ def get_optional_params(
|
||||||
model=model, custom_llm_provider="azure"
|
model=model, custom_llm_provider="azure"
|
||||||
)
|
)
|
||||||
_check_valid_arg(supported_params=supported_params)
|
_check_valid_arg(supported_params=supported_params)
|
||||||
verbose_logger.debug(
|
if litellm.AzureOpenAIO1Config().is_o1_model(model=model):
|
||||||
"Azure optional params - api_version: api_version={}, litellm.api_version={}, os.environ['AZURE_API_VERSION']={}".format(
|
optional_params = litellm.AzureOpenAIO1Config().map_openai_params(
|
||||||
api_version, litellm.api_version, get_secret("AZURE_API_VERSION")
|
non_default_params=non_default_params,
|
||||||
|
optional_params=optional_params,
|
||||||
|
model=model,
|
||||||
|
drop_params=(
|
||||||
|
drop_params
|
||||||
|
if drop_params is not None and isinstance(drop_params, bool)
|
||||||
|
else False
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
verbose_logger.debug(
|
||||||
|
"Azure optional params - api_version: api_version={}, litellm.api_version={}, os.environ['AZURE_API_VERSION']={}".format(
|
||||||
|
api_version, litellm.api_version, get_secret("AZURE_API_VERSION")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
api_version = (
|
||||||
|
api_version
|
||||||
|
or litellm.api_version
|
||||||
|
or get_secret("AZURE_API_VERSION")
|
||||||
|
or litellm.AZURE_DEFAULT_API_VERSION
|
||||||
|
)
|
||||||
|
optional_params = litellm.AzureOpenAIConfig().map_openai_params(
|
||||||
|
non_default_params=non_default_params,
|
||||||
|
optional_params=optional_params,
|
||||||
|
model=model,
|
||||||
|
api_version=api_version, # type: ignore
|
||||||
|
drop_params=drop_params,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
api_version = (
|
|
||||||
api_version
|
|
||||||
or litellm.api_version
|
|
||||||
or get_secret("AZURE_API_VERSION")
|
|
||||||
or litellm.AZURE_DEFAULT_API_VERSION
|
|
||||||
)
|
|
||||||
optional_params = litellm.AzureOpenAIConfig().map_openai_params(
|
|
||||||
non_default_params=non_default_params,
|
|
||||||
optional_params=optional_params,
|
|
||||||
model=model,
|
|
||||||
api_version=api_version, # type: ignore
|
|
||||||
drop_params=drop_params,
|
|
||||||
)
|
|
||||||
else: # assume passing in params for text-completion openai
|
else: # assume passing in params for text-completion openai
|
||||||
supported_params = get_supported_openai_params(
|
supported_params = get_supported_openai_params(
|
||||||
model=model, custom_llm_provider="custom_openai"
|
model=model, custom_llm_provider="custom_openai"
|
||||||
|
@ -4409,6 +4460,8 @@ def get_supported_openai_params(
|
||||||
"extra_headers",
|
"extra_headers",
|
||||||
"extra_body",
|
"extra_body",
|
||||||
]
|
]
|
||||||
|
elif custom_llm_provider == "hosted_vllm":
|
||||||
|
return litellm.HostedVLLMChatConfig().get_supported_openai_params(model=model)
|
||||||
elif custom_llm_provider == "deepseek":
|
elif custom_llm_provider == "deepseek":
|
||||||
return [
|
return [
|
||||||
# https://platform.deepseek.com/api-docs/api/create-chat-completion
|
# https://platform.deepseek.com/api-docs/api/create-chat-completion
|
||||||
|
@ -4465,7 +4518,12 @@ def get_supported_openai_params(
|
||||||
elif custom_llm_provider == "openai":
|
elif custom_llm_provider == "openai":
|
||||||
return litellm.OpenAIConfig().get_supported_openai_params(model=model)
|
return litellm.OpenAIConfig().get_supported_openai_params(model=model)
|
||||||
elif custom_llm_provider == "azure":
|
elif custom_llm_provider == "azure":
|
||||||
return litellm.AzureOpenAIConfig().get_supported_openai_params()
|
if litellm.AzureOpenAIO1Config().is_o1_model(model=model):
|
||||||
|
return litellm.AzureOpenAIO1Config().get_supported_openai_params(
|
||||||
|
model=model
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return litellm.AzureOpenAIConfig().get_supported_openai_params()
|
||||||
elif custom_llm_provider == "openrouter":
|
elif custom_llm_provider == "openrouter":
|
||||||
return [
|
return [
|
||||||
"temperature",
|
"temperature",
|
||||||
|
|
|
@ -620,16 +620,28 @@ def test_o1_model_params():
|
||||||
assert optional_params["user"] == "John"
|
assert optional_params["user"] == "John"
|
||||||
|
|
||||||
|
|
||||||
|
def test_azure_o1_model_params():
|
||||||
|
optional_params = get_optional_params(
|
||||||
|
model="o1-preview",
|
||||||
|
custom_llm_provider="azure",
|
||||||
|
seed=10,
|
||||||
|
user="John",
|
||||||
|
)
|
||||||
|
assert optional_params["seed"] == 10
|
||||||
|
assert optional_params["user"] == "John"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"temperature, expected_error",
|
"temperature, expected_error",
|
||||||
[(0.2, True), (1, False)],
|
[(0.2, True), (1, False)],
|
||||||
)
|
)
|
||||||
def test_o1_model_temperature_params(temperature, expected_error):
|
@pytest.mark.parametrize("provider", ["openai", "azure"])
|
||||||
|
def test_o1_model_temperature_params(provider, temperature, expected_error):
|
||||||
if expected_error:
|
if expected_error:
|
||||||
with pytest.raises(litellm.UnsupportedParamsError):
|
with pytest.raises(litellm.UnsupportedParamsError):
|
||||||
get_optional_params(
|
get_optional_params(
|
||||||
model="o1-preview-2024-09-12",
|
model="o1-preview",
|
||||||
custom_llm_provider="openai",
|
custom_llm_provider=provider,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -650,3 +662,45 @@ def test_unmapped_gemini_model_params():
|
||||||
stop="stop_word",
|
stop="stop_word",
|
||||||
)
|
)
|
||||||
assert optional_params["stop_sequences"] == ["stop_word"]
|
assert optional_params["stop_sequences"] == ["stop_word"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_drop_nested_params_vllm():
|
||||||
|
"""
|
||||||
|
Relevant issue - https://github.com/BerriAI/litellm/issues/5288
|
||||||
|
"""
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "structure_output",
|
||||||
|
"description": "Send structured output back to the user",
|
||||||
|
"strict": True,
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"reasoning": {"type": "string"},
|
||||||
|
"sentiment": {"type": "string"},
|
||||||
|
},
|
||||||
|
"required": ["reasoning", "sentiment"],
|
||||||
|
"additionalProperties": False,
|
||||||
|
},
|
||||||
|
"additionalProperties": False,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
tool_choice = {"type": "function", "function": {"name": "structure_output"}}
|
||||||
|
optional_params = get_optional_params(
|
||||||
|
model="my-vllm-model",
|
||||||
|
custom_llm_provider="hosted_vllm",
|
||||||
|
temperature=0.2,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
additional_drop_params=[
|
||||||
|
["tools", "function", "strict"],
|
||||||
|
["tools", "function", "additionalProperties"],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
print(optional_params["tools"][0]["function"])
|
||||||
|
|
||||||
|
assert "additionalProperties" not in optional_params["tools"][0]["function"]
|
||||||
|
assert "strict" not in optional_params["tools"][0]["function"]
|
||||||
|
|
|
@ -1929,7 +1929,7 @@ def test_hf_test_completion_tgi():
|
||||||
# hf_test_completion_tgi()
|
# hf_test_completion_tgi()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai",
|
@pytest.mark.parametrize("provider", ["openai", "hosted_vllm"]) # "vertex_ai",
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_openai_compatible_custom_api_base(provider):
|
async def test_openai_compatible_custom_api_base(provider):
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
@ -1947,15 +1947,15 @@ async def test_openai_compatible_custom_api_base(provider):
|
||||||
openai_client.chat.completions, "create", new=MagicMock()
|
openai_client.chat.completions, "create", new=MagicMock()
|
||||||
) as mock_call:
|
) as mock_call:
|
||||||
try:
|
try:
|
||||||
response = completion(
|
completion(
|
||||||
model="openai/my-vllm-model",
|
model="{provider}/my-vllm-model".format(provider=provider),
|
||||||
messages=messages,
|
messages=messages,
|
||||||
response_format={"type": "json_object"},
|
response_format={"type": "json_object"},
|
||||||
client=openai_client,
|
client=openai_client,
|
||||||
api_base="my-custom-api-base",
|
api_base="my-custom-api-base",
|
||||||
hello="world",
|
hello="world",
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
mock_call.assert_called_once()
|
mock_call.assert_called_once()
|
||||||
|
|
|
@ -42,8 +42,11 @@ from litellm import (
|
||||||
acompletion,
|
acompletion,
|
||||||
completion,
|
completion,
|
||||||
get_llm_provider,
|
get_llm_provider,
|
||||||
|
image_generation,
|
||||||
)
|
)
|
||||||
from litellm.utils import ModelResponseIterator
|
from litellm.utils import ModelResponseIterator
|
||||||
|
from litellm.types.utils import ImageResponse, ImageObject
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
|
|
||||||
|
|
||||||
class CustomModelResponseIterator:
|
class CustomModelResponseIterator:
|
||||||
|
@ -219,6 +222,38 @@ class MyCustomLLM(CustomLLM):
|
||||||
|
|
||||||
yield generic_streaming_chunk # type: ignore
|
yield generic_streaming_chunk # type: ignore
|
||||||
|
|
||||||
|
def image_generation(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
prompt: str,
|
||||||
|
model_response: ImageResponse,
|
||||||
|
optional_params: dict,
|
||||||
|
logging_obj: Any,
|
||||||
|
timeout=None,
|
||||||
|
client: Optional[HTTPHandler] = None,
|
||||||
|
):
|
||||||
|
return ImageResponse(
|
||||||
|
created=int(time.time()),
|
||||||
|
data=[ImageObject(url="https://example.com/image.png")],
|
||||||
|
response_ms=1000,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def aimage_generation(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
prompt: str,
|
||||||
|
model_response: ImageResponse,
|
||||||
|
optional_params: dict,
|
||||||
|
logging_obj: Any,
|
||||||
|
timeout=None,
|
||||||
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
|
):
|
||||||
|
return ImageResponse(
|
||||||
|
created=int(time.time()),
|
||||||
|
data=[ImageObject(url="https://example.com/image.png")],
|
||||||
|
response_ms=1000,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_get_llm_provider():
|
def test_get_llm_provider():
|
||||||
""""""
|
""""""
|
||||||
|
@ -300,3 +335,30 @@ async def test_simple_completion_async_streaming():
|
||||||
assert isinstance(chunk.choices[0].delta.content, str)
|
assert isinstance(chunk.choices[0].delta.content, str)
|
||||||
else:
|
else:
|
||||||
assert chunk.choices[0].finish_reason == "stop"
|
assert chunk.choices[0].finish_reason == "stop"
|
||||||
|
|
||||||
|
|
||||||
|
def test_simple_image_generation():
|
||||||
|
my_custom_llm = MyCustomLLM()
|
||||||
|
litellm.custom_provider_map = [
|
||||||
|
{"provider": "custom_llm", "custom_handler": my_custom_llm}
|
||||||
|
]
|
||||||
|
resp = image_generation(
|
||||||
|
model="custom_llm/my-fake-model",
|
||||||
|
prompt="Hello world",
|
||||||
|
)
|
||||||
|
|
||||||
|
print(resp)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_simple_image_generation_async():
|
||||||
|
my_custom_llm = MyCustomLLM()
|
||||||
|
litellm.custom_provider_map = [
|
||||||
|
{"provider": "custom_llm", "custom_handler": my_custom_llm}
|
||||||
|
]
|
||||||
|
resp = await litellm.aimage_generation(
|
||||||
|
model="custom_llm/my-fake-model",
|
||||||
|
prompt="Hello world",
|
||||||
|
)
|
||||||
|
|
||||||
|
print(resp)
|
||||||
|
|
|
@ -2156,7 +2156,13 @@ def test_openai_chat_completion_complete_response_call():
|
||||||
# test_openai_chat_completion_complete_response_call()
|
# test_openai_chat_completion_complete_response_call()
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model",
|
"model",
|
||||||
["gpt-3.5-turbo", "azure/chatgpt-v-2", "claude-3-haiku-20240307", "o1-preview"], #
|
[
|
||||||
|
"gpt-3.5-turbo",
|
||||||
|
"azure/chatgpt-v-2",
|
||||||
|
"claude-3-haiku-20240307",
|
||||||
|
"o1-preview",
|
||||||
|
"azure/fake-o1-mini",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"sync",
|
"sync",
|
||||||
|
@ -2164,6 +2170,7 @@ def test_openai_chat_completion_complete_response_call():
|
||||||
)
|
)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_openai_stream_options_call(model, sync):
|
async def test_openai_stream_options_call(model, sync):
|
||||||
|
litellm.enable_preview_features = True
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
usage = None
|
usage = None
|
||||||
chunks = []
|
chunks = []
|
||||||
|
@ -2175,7 +2182,6 @@ async def test_openai_stream_options_call(model, sync):
|
||||||
],
|
],
|
||||||
stream=True,
|
stream=True,
|
||||||
stream_options={"include_usage": True},
|
stream_options={"include_usage": True},
|
||||||
max_tokens=10,
|
|
||||||
)
|
)
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
print("chunk: ", chunk)
|
print("chunk: ", chunk)
|
||||||
|
@ -2186,7 +2192,6 @@ async def test_openai_stream_options_call(model, sync):
|
||||||
messages=[{"role": "user", "content": "say GM - we're going to make it "}],
|
messages=[{"role": "user", "content": "say GM - we're going to make it "}],
|
||||||
stream=True,
|
stream=True,
|
||||||
stream_options={"include_usage": True},
|
stream_options={"include_usage": True},
|
||||||
max_tokens=10,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
|
|
|
@ -4223,7 +4223,8 @@ def mock_post(*args, **kwargs):
|
||||||
return mock_response
|
return mock_response
|
||||||
|
|
||||||
|
|
||||||
def test_completion_vllm():
|
@pytest.mark.parametrize("provider", ["openai", "hosted_vllm"])
|
||||||
|
def test_completion_vllm(provider):
|
||||||
"""
|
"""
|
||||||
Asserts a text completion call for vllm actually goes to the text completion endpoint
|
Asserts a text completion call for vllm actually goes to the text completion endpoint
|
||||||
"""
|
"""
|
||||||
|
@ -4235,7 +4236,10 @@ def test_completion_vllm():
|
||||||
client.completions.with_raw_response, "create", side_effect=mock_post
|
client.completions.with_raw_response, "create", side_effect=mock_post
|
||||||
) as mock_call:
|
) as mock_call:
|
||||||
response = text_completion(
|
response = text_completion(
|
||||||
model="openai/gemini-1.5-flash", prompt="ping", client=client, hello="world"
|
model="{provider}/gemini-1.5-flash".format(provider=provider),
|
||||||
|
prompt="ping",
|
||||||
|
client=client,
|
||||||
|
hello="world",
|
||||||
)
|
)
|
||||||
print("raw response", response)
|
print("raw response", response)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue