From fbca0351bc1cb67f1e75391fd57f775a419c45d4 Mon Sep 17 00:00:00 2001 From: zeeland Date: Fri, 6 Dec 2024 18:22:01 +0800 Subject: [PATCH] refactor: add type annotations and overloads to completion functions - Add type annotations to 'acompletion' and 'completion' functions - Implement overloads for different stream types - Enhance type safety and improve code readability These changes improve the robustness of the code by ensuring proper type checking and making the function signatures more explicit. --- litellm/main.py | 173 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 173 insertions(+) diff --git a/litellm/main.py b/litellm/main.py index d33436a96f..d947b4ca4f 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -33,6 +33,7 @@ from typing import ( Type, Union, cast, + overload ) import dotenv @@ -298,6 +299,89 @@ class AsyncCompletions: return response +@overload +async def acompletion( + model: str, + # Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create + stream: Literal[True] = True, + messages: List = [], + timeout: Optional[Union[float, int]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + n: Optional[int] = None, + stream_options: Optional[dict] = None, + stop=None, + max_tokens: Optional[int] = None, + max_completion_tokens: Optional[int] = None, + modalities: Optional[List[ChatCompletionModality]] = None, + prediction: Optional[ChatCompletionPredictionContentParam] = None, + audio: Optional[ChatCompletionAudioParam] = None, + presence_penalty: Optional[float] = None, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[dict] = None, + user: Optional[str] = None, + # openai v1.0+ new params + response_format: Optional[Union[dict, Type[BaseModel]]] = None, + seed: Optional[int] = None, + tools: Optional[List] = None, + tool_choice: Optional[str] = None, + parallel_tool_calls: Optional[bool] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, + deployment_id=None, + # set api_base, api_version, api_key + base_url: Optional[str] = None, + api_version: Optional[str] = None, + api_key: Optional[str] = None, + model_list: Optional[list] = None, # pass in a list of api_base,keys, etc. + extra_headers: Optional[dict] = None, + # Optional liteLLM function params + **kwargs, +) -> CustomStreamWrapper: + ... + + +@overload +async def acompletion( + model: str, + # Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create + stream: Literal[False] = False, + messages: List = [], + timeout: Optional[Union[float, int]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + n: Optional[int] = None, + stream_options: Optional[dict] = None, + stop=None, + max_tokens: Optional[int] = None, + max_completion_tokens: Optional[int] = None, + modalities: Optional[List[ChatCompletionModality]] = None, + prediction: Optional[ChatCompletionPredictionContentParam] = None, + audio: Optional[ChatCompletionAudioParam] = None, + presence_penalty: Optional[float] = None, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[dict] = None, + user: Optional[str] = None, + # openai v1.0+ new params + response_format: Optional[Union[dict, Type[BaseModel]]] = None, + seed: Optional[int] = None, + tools: Optional[List] = None, + tool_choice: Optional[str] = None, + parallel_tool_calls: Optional[bool] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, + deployment_id=None, + # set api_base, api_version, api_key + base_url: Optional[str] = None, + api_version: Optional[str] = None, + api_key: Optional[str] = None, + model_list: Optional[list] = None, # pass in a list of api_base,keys, etc. + extra_headers: Optional[dict] = None, + # Optional liteLLM function params + **kwargs, +) -> ModelResponse: + ... + @client async def acompletion( model: str, @@ -690,6 +774,95 @@ def mock_completion( raise e raise Exception("Mock completion response failed") +@overload +def completion( # type: ignore # noqa: PLR0915 + model: str, + # Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create + stream: Literal[True] = True, + messages: List = [], + timeout: Optional[Union[float, str, httpx.Timeout]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + n: Optional[int] = None, + stream_options: Optional[dict] = None, + stop=None, + max_completion_tokens: Optional[int] = None, + max_tokens: Optional[int] = None, + modalities: Optional[List[ChatCompletionModality]] = None, + prediction: Optional[ChatCompletionPredictionContentParam] = None, + audio: Optional[ChatCompletionAudioParam] = None, + presence_penalty: Optional[float] = None, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[dict] = None, + user: Optional[str] = None, + # openai v1.0+ new params + response_format: Optional[Union[dict, Type[BaseModel]]] = None, + seed: Optional[int] = None, + tools: Optional[List] = None, + tool_choice: Optional[Union[str, dict]] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, + parallel_tool_calls: Optional[bool] = None, + deployment_id=None, + extra_headers: Optional[dict] = None, + # soon to be deprecated params by OpenAI + functions: Optional[List] = None, + function_call: Optional[str] = None, + # set api_base, api_version, api_key + base_url: Optional[str] = None, + api_version: Optional[str] = None, + api_key: Optional[str] = None, + model_list: Optional[list] = None, # pass in a list of api_base,keys, etc. + # Optional liteLLM function params + **kwargs, +) -> CustomStreamWrapper: + ... + + +@overload +def completion( # type: ignore # noqa: PLR0915 + model: str, + # Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create + stream: Literal[False] = False, + messages: List = [], + timeout: Optional[Union[float, str, httpx.Timeout]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + n: Optional[int] = None, + stream_options: Optional[dict] = None, + stop=None, + max_completion_tokens: Optional[int] = None, + max_tokens: Optional[int] = None, + modalities: Optional[List[ChatCompletionModality]] = None, + prediction: Optional[ChatCompletionPredictionContentParam] = None, + audio: Optional[ChatCompletionAudioParam] = None, + presence_penalty: Optional[float] = None, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[dict] = None, + user: Optional[str] = None, + # openai v1.0+ new params + response_format: Optional[Union[dict, Type[BaseModel]]] = None, + seed: Optional[int] = None, + tools: Optional[List] = None, + tool_choice: Optional[Union[str, dict]] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, + parallel_tool_calls: Optional[bool] = None, + deployment_id=None, + extra_headers: Optional[dict] = None, + # soon to be deprecated params by OpenAI + functions: Optional[List] = None, + function_call: Optional[str] = None, + # set api_base, api_version, api_key + base_url: Optional[str] = None, + api_version: Optional[str] = None, + api_key: Optional[str] = None, + model_list: Optional[list] = None, # pass in a list of api_base,keys, etc. + # Optional liteLLM function params + **kwargs, +) -> ModelResponse: + ... + @client def completion( # type: ignore # noqa: PLR0915