Add pyright to ci/cd + Fix remaining type-checking errors (#6082)

* fix: fix type-checking errors

* fix: fix additional type-checking errors

* fix: additional type-checking error fixes

* fix: fix additional type-checking errors

* fix: additional type-check fixes

* fix: fix all type-checking errors + add pyright to ci/cd

* fix: fix incorrect import

* ci(config.yml): use mypy on ci/cd

* fix: fix type-checking errors in utils.py

* fix: fix all type-checking errors on main.py

* fix: fix mypy linting errors

* fix(anthropic/cost_calculator.py): fix linting errors

* fix: fix mypy linting errors

* fix: fix linting errors
This commit is contained in:
Krish Dholakia 2024-10-05 17:04:00 -04:00 committed by GitHub
parent f7ce1173f3
commit fac3b2ee42
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
65 changed files with 619 additions and 522 deletions

View file

@ -19,7 +19,8 @@ import threading
import time
import traceback
import uuid
from concurrent.futures import ThreadPoolExecutor
from concurrent import futures
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
from copy import deepcopy
from functools import partial
from typing import Any, Callable, Dict, List, Literal, Mapping, Optional, Type, Union
@ -647,7 +648,7 @@ def mock_completion(
@client
def completion(
def completion( # type: ignore
model: str,
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
messages: List = [],
@ -2940,16 +2941,16 @@ def completion_with_retries(*args, **kwargs):
num_retries = kwargs.pop("num_retries", 3)
retry_strategy: Literal["exponential_backoff_retry", "constant_retry"] = kwargs.pop("retry_strategy", "constant_retry") # type: ignore
original_function = kwargs.pop("original_function", completion)
if retry_strategy == "constant_retry":
retryer = tenacity.Retrying(
stop=tenacity.stop_after_attempt(num_retries), reraise=True
)
elif retry_strategy == "exponential_backoff_retry":
if retry_strategy == "exponential_backoff_retry":
retryer = tenacity.Retrying(
wait=tenacity.wait_exponential(multiplier=1, max=10),
stop=tenacity.stop_after_attempt(num_retries),
reraise=True,
)
else:
retryer = tenacity.Retrying(
stop=tenacity.stop_after_attempt(num_retries), reraise=True
)
return retryer(original_function, *args, **kwargs)
@ -2968,16 +2969,16 @@ async def acompletion_with_retries(*args, **kwargs):
num_retries = kwargs.pop("num_retries", 3)
retry_strategy = kwargs.pop("retry_strategy", "constant_retry")
original_function = kwargs.pop("original_function", completion)
if retry_strategy == "constant_retry":
retryer = tenacity.Retrying(
stop=tenacity.stop_after_attempt(num_retries), reraise=True
)
elif retry_strategy == "exponential_backoff_retry":
if retry_strategy == "exponential_backoff_retry":
retryer = tenacity.Retrying(
wait=tenacity.wait_exponential(multiplier=1, max=10),
stop=tenacity.stop_after_attempt(num_retries),
reraise=True,
)
else:
retryer = tenacity.Retrying(
stop=tenacity.stop_after_attempt(num_retries), reraise=True
)
return await retryer(original_function, *args, **kwargs)
@ -3045,7 +3046,7 @@ def batch_completion(
temperature=temperature,
top_p=top_p,
n=n,
stream=stream,
stream=stream or False,
stop=stop,
max_tokens=max_tokens,
presence_penalty=presence_penalty,
@ -3124,7 +3125,7 @@ def batch_completion_models(*args, **kwargs):
models = kwargs["models"]
kwargs.pop("models")
futures = {}
with concurrent.futures.ThreadPoolExecutor(max_workers=len(models)) as executor:
with ThreadPoolExecutor(max_workers=len(models)) as executor:
for model in models:
futures[model] = executor.submit(
completion, *args, model=model, **kwargs
@ -3141,9 +3142,7 @@ def batch_completion_models(*args, **kwargs):
kwargs.pop("model_list")
nested_kwargs = kwargs.pop("kwargs", {})
futures = {}
with concurrent.futures.ThreadPoolExecutor(
max_workers=len(deployments)
) as executor:
with ThreadPoolExecutor(max_workers=len(deployments)) as executor:
for deployment in deployments:
for key in kwargs.keys():
if (
@ -3156,9 +3155,7 @@ def batch_completion_models(*args, **kwargs):
while futures:
# wait for the first returned future
print_verbose("\n\n waiting for next result\n\n")
done, _ = concurrent.futures.wait(
futures.values(), return_when=concurrent.futures.FIRST_COMPLETED
)
done, _ = wait(futures.values(), return_when=FIRST_COMPLETED)
print_verbose(f"done list\n{done}")
for future in done:
try:
@ -3214,6 +3211,8 @@ def batch_completion_models_all_responses(*args, **kwargs):
if "models" in kwargs:
models = kwargs["models"]
kwargs.pop("models")
else:
raise Exception("'models' param not in kwargs")
responses = []
@ -3256,6 +3255,7 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse:
model=model, api_base=kwargs.get("api_base", None)
)
response: Optional[EmbeddingResponse] = None
if (
custom_llm_provider == "openai"
or custom_llm_provider == "azure"
@ -3294,12 +3294,21 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse:
elif isinstance(init_response, EmbeddingResponse): ## CACHING SCENARIO
response = init_response
elif asyncio.iscoroutine(init_response):
response = await init_response
response = await init_response # type: ignore
else:
# Call the synchronous function using run_in_executor
response = await loop.run_in_executor(None, func_with_context)
if response is not None and hasattr(response, "_hidden_params"):
if (
response is not None
and isinstance(response, EmbeddingResponse)
and hasattr(response, "_hidden_params")
):
response._hidden_params["custom_llm_provider"] = custom_llm_provider
if response is None:
raise ValueError(
"Unable to get Embedding Response. Please pass a valid llm_provider."
)
return response
except Exception as e:
custom_llm_provider = custom_llm_provider or "openai"
@ -3329,7 +3338,6 @@ def embedding(
user: Optional[str] = None,
custom_llm_provider=None,
litellm_call_id=None,
litellm_logging_obj=None,
logger_fn=None,
**kwargs,
) -> EmbeddingResponse:
@ -3362,6 +3370,7 @@ def embedding(
client = kwargs.pop("client", None)
rpm = kwargs.pop("rpm", None)
tpm = kwargs.pop("tpm", None)
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
cooldown_time = kwargs.get("cooldown_time", None)
max_parallel_requests = kwargs.pop("max_parallel_requests", None)
model_info = kwargs.get("model_info", None)
@ -3491,7 +3500,7 @@ def embedding(
}
)
try:
response = None
response: Optional[EmbeddingResponse] = None
logging: Logging = litellm_logging_obj # type: ignore
logging.update_environment_variables(
model=model,
@ -3691,7 +3700,7 @@ def embedding(
raise ValueError(
"api_base is required for triton. Please pass `api_base`"
)
response = triton_chat_completions.embedding(
response = triton_chat_completions.embedding( # type: ignore
model=model,
input=input,
api_base=api_base,
@ -3783,6 +3792,7 @@ def embedding(
timeout=timeout,
aembedding=aembedding,
print_verbose=print_verbose,
api_key=api_key,
)
elif custom_llm_provider == "oobabooga":
response = oobabooga.embedding(
@ -3793,14 +3803,16 @@ def embedding(
logging_obj=logging,
optional_params=optional_params,
model_response=EmbeddingResponse(),
api_key=api_key,
)
elif custom_llm_provider == "ollama":
api_base = (
litellm.api_base
or api_base
or get_secret("OLLAMA_API_BASE")
or get_secret_str("OLLAMA_API_BASE")
or "http://localhost:11434"
) # type: ignore
if isinstance(input, str):
input = [input]
if not all(isinstance(item, str) for item in input):
@ -3881,13 +3893,13 @@ def embedding(
api_key = (
api_key
or litellm.api_key
or get_secret("XINFERENCE_API_KEY")
or get_secret_str("XINFERENCE_API_KEY")
or "stub-xinference-key"
) # xinference does not need an api key, pass a stub key if user did not set one
api_base = (
api_base
or litellm.api_base
or get_secret("XINFERENCE_API_BASE")
or get_secret_str("XINFERENCE_API_BASE")
or "http://127.0.0.1:9997/v1"
)
response = openai_chat_completions.embedding(
@ -3911,19 +3923,20 @@ def embedding(
optional_params=optional_params,
model_response=EmbeddingResponse(),
aembedding=aembedding,
api_key=api_key,
)
elif custom_llm_provider == "azure_ai":
api_base = (
api_base # for deepinfra/perplexity/anyscale/groq/friendliai we check in get_llm_provider and pass in the api base from there
or litellm.api_base
or get_secret("AZURE_AI_API_BASE")
or get_secret_str("AZURE_AI_API_BASE")
)
# set API KEY
api_key = (
api_key
or litellm.api_key # for deepinfra/perplexity/anyscale/friendliai we check in get_llm_provider and pass in the api key from there
or litellm.openai_key
or get_secret("AZURE_AI_API_KEY")
or get_secret_str("AZURE_AI_API_KEY")
)
## EMBEDDING CALL
@ -3944,10 +3957,14 @@ def embedding(
raise ValueError(f"No valid embedding model args passed in - {args}")
if response is not None and hasattr(response, "_hidden_params"):
response._hidden_params["custom_llm_provider"] = custom_llm_provider
if response is None:
args = locals()
raise ValueError(f"No valid embedding model args passed in - {args}")
return response
except Exception as e:
## LOGGING
logging.post_call(
litellm_logging_obj.post_call(
input=input,
api_key=api_key,
original_response=str(e),
@ -4018,7 +4035,11 @@ async def atext_completion(
else:
# Call the synchronous function using run_in_executor
response = await loop.run_in_executor(None, func_with_context)
if kwargs.get("stream", False) is True: # return an async generator
if (
kwargs.get("stream", False) is True
or isinstance(response, TextCompletionStreamWrapper)
or isinstance(response, CustomStreamWrapper)
): # return an async generator
return TextCompletionStreamWrapper(
completion_stream=_async_streaming(
response=response,
@ -4153,9 +4174,10 @@ def text_completion(
Your example of how to use this function goes here.
"""
if "engine" in kwargs:
if model is None:
_engine = kwargs["engine"]
if model is None and isinstance(_engine, str):
# only use engine when model not passed
model = kwargs["engine"]
model = _engine
kwargs.pop("engine")
text_completion_response = TextCompletionResponse()
@ -4223,7 +4245,7 @@ def text_completion(
def process_prompt(i, individual_prompt):
decoded_prompt = tokenizer.decode(individual_prompt)
all_params = {**kwargs, **optional_params}
response = text_completion(
response: TextCompletionResponse = text_completion( # type: ignore
model=model,
prompt=decoded_prompt,
num_retries=3, # ensure this does not fail for the batch
@ -4292,6 +4314,8 @@ def text_completion(
model = "text-completion-openai/" + _model
optional_params.pop("custom_llm_provider", None)
if model is None:
raise ValueError("model is not set. Set either via 'model' or 'engine' param.")
kwargs["text_completion"] = True
response = completion(
model=model,
@ -4302,7 +4326,11 @@ def text_completion(
)
if kwargs.get("acompletion", False) is True:
return response
if stream is True or kwargs.get("stream", False) is True:
if (
stream is True
or kwargs.get("stream", False) is True
or isinstance(response, CustomStreamWrapper)
):
response = TextCompletionStreamWrapper(
completion_stream=response,
model=model,
@ -4310,6 +4338,8 @@ def text_completion(
custom_llm_provider=custom_llm_provider,
)
return response
elif isinstance(response, TextCompletionStreamWrapper):
return response
transformed_logprobs = None
# only supported for TGI models
try:
@ -4424,7 +4454,10 @@ def moderation(
):
# only supports open ai for now
api_key = (
api_key or litellm.api_key or litellm.openai_key or get_secret("OPENAI_API_KEY")
api_key
or litellm.api_key
or litellm.openai_key
or get_secret_str("OPENAI_API_KEY")
)
openai_client = kwargs.get("client", None)
@ -4433,7 +4466,10 @@ def moderation(
api_key=api_key,
)
response = openai_client.moderations.create(input=input, model=model)
if model is not None:
response = openai_client.moderations.create(input=input, model=model)
else:
response = openai_client.moderations.create(input=input)
return response
@ -4441,20 +4477,30 @@ def moderation(
async def amoderation(
input: str, model: Optional[str] = None, api_key: Optional[str] = None, **kwargs
):
from openai import AsyncOpenAI
# only supports open ai for now
api_key = (
api_key or litellm.api_key or litellm.openai_key or get_secret("OPENAI_API_KEY")
api_key
or litellm.api_key
or litellm.openai_key
or get_secret_str("OPENAI_API_KEY")
)
openai_client = kwargs.get("client", None)
if openai_client is None:
if openai_client is None or not isinstance(openai_client, AsyncOpenAI):
# call helper to get OpenAI client
# _get_openai_client maintains in-memory caching logic for OpenAI clients
openai_client = openai_chat_completions._get_openai_client(
_openai_client: AsyncOpenAI = openai_chat_completions._get_openai_client( # type: ignore
is_async=True,
api_key=api_key,
)
response = await openai_client.moderations.create(input=input, model=model)
else:
_openai_client = openai_client
if model is not None:
response = await openai_client.moderations.create(input=input, model=model)
else:
response = await openai_client.moderations.create(input=input)
return response
@ -4497,7 +4543,7 @@ async def aimage_generation(*args, **kwargs) -> ImageResponse:
init_response = ImageResponse(**init_response)
response = init_response
elif asyncio.iscoroutine(init_response):
response = await init_response
response = await init_response # type: ignore
else:
# Call the synchronous function using run_in_executor
response = await loop.run_in_executor(None, func_with_context)
@ -4527,7 +4573,6 @@ def image_generation(
api_key: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
litellm_logging_obj=None,
custom_llm_provider=None,
**kwargs,
) -> ImageResponse:
@ -4543,9 +4588,10 @@ def image_generation(
proxy_server_request = kwargs.get("proxy_server_request", None)
model_info = kwargs.get("model_info", None)
metadata = kwargs.get("metadata", {})
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
client = kwargs.get("client", None)
model_response = litellm.utils.ImageResponse()
model_response: ImageResponse = litellm.utils.ImageResponse()
if model is not None or custom_llm_provider is not None:
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
else:
@ -4651,25 +4697,27 @@ def image_generation(
if custom_llm_provider == "azure":
# azure configs
api_type = get_secret("AZURE_API_TYPE") or "azure"
api_type = get_secret_str("AZURE_API_TYPE") or "azure"
api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE")
api_base = api_base or litellm.api_base or get_secret_str("AZURE_API_BASE")
api_version = (
api_version or litellm.api_version or get_secret("AZURE_API_VERSION")
api_version
or litellm.api_version
or get_secret_str("AZURE_API_VERSION")
)
api_key = (
api_key
or litellm.api_key
or litellm.azure_key
or get_secret("AZURE_OPENAI_API_KEY")
or get_secret("AZURE_API_KEY")
or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret_str("AZURE_API_KEY")
)
azure_ad_token = optional_params.pop("azure_ad_token", None) or get_secret(
"AZURE_AD_TOKEN"
)
azure_ad_token = optional_params.pop(
"azure_ad_token", None
) or get_secret_str("AZURE_AD_TOKEN")
model_response = azure_chat_completions.image_generation(
model=model,
@ -4714,18 +4762,18 @@ def image_generation(
optional_params.pop("vertex_project", None)
or optional_params.pop("vertex_ai_project", None)
or litellm.vertex_project
or get_secret("VERTEXAI_PROJECT")
or get_secret_str("VERTEXAI_PROJECT")
)
vertex_ai_location = (
optional_params.pop("vertex_location", None)
or optional_params.pop("vertex_ai_location", None)
or litellm.vertex_location
or get_secret("VERTEXAI_LOCATION")
or get_secret_str("VERTEXAI_LOCATION")
)
vertex_credentials = (
optional_params.pop("vertex_credentials", None)
or optional_params.pop("vertex_ai_credentials", None)
or get_secret("VERTEXAI_CREDENTIALS")
or get_secret_str("VERTEXAI_CREDENTIALS")
)
model_response = vertex_image_generation.image_generation(
model=model,
@ -4786,7 +4834,7 @@ async def atranscription(*args, **kwargs) -> TranscriptionResponse:
elif isinstance(init_response, TranscriptionResponse): ## CACHING SCENARIO
response = init_response
elif asyncio.iscoroutine(init_response):
response = await init_response
response = await init_response # type: ignore
else:
# Call the synchronous function using run_in_executor
response = await loop.run_in_executor(None, func_with_context)
@ -4820,7 +4868,6 @@ def transcription(
api_base: Optional[str] = None,
api_version: Optional[str] = None,
max_retries: Optional[int] = None,
litellm_logging_obj: Optional[LiteLLMLoggingObj] = None,
custom_llm_provider=None,
**kwargs,
) -> TranscriptionResponse:
@ -4830,6 +4877,7 @@ def transcription(
Allows router to load balance between them
"""
atranscription = kwargs.get("atranscription", False)
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
kwargs.get("litellm_call_id", None)
kwargs.get("logger_fn", None)
kwargs.get("proxy_server_request", None)
@ -4869,22 +4917,17 @@ def transcription(
custom_llm_provider=custom_llm_provider,
drop_params=drop_params,
)
# optional_params = {
# "language": language,
# "prompt": prompt,
# "response_format": response_format,
# "temperature": None, # openai defaults this to 0
# }
response: Optional[TranscriptionResponse] = None
if custom_llm_provider == "azure":
# azure configs
api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE")
api_base = api_base or litellm.api_base or get_secret_str("AZURE_API_BASE")
api_version = (
api_version or litellm.api_version or get_secret("AZURE_API_VERSION")
api_version or litellm.api_version or get_secret_str("AZURE_API_VERSION")
)
azure_ad_token = kwargs.pop("azure_ad_token", None) or get_secret(
azure_ad_token = kwargs.pop("azure_ad_token", None) or get_secret_str(
"AZURE_AD_TOKEN"
)
@ -4892,8 +4935,8 @@ def transcription(
api_key
or litellm.api_key
or litellm.azure_key
or get_secret("AZURE_API_KEY")
) # type: ignore
or get_secret_str("AZURE_API_KEY")
)
response = azure_audio_transcriptions.audio_transcriptions(
model=model,
@ -4942,6 +4985,9 @@ def transcription(
api_base=api_base,
api_key=api_key,
)
if response is None:
raise ValueError("Unmapped provider passed in. Unable to get the response.")
return response
@ -5149,15 +5195,16 @@ def speech(
vertex_ai_project = (
generic_optional_params.vertex_project
or litellm.vertex_project
or get_secret("VERTEXAI_PROJECT")
or get_secret_str("VERTEXAI_PROJECT")
)
vertex_ai_location = (
generic_optional_params.vertex_location
or litellm.vertex_location
or get_secret("VERTEXAI_LOCATION")
or get_secret_str("VERTEXAI_LOCATION")
)
vertex_credentials = generic_optional_params.vertex_credentials or get_secret(
"VERTEXAI_CREDENTIALS"
vertex_credentials = (
generic_optional_params.vertex_credentials
or get_secret_str("VERTEXAI_CREDENTIALS")
)
if voice is not None and not isinstance(voice, dict):
@ -5234,20 +5281,25 @@ async def ahealth_check(
if custom_llm_provider == "azure":
api_key = (
model_params.get("api_key")
or get_secret("AZURE_API_KEY")
or get_secret("AZURE_OPENAI_API_KEY")
or get_secret_str("AZURE_API_KEY")
or get_secret_str("AZURE_OPENAI_API_KEY")
)
api_base = (
api_base: Optional[str] = (
model_params.get("api_base")
or get_secret("AZURE_API_BASE")
or get_secret("AZURE_OPENAI_API_BASE")
or get_secret_str("AZURE_API_BASE")
or get_secret_str("AZURE_OPENAI_API_BASE")
)
if api_base is None:
raise ValueError(
"Azure API Base cannot be None. Set via 'AZURE_API_BASE' in env var or `.completion(..., api_base=..)`"
)
api_version = (
model_params.get("api_version")
or get_secret("AZURE_API_VERSION")
or get_secret("AZURE_OPENAI_API_VERSION")
or get_secret_str("AZURE_API_VERSION")
or get_secret_str("AZURE_OPENAI_API_VERSION")
)
timeout = (
@ -5273,7 +5325,7 @@ async def ahealth_check(
custom_llm_provider == "openai"
or custom_llm_provider == "text-completion-openai"
):
api_key = model_params.get("api_key") or get_secret("OPENAI_API_KEY")
api_key = model_params.get("api_key") or get_secret_str("OPENAI_API_KEY")
organization = model_params.get("organization")
timeout = (
@ -5282,7 +5334,7 @@ async def ahealth_check(
or default_timeout
)
api_base = model_params.get("api_base") or get_secret("OPENAI_API_BASE")
api_base = model_params.get("api_base") or get_secret_str("OPENAI_API_BASE")
if custom_llm_provider == "text-completion-openai":
mode = "completion"
@ -5377,7 +5429,9 @@ def config_completion(**kwargs):
)
def stream_chunk_builder_text_completion(chunks: list, messages: Optional[List] = None):
def stream_chunk_builder_text_completion(
chunks: list, messages: Optional[List] = None
) -> TextCompletionResponse:
id = chunks[0]["id"]
object = chunks[0]["object"]
created = chunks[0]["created"]
@ -5446,7 +5500,7 @@ def stream_chunk_builder_text_completion(chunks: list, messages: Optional[List]
response["usage"]["total_tokens"] = (
response["usage"]["prompt_tokens"] + response["usage"]["completion_tokens"]
)
return response
return TextCompletionResponse(**response)
def stream_chunk_builder(