diff --git a/litellm/__init__.py b/litellm/__init__.py index 1bff66e6c..f607998e1 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -1,24 +1,25 @@ import threading +from typing import Callable, List, Optional -success_callback = [] -failure_callback = [] +success_callback: List[str] = [] +failure_callback: List[str] = [] set_verbose = False telemetry = True max_tokens = 256 # OpenAI Defaults retry = True -api_key = None -openai_key = None -azure_key = None -anthropic_key = None -replicate_key = None -cohere_key = None -openrouter_key = None -huggingface_key = None -vertex_project = None -vertex_location = None +api_key: Optional[str] = None +openai_key: Optional[str] = None +azure_key: Optional[str] = None +anthropic_key: Optional[str] = None +replicate_key: Optional[str] = None +cohere_key: Optional[str] = None +openrouter_key: Optional[str] = None +huggingface_key: Optional[str] = None +vertex_project: Optional[str] = None +vertex_location: Optional[str] = None +hugging_api_token: Optional[str] = None +togetherai_api_key: Optional[str] = None caching = False -hugging_api_token = None -togetherai_api_key = None model_cost = { "gpt-3.5-turbo": { "max_tokens": 4000, @@ -223,7 +224,7 @@ from .utils import ( completion_cost, get_litellm_params, ) -from .main import * # Import all the symbols from main.py +from .main import * # type: ignore from .integrations import * from openai.error import ( AuthenticationError, diff --git a/litellm/exceptions.py b/litellm/exceptions.py index c440d7eab..7b48a343d 100644 --- a/litellm/exceptions.py +++ b/litellm/exceptions.py @@ -8,7 +8,7 @@ from openai.error import ( ) -class AuthenticationError(AuthenticationError): # type: ignore +class AuthenticationError(AuthenticationError): # type: ignore def __init__(self, message, llm_provider): self.status_code = 401 self.message = message @@ -18,7 +18,7 @@ class AuthenticationError(AuthenticationError): # type: ignore ) # Call the base class constructor with the parameters it needs -class InvalidRequestError(InvalidRequestError): # type: ignore +class InvalidRequestError(InvalidRequestError): # type: ignore def __init__(self, message, model, llm_provider): self.status_code = 400 self.message = message @@ -29,7 +29,7 @@ class InvalidRequestError(InvalidRequestError): # type: ignore ) # Call the base class constructor with the parameters it needs -class RateLimitError(RateLimitError): # type: ignore +class RateLimitError(RateLimitError): # type: ignore def __init__(self, message, llm_provider): self.status_code = 429 self.message = message @@ -39,7 +39,7 @@ class RateLimitError(RateLimitError): # type: ignore ) # Call the base class constructor with the parameters it needs -class ServiceUnavailableError(ServiceUnavailableError): # type: ignore +class ServiceUnavailableError(ServiceUnavailableError): # type: ignore def __init__(self, message, llm_provider): self.status_code = 500 self.message = message @@ -49,7 +49,7 @@ class ServiceUnavailableError(ServiceUnavailableError): # type: ignore ) # Call the base class constructor with the parameters it needs -class OpenAIError(OpenAIError): # type: ignore +class OpenAIError(OpenAIError): # type: ignore def __init__(self, original_exception): self.status_code = original_exception.http_status super().__init__( diff --git a/litellm/llms/anthropic.py b/litellm/llms/anthropic.py index 2ea07215b..5ebbc640a 100644 --- a/litellm/llms/anthropic.py +++ b/litellm/llms/anthropic.py @@ -141,5 +141,7 @@ class AnthropicLLM: } return model_response - def embedding(self): # logic for parsing in - calling - parsing out model embedding calls + def embedding( + self, + ): # logic for parsing in - calling - parsing out model embedding calls pass diff --git a/litellm/llms/base.py b/litellm/llms/base.py index bf6a3dd3a..2c05a00c1 100644 --- a/litellm/llms/base.py +++ b/litellm/llms/base.py @@ -5,8 +5,12 @@ class BaseLLM: def validate_environment(self): # set up the environment required to run the model pass - def completion(self): # logic for parsing in - calling - parsing out model completion calls + def completion( + self, + ): # logic for parsing in - calling - parsing out model completion calls pass - def embedding(self): # logic for parsing in - calling - parsing out model embedding calls + def embedding( + self, + ): # logic for parsing in - calling - parsing out model embedding calls pass diff --git a/litellm/llms/huggingface_restapi.py b/litellm/llms/huggingface_restapi.py index 23ac16bef..709c2347e 100644 --- a/litellm/llms/huggingface_restapi.py +++ b/litellm/llms/huggingface_restapi.py @@ -45,10 +45,11 @@ class HuggingfaceRestAPILLM: litellm_params=None, logger_fn=None, ): # logic for parsing in - calling - parsing out model completion calls + completion_url: str = "" if custom_api_base: - completion_url: Optional[str] = custom_api_base + completion_url = custom_api_base elif "HF_API_BASE" in os.environ: - completion_url = os.getenv("HF_API_BASE") + completion_url = os.getenv("HF_API_BASE", "") else: completion_url = f"https://api-inference.huggingface.co/models/{model}" prompt = "" @@ -137,5 +138,7 @@ class HuggingfaceRestAPILLM: return model_response pass - def embedding(self): # logic for parsing in - calling - parsing out model embedding calls + def embedding( + self, + ): # logic for parsing in - calling - parsing out model embedding calls pass diff --git a/litellm/main.py b/litellm/main.py index 1fc89e2e7..cbf2bde4a 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -4,7 +4,7 @@ from functools import partial import dotenv, traceback, random, asyncio, time from copy import deepcopy import litellm -from litellm import ( +from litellm import ( # type: ignore client, logging, exception_type, @@ -55,7 +55,7 @@ async def acompletion(*args, **kwargs): @client # @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(2), reraise=True, retry_error_callback=lambda retry_state: setattr(retry_state.outcome, 'retry_variable', litellm.retry)) # retry call, turn this off by setting `litellm.retry = False` -@timeout( +@timeout( # type: ignore 600 ) ## set timeouts, in case calls hang (e.g. Azure) - default is 60s, override with `force_timeout` def completion( @@ -266,7 +266,7 @@ def completion( or litellm.replicate_key ) # set replicate key - os.environ["REPLICATE_API_TOKEN"]: str = replicate_key + os.environ["REPLICATE_API_TOKEN"] = str(replicate_key) prompt = " ".join([message["content"] for message in messages]) input = {"prompt": prompt} if "max_tokens" in optional_params: @@ -807,7 +807,7 @@ def batch_completion(*args, **kwargs): ### EMBEDDING ENDPOINTS #################### @client -@timeout( +@timeout( # type: ignore 60 ) ## set timeouts, in case calls hang (e.g. Azure) - default is 60s, override with `force_timeout` def embedding(model, input=[], azure=False, force_timeout=60, logger_fn=None): diff --git a/litellm/tests/test_api_key_param.py b/litellm/tests/test_api_key_param.py index c444b3904..40f7a12b0 100644 --- a/litellm/tests/test_api_key_param.py +++ b/litellm/tests/test_api_key_param.py @@ -21,8 +21,8 @@ user_message = "Hello, how are you?" messages = [{"content": user_message, "role": "user"}] ## Test 1: Setting key dynamically -temp_key = os.environ.get("ANTHROPIC_API_KEY") -os.environ["ANTHROPIC_API_KEY"]: str = "bad-key" +temp_key = os.environ.get("ANTHROPIC_API_KEY", "") +os.environ["ANTHROPIC_API_KEY"] = "bad-key" # test on openai completion call try: response = completion( @@ -39,7 +39,7 @@ os.environ["ANTHROPIC_API_KEY"] = temp_key ## Test 2: Setting key via __init__ params -litellm.anthropic_key: str = os.environ.get("ANTHROPIC_API_KEY") +litellm.anthropic_key = os.environ.get("ANTHROPIC_API_KEY", "") os.environ.pop("ANTHROPIC_API_KEY") # test on openai completion call try: diff --git a/litellm/tests/test_bad_params.py b/litellm/tests/test_bad_params.py index 71cbffe56..a85613a50 100644 --- a/litellm/tests/test_bad_params.py +++ b/litellm/tests/test_bad_params.py @@ -50,4 +50,4 @@ try: except: print(f"error occurred: {traceback.format_exc()}") pass -os.environ["OPENAI_API_KEY"] = temp_key +os.environ["OPENAI_API_KEY"] = str(temp_key)