mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 10:14:26 +00:00
refactor: add type annotations and overloads to completion functions
- Add type annotations to 'acompletion' and 'completion' functions - Implement overloads for different stream types - Enhance type safety and improve code readability These changes improve the robustness of the code by ensuring proper type checking and making the function signatures more explicit.
This commit is contained in:
parent
7e0680cf94
commit
fbca0351bc
1 changed files with 173 additions and 0 deletions
173
litellm/main.py
173
litellm/main.py
|
@ -33,6 +33,7 @@ from typing import (
|
|||
Type,
|
||||
Union,
|
||||
cast,
|
||||
overload
|
||||
)
|
||||
|
||||
import dotenv
|
||||
|
@ -298,6 +299,89 @@ class AsyncCompletions:
|
|||
return response
|
||||
|
||||
|
||||
@overload
|
||||
async def acompletion(
|
||||
model: str,
|
||||
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
|
||||
stream: Literal[True] = True,
|
||||
messages: List = [],
|
||||
timeout: Optional[Union[float, int]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
n: Optional[int] = None,
|
||||
stream_options: Optional[dict] = None,
|
||||
stop=None,
|
||||
max_tokens: Optional[int] = None,
|
||||
max_completion_tokens: Optional[int] = None,
|
||||
modalities: Optional[List[ChatCompletionModality]] = None,
|
||||
prediction: Optional[ChatCompletionPredictionContentParam] = None,
|
||||
audio: Optional[ChatCompletionAudioParam] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
logit_bias: Optional[dict] = None,
|
||||
user: Optional[str] = None,
|
||||
# openai v1.0+ new params
|
||||
response_format: Optional[Union[dict, Type[BaseModel]]] = None,
|
||||
seed: Optional[int] = None,
|
||||
tools: Optional[List] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
parallel_tool_calls: Optional[bool] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
deployment_id=None,
|
||||
# set api_base, api_version, api_key
|
||||
base_url: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
model_list: Optional[list] = None, # pass in a list of api_base,keys, etc.
|
||||
extra_headers: Optional[dict] = None,
|
||||
# Optional liteLLM function params
|
||||
**kwargs,
|
||||
) -> CustomStreamWrapper:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
async def acompletion(
|
||||
model: str,
|
||||
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
|
||||
stream: Literal[False] = False,
|
||||
messages: List = [],
|
||||
timeout: Optional[Union[float, int]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
n: Optional[int] = None,
|
||||
stream_options: Optional[dict] = None,
|
||||
stop=None,
|
||||
max_tokens: Optional[int] = None,
|
||||
max_completion_tokens: Optional[int] = None,
|
||||
modalities: Optional[List[ChatCompletionModality]] = None,
|
||||
prediction: Optional[ChatCompletionPredictionContentParam] = None,
|
||||
audio: Optional[ChatCompletionAudioParam] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
logit_bias: Optional[dict] = None,
|
||||
user: Optional[str] = None,
|
||||
# openai v1.0+ new params
|
||||
response_format: Optional[Union[dict, Type[BaseModel]]] = None,
|
||||
seed: Optional[int] = None,
|
||||
tools: Optional[List] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
parallel_tool_calls: Optional[bool] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
deployment_id=None,
|
||||
# set api_base, api_version, api_key
|
||||
base_url: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
model_list: Optional[list] = None, # pass in a list of api_base,keys, etc.
|
||||
extra_headers: Optional[dict] = None,
|
||||
# Optional liteLLM function params
|
||||
**kwargs,
|
||||
) -> ModelResponse:
|
||||
...
|
||||
|
||||
@client
|
||||
async def acompletion(
|
||||
model: str,
|
||||
|
@ -690,6 +774,95 @@ def mock_completion(
|
|||
raise e
|
||||
raise Exception("Mock completion response failed")
|
||||
|
||||
@overload
|
||||
def completion( # type: ignore # noqa: PLR0915
|
||||
model: str,
|
||||
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
|
||||
stream: Literal[True] = True,
|
||||
messages: List = [],
|
||||
timeout: Optional[Union[float, str, httpx.Timeout]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
n: Optional[int] = None,
|
||||
stream_options: Optional[dict] = None,
|
||||
stop=None,
|
||||
max_completion_tokens: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
modalities: Optional[List[ChatCompletionModality]] = None,
|
||||
prediction: Optional[ChatCompletionPredictionContentParam] = None,
|
||||
audio: Optional[ChatCompletionAudioParam] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
logit_bias: Optional[dict] = None,
|
||||
user: Optional[str] = None,
|
||||
# openai v1.0+ new params
|
||||
response_format: Optional[Union[dict, Type[BaseModel]]] = None,
|
||||
seed: Optional[int] = None,
|
||||
tools: Optional[List] = None,
|
||||
tool_choice: Optional[Union[str, dict]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
parallel_tool_calls: Optional[bool] = None,
|
||||
deployment_id=None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
# soon to be deprecated params by OpenAI
|
||||
functions: Optional[List] = None,
|
||||
function_call: Optional[str] = None,
|
||||
# set api_base, api_version, api_key
|
||||
base_url: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
model_list: Optional[list] = None, # pass in a list of api_base,keys, etc.
|
||||
# Optional liteLLM function params
|
||||
**kwargs,
|
||||
) -> CustomStreamWrapper:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def completion( # type: ignore # noqa: PLR0915
|
||||
model: str,
|
||||
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
|
||||
stream: Literal[False] = False,
|
||||
messages: List = [],
|
||||
timeout: Optional[Union[float, str, httpx.Timeout]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
n: Optional[int] = None,
|
||||
stream_options: Optional[dict] = None,
|
||||
stop=None,
|
||||
max_completion_tokens: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
modalities: Optional[List[ChatCompletionModality]] = None,
|
||||
prediction: Optional[ChatCompletionPredictionContentParam] = None,
|
||||
audio: Optional[ChatCompletionAudioParam] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
logit_bias: Optional[dict] = None,
|
||||
user: Optional[str] = None,
|
||||
# openai v1.0+ new params
|
||||
response_format: Optional[Union[dict, Type[BaseModel]]] = None,
|
||||
seed: Optional[int] = None,
|
||||
tools: Optional[List] = None,
|
||||
tool_choice: Optional[Union[str, dict]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
parallel_tool_calls: Optional[bool] = None,
|
||||
deployment_id=None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
# soon to be deprecated params by OpenAI
|
||||
functions: Optional[List] = None,
|
||||
function_call: Optional[str] = None,
|
||||
# set api_base, api_version, api_key
|
||||
base_url: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
model_list: Optional[list] = None, # pass in a list of api_base,keys, etc.
|
||||
# Optional liteLLM function params
|
||||
**kwargs,
|
||||
) -> ModelResponse:
|
||||
...
|
||||
|
||||
|
||||
@client
|
||||
def completion( # type: ignore # noqa: PLR0915
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue