Add pyright to ci/cd + Fix remaining type-checking errors (#6082)

* fix: fix type-checking errors

* fix: fix additional type-checking errors

* fix: additional type-checking error fixes

* fix: fix additional type-checking errors

* fix: additional type-check fixes

* fix: fix all type-checking errors + add pyright to ci/cd

* fix: fix incorrect import

* ci(config.yml): use mypy on ci/cd

* fix: fix type-checking errors in utils.py

* fix: fix all type-checking errors on main.py

* fix: fix mypy linting errors

* fix(anthropic/cost_calculator.py): fix linting errors

* fix: fix mypy linting errors

* fix: fix linting errors
This commit is contained in:
Krish Dholakia 2024-10-05 17:04:00 -04:00 committed by GitHub
parent f7ce1173f3
commit fac3b2ee42
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
65 changed files with 619 additions and 522 deletions

View file

@ -322,6 +322,9 @@ def function_setup(
original_function: str, rules_obj, start_time, *args, **kwargs
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
### NOTICES ###
from litellm import Logging as LiteLLMLogging
from litellm.litellm_core_utils.litellm_logging import set_callbacks
if litellm.set_verbose is True:
verbose_logger.warning(
"`litellm.set_verbose` is deprecated. Please set `os.environ['LITELLM_LOG'] = 'DEBUG'` for debug logs."
@ -333,7 +336,7 @@ def function_setup(
custom_llm_setup()
## LOGGING SETUP
function_id = kwargs["id"] if "id" in kwargs else None
function_id: Optional[str] = kwargs["id"] if "id" in kwargs else None
if len(litellm.callbacks) > 0:
for callback in litellm.callbacks:
@ -375,9 +378,7 @@ def function_setup(
+ litellm.failure_callback
)
)
litellm.litellm_core_utils.litellm_logging.set_callbacks(
callback_list=callback_list, function_id=function_id
)
set_callbacks(callback_list=callback_list, function_id=function_id)
## ASYNC CALLBACKS
if len(litellm.input_callback) > 0:
removed_async_items = []
@ -560,12 +561,12 @@ def function_setup(
else:
messages = "default-message-value"
stream = True if "stream" in kwargs and kwargs["stream"] is True else False
logging_obj = litellm.litellm_core_utils.litellm_logging.Logging(
logging_obj = LiteLLMLogging(
model=model,
messages=messages,
stream=stream,
litellm_call_id=kwargs["litellm_call_id"],
function_id=function_id,
function_id=function_id or "",
call_type=call_type,
start_time=start_time,
dynamic_success_callbacks=dynamic_success_callbacks,
@ -655,10 +656,8 @@ def client(original_function):
json_response_format = optional_params[
"response_format"
]
elif (
_parsing._completions.is_basemodel_type(
optional_params["response_format"]
)
elif _parsing._completions.is_basemodel_type(
optional_params["response_format"] # type: ignore
):
json_response_format = (
type_to_response_format_param(
@ -827,6 +826,7 @@ def client(original_function):
print_verbose("INSIDE CHECKING CACHE")
if (
litellm.cache is not None
and litellm.cache.supported_call_types is not None
and str(original_function.__name__)
in litellm.cache.supported_call_types
):
@ -879,7 +879,7 @@ def client(original_function):
dynamic_api_key,
api_base,
) = litellm.get_llm_provider(
model=model,
model=model or "",
custom_llm_provider=kwargs.get(
"custom_llm_provider", None
),
@ -949,6 +949,8 @@ def client(original_function):
base_model=base_model,
messages=messages,
user_max_tokens=user_max_tokens,
buffer_num=None,
buffer_perc=None,
)
kwargs["max_tokens"] = modified_max_tokens
except Exception as e:
@ -990,6 +992,7 @@ def client(original_function):
# [OPTIONAL] ADD TO CACHE
if (
litellm.cache is not None
and litellm.cache.supported_call_types is not None
and str(original_function.__name__)
in litellm.cache.supported_call_types
) and (kwargs.get("cache", {}).get("no-store", False) is not True):
@ -1006,7 +1009,7 @@ def client(original_function):
"id", None
)
result._hidden_params["api_base"] = get_api_base(
model=model,
model=model or "",
optional_params=getattr(logging_obj, "optional_params", {}),
)
result._hidden_params["response_cost"] = (
@ -1053,7 +1056,7 @@ def client(original_function):
and not _is_litellm_router_call
):
if len(args) > 0:
args[0] = context_window_fallback_dict[model]
args[0] = context_window_fallback_dict[model] # type: ignore
else:
kwargs["model"] = context_window_fallback_dict[model]
return original_function(*args, **kwargs)
@ -1065,12 +1068,6 @@ def client(original_function):
logging_obj.failure_handler(
e, traceback_exception, start_time, end_time
) # DO NOT MAKE THREADED - router retry fallback relies on this!
if hasattr(e, "message"):
if (
liteDebuggerClient
and liteDebuggerClient.dashboard_url is not None
): # make it easy to get to the debugger logs if you've initialized it
e.message += f"\n Check the log in your dashboard - {liteDebuggerClient.dashboard_url}"
raise e
@wraps(original_function)
@ -1126,6 +1123,7 @@ def client(original_function):
print_verbose("INSIDE CHECKING CACHE")
if (
litellm.cache is not None
and litellm.cache.supported_call_types is not None
and str(original_function.__name__)
in litellm.cache.supported_call_types
):
@ -1287,7 +1285,11 @@ def client(original_function):
args=(cached_result, start_time, end_time, cache_hit),
).start()
cache_key = kwargs.get("preset_cache_key", None)
cached_result._hidden_params["cache_key"] = cache_key
if (
isinstance(cached_result, BaseModel)
or isinstance(cached_result, CustomStreamWrapper)
) and hasattr(cached_result, "_hidden_params"):
cached_result._hidden_params["cache_key"] = cache_key # type: ignore
return cached_result
elif (
call_type == CallTypes.aembedding.value
@ -1447,6 +1449,7 @@ def client(original_function):
# [OPTIONAL] ADD TO CACHE
if (
(litellm.cache is not None)
and litellm.cache.supported_call_types is not None
and (
str(original_function.__name__)
in litellm.cache.supported_call_types
@ -1504,11 +1507,12 @@ def client(original_function):
if (
isinstance(result, EmbeddingResponse)
and final_embedding_cached_response is not None
and final_embedding_cached_response.data is not None
):
idx = 0
final_data_list = []
for item in final_embedding_cached_response.data:
if item is None:
if item is None and result.data is not None:
final_data_list.append(result.data[idx])
idx += 1
else:
@ -1575,7 +1579,7 @@ def client(original_function):
and model in context_window_fallback_dict
):
if len(args) > 0:
args[0] = context_window_fallback_dict[model]
args[0] = context_window_fallback_dict[model] # type: ignore
else:
kwargs["model"] = context_window_fallback_dict[model]
return await original_function(*args, **kwargs)
@ -2945,13 +2949,19 @@ def get_optional_params(
response_format=non_default_params["response_format"]
)
# # clean out 'additionalProperties = False'. Causes vertexai/gemini OpenAI API Schema errors - https://github.com/langchain-ai/langchainjs/issues/5240
if non_default_params["response_format"].get("json_schema", {}).get(
"schema"
) is not None and custom_llm_provider in [
"gemini",
"vertex_ai",
"vertex_ai_beta",
]:
if (
non_default_params["response_format"] is not None
and non_default_params["response_format"]
.get("json_schema", {})
.get("schema")
is not None
and custom_llm_provider
in [
"gemini",
"vertex_ai",
"vertex_ai_beta",
]
):
old_schema = copy.deepcopy(
non_default_params["response_format"]
.get("json_schema", {})
@ -3754,7 +3764,11 @@ def get_optional_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=drop_params,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
elif custom_llm_provider == "openrouter":
supported_params = get_supported_openai_params(
@ -3863,7 +3877,11 @@ def get_optional_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=drop_params,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
elif custom_llm_provider == "azure":
supported_params = get_supported_openai_params(
@ -4889,7 +4907,7 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
try:
split_model, custom_llm_provider, _, _ = get_llm_provider(model=model)
except Exception:
pass
split_model = model
combined_model_name = model
stripped_model_name = _strip_model_name(model=model)
combined_stripped_model_name = stripped_model_name
@ -5865,6 +5883,8 @@ def convert_to_model_response_object(
for idx, choice in enumerate(response_object["choices"]):
## HANDLE JSON MODE - anthropic returns single function call]
tool_calls = choice["message"].get("tool_calls", None)
message: Optional[Message] = None
finish_reason: Optional[str] = None
if (
convert_tool_call_to_json_mode
and tool_calls is not None
@ -5877,7 +5897,7 @@ def convert_to_model_response_object(
if json_mode_content_str is not None:
message = litellm.Message(content=json_mode_content_str)
finish_reason = "stop"
else:
if message is None:
message = Message(
content=choice["message"].get("content", None),
role=choice["message"]["role"] or "assistant",
@ -6066,7 +6086,7 @@ def valid_model(model):
model in litellm.open_ai_chat_completion_models
or model in litellm.open_ai_text_completion_models
):
openai.Model.retrieve(model)
openai.models.retrieve(model)
else:
messages = [{"role": "user", "content": "Hello World"}]
litellm.completion(model=model, messages=messages)
@ -6386,8 +6406,8 @@ class CustomStreamWrapper:
self,
completion_stream,
model,
custom_llm_provider=None,
logging_obj=None,
logging_obj: Any,
custom_llm_provider: Optional[str] = None,
stream_options=None,
make_call: Optional[Callable] = None,
_response_headers: Optional[dict] = None,
@ -6633,36 +6653,6 @@ class CustomStreamWrapper:
"completion_tokens": completion_tokens,
}
def handle_together_ai_chunk(self, chunk):
chunk = chunk.decode("utf-8")
text = ""
is_finished = False
finish_reason = None
if "text" in chunk:
text_index = chunk.find('"text":"') # this checks if text: exists
text_start = text_index + len('"text":"')
text_end = chunk.find('"}', text_start)
if text_index != -1 and text_end != -1:
extracted_text = chunk[text_start:text_end]
text = extracted_text
return {
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
}
elif "[DONE]" in chunk:
return {"text": text, "is_finished": True, "finish_reason": "stop"}
elif "error" in chunk:
raise litellm.together_ai.TogetherAIError(
status_code=422, message=f"{str(chunk)}"
)
else:
return {
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
}
def handle_predibase_chunk(self, chunk):
try:
if not isinstance(chunk, str):
@ -7264,12 +7254,17 @@ class CustomStreamWrapper:
try:
if isinstance(chunk, dict):
parsed_response = chunk
if isinstance(chunk, (str, bytes)):
elif isinstance(chunk, (str, bytes)):
if isinstance(chunk, bytes):
parsed_response = chunk.decode("utf-8")
else:
parsed_response = chunk
data_json = json.loads(parsed_response)
else:
raise ValueError("Unable to parse streaming chunk")
if isinstance(parsed_response, dict):
data_json = parsed_response
else:
data_json = json.loads(parsed_response)
text = (
data_json.get("outputs", "")[0]
.get("data", "")
@ -7331,8 +7326,7 @@ class CustomStreamWrapper:
if (
len(model_response.choices) > 0
and hasattr(model_response.choices[0], "delta")
and model_response.choices[0].delta is not None
and getattr(model_response.choices[0], "delta") is not None
):
# do nothing, if object instantiated
pass
@ -7350,7 +7344,7 @@ class CustomStreamWrapper:
is_empty = False
return is_empty
def chunk_creator(self, chunk):
def chunk_creator(self, chunk): # type: ignore
model_response = self.model_response_creator()
response_obj = {}
try:
@ -7422,11 +7416,6 @@ class CustomStreamWrapper:
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider and self.custom_llm_provider == "together_ai":
response_obj = self.handle_together_ai_chunk(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider and self.custom_llm_provider == "huggingface":
response_obj = self.handle_huggingface_chunk(chunk)
completion_obj["content"] = response_obj["text"]
@ -7475,51 +7464,6 @@ class CustomStreamWrapper:
if self.sent_first_chunk is False:
raise Exception("An unknown error occurred with the stream")
self.received_finish_reason = "stop"
elif self.custom_llm_provider == "gemini":
if hasattr(chunk, "parts") is True:
try:
if len(chunk.parts) > 0:
completion_obj["content"] = chunk.parts[0].text
if len(chunk.parts) > 0 and hasattr(
chunk.parts[0], "finish_reason"
):
self.received_finish_reason = chunk.parts[
0
].finish_reason.name
except Exception:
if chunk.parts[0].finish_reason.name == "SAFETY":
raise Exception(
f"The response was blocked by VertexAI. {str(chunk)}"
)
else:
completion_obj["content"] = str(chunk)
elif self.custom_llm_provider and (
self.custom_llm_provider == "vertex_ai_beta"
):
from litellm.types.utils import (
GenericStreamingChunk as UtilsStreamingChunk,
)
if self.received_finish_reason is not None:
raise StopIteration
response_obj: UtilsStreamingChunk = chunk
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
if (
self.stream_options
and self.stream_options.get("include_usage", False) is True
and response_obj["usage"] is not None
):
model_response.usage = litellm.Usage(
prompt_tokens=response_obj["usage"]["prompt_tokens"],
completion_tokens=response_obj["usage"]["completion_tokens"],
total_tokens=response_obj["usage"]["total_tokens"],
)
if "tool_use" in response_obj and response_obj["tool_use"] is not None:
completion_obj["tool_calls"] = [response_obj["tool_use"]]
elif self.custom_llm_provider and (self.custom_llm_provider == "vertex_ai"):
import proto # type: ignore
@ -7624,53 +7568,7 @@ class CustomStreamWrapper:
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider == "bedrock":
from litellm.types.llms.bedrock import GenericStreamingChunk
if self.received_finish_reason is not None:
raise StopIteration
response_obj: GenericStreamingChunk = chunk
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
if (
self.stream_options
and self.stream_options.get("include_usage", False) is True
and response_obj["usage"] is not None
):
model_response.usage = litellm.Usage(
prompt_tokens=response_obj["usage"]["inputTokens"],
completion_tokens=response_obj["usage"]["outputTokens"],
total_tokens=response_obj["usage"]["totalTokens"],
)
if "tool_use" in response_obj and response_obj["tool_use"] is not None:
completion_obj["tool_calls"] = [response_obj["tool_use"]]
elif self.custom_llm_provider == "sagemaker":
from litellm.types.llms.bedrock import GenericStreamingChunk
if self.received_finish_reason is not None:
raise StopIteration
response_obj: GenericStreamingChunk = chunk
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
if (
self.stream_options
and self.stream_options.get("include_usage", False) is True
and response_obj["usage"] is not None
):
model_response.usage = litellm.Usage(
prompt_tokens=response_obj["usage"]["inputTokens"],
completion_tokens=response_obj["usage"]["outputTokens"],
total_tokens=response_obj["usage"]["totalTokens"],
)
if "tool_use" in response_obj and response_obj["tool_use"] is not None:
completion_obj["tool_calls"] = [response_obj["tool_use"]]
elif self.custom_llm_provider == "petals":
if len(self.completion_stream) == 0:
if self.received_finish_reason is not None:
@ -8181,9 +8079,11 @@ class CustomStreamWrapper:
target=self.run_success_logging_in_thread,
args=(response, cache_hit),
).start() # log response
self.response_uptil_now += (
response.choices[0].delta.get("content", "") or ""
)
choice = response.choices[0]
if isinstance(choice, StreamingChoices):
self.response_uptil_now += choice.delta.get("content", "") or ""
else:
self.response_uptil_now += ""
self.rules.post_call_rules(
input=self.response_uptil_now, model=self.model
)
@ -8223,8 +8123,11 @@ class CustomStreamWrapper:
)
response = self.model_response_creator()
if complete_streaming_response is not None:
response.usage = complete_streaming_response.usage
response._hidden_params["usage"] = complete_streaming_response.usage # type: ignore
setattr(
response,
"usage",
getattr(complete_streaming_response, "usage"),
)
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler,
@ -8349,9 +8252,11 @@ class CustomStreamWrapper:
processed_chunk, cache_hit=cache_hit
)
)
self.response_uptil_now += (
processed_chunk.choices[0].delta.get("content", "") or ""
)
choice = processed_chunk.choices[0]
if isinstance(choice, StreamingChoices):
self.response_uptil_now += choice.delta.get("content", "") or ""
else:
self.response_uptil_now += ""
self.rules.post_call_rules(
input=self.response_uptil_now, model=self.model
)
@ -8401,9 +8306,13 @@ class CustomStreamWrapper:
)
)
self.response_uptil_now += (
processed_chunk.choices[0].delta.get("content", "") or ""
)
choice = processed_chunk.choices[0]
if isinstance(choice, StreamingChoices):
self.response_uptil_now += (
choice.delta.get("content", "") or ""
)
else:
self.response_uptil_now += ""
self.rules.post_call_rules(
input=self.response_uptil_now, model=self.model
)
@ -8423,7 +8332,11 @@ class CustomStreamWrapper:
)
response = self.model_response_creator()
if complete_streaming_response is not None:
setattr(response, "usage", complete_streaming_response.usage)
setattr(
response,
"usage",
getattr(complete_streaming_response, "usage"),
)
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler,
@ -8464,7 +8377,11 @@ class CustomStreamWrapper:
)
response = self.model_response_creator()
if complete_streaming_response is not None:
response.usage = complete_streaming_response.usage
setattr(
response,
"usage",
getattr(complete_streaming_response, "usage"),
)
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler,
@ -8898,7 +8815,7 @@ def trim_messages(
if len(tool_messages):
messages = messages[: -len(tool_messages)]
current_tokens = token_counter(model=model, messages=messages)
current_tokens = token_counter(model=model or "", messages=messages)
print_verbose(f"Current tokens: {current_tokens}, max tokens: {max_tokens}")
# Do nothing if current tokens under messages
@ -8909,6 +8826,7 @@ def trim_messages(
print_verbose(
f"Need to trim input messages: {messages}, current_tokens{current_tokens}, max_tokens: {max_tokens}"
)
system_message_event: Optional[dict] = None
if system_message:
system_message_event, max_tokens = process_system_message(
system_message=system_message, max_tokens=max_tokens, model=model
@ -8926,7 +8844,7 @@ def trim_messages(
)
# Add system message to the beginning of the final messages
if system_message:
if system_message_event:
final_messages = [system_message_event] + final_messages
if len(tool_messages) > 0:
@ -9214,6 +9132,8 @@ def is_cached_message(message: AllMessageValues) -> bool:
Follows the anthropic format {"cache_control": {"type": "ephemeral"}}
"""
if "content" not in message:
return False
if message["content"] is None or isinstance(message["content"], str):
return False