all fixes to linting

This commit is contained in:
ishaan-jaff 2023-08-18 11:56:44 -07:00
parent 5e7d22512d
commit 1bb2aefea1
8 changed files with 44 additions and 34 deletions

View file

@ -1,24 +1,25 @@
import threading
from typing import Callable, List, Optional
success_callback = []
failure_callback = []
success_callback: List[str] = []
failure_callback: List[str] = []
set_verbose = False
telemetry = True
max_tokens = 256 # OpenAI Defaults
retry = True
api_key = None
openai_key = None
azure_key = None
anthropic_key = None
replicate_key = None
cohere_key = None
openrouter_key = None
huggingface_key = None
vertex_project = None
vertex_location = None
api_key: Optional[str] = None
openai_key: Optional[str] = None
azure_key: Optional[str] = None
anthropic_key: Optional[str] = None
replicate_key: Optional[str] = None
cohere_key: Optional[str] = None
openrouter_key: Optional[str] = None
huggingface_key: Optional[str] = None
vertex_project: Optional[str] = None
vertex_location: Optional[str] = None
hugging_api_token: Optional[str] = None
togetherai_api_key: Optional[str] = None
caching = False
hugging_api_token = None
togetherai_api_key = None
model_cost = {
"gpt-3.5-turbo": {
"max_tokens": 4000,
@ -223,7 +224,7 @@ from .utils import (
completion_cost,
get_litellm_params,
)
from .main import * # Import all the symbols from main.py
from .main import * # type: ignore
from .integrations import *
from openai.error import (
AuthenticationError,

View file

@ -8,7 +8,7 @@ from openai.error import (
)
class AuthenticationError(AuthenticationError): # type: ignore
class AuthenticationError(AuthenticationError): # type: ignore
def __init__(self, message, llm_provider):
self.status_code = 401
self.message = message
@ -18,7 +18,7 @@ class AuthenticationError(AuthenticationError): # type: ignore
) # Call the base class constructor with the parameters it needs
class InvalidRequestError(InvalidRequestError): # type: ignore
class InvalidRequestError(InvalidRequestError): # type: ignore
def __init__(self, message, model, llm_provider):
self.status_code = 400
self.message = message
@ -29,7 +29,7 @@ class InvalidRequestError(InvalidRequestError): # type: ignore
) # Call the base class constructor with the parameters it needs
class RateLimitError(RateLimitError): # type: ignore
class RateLimitError(RateLimitError): # type: ignore
def __init__(self, message, llm_provider):
self.status_code = 429
self.message = message
@ -39,7 +39,7 @@ class RateLimitError(RateLimitError): # type: ignore
) # Call the base class constructor with the parameters it needs
class ServiceUnavailableError(ServiceUnavailableError): # type: ignore
class ServiceUnavailableError(ServiceUnavailableError): # type: ignore
def __init__(self, message, llm_provider):
self.status_code = 500
self.message = message
@ -49,7 +49,7 @@ class ServiceUnavailableError(ServiceUnavailableError): # type: ignore
) # Call the base class constructor with the parameters it needs
class OpenAIError(OpenAIError): # type: ignore
class OpenAIError(OpenAIError): # type: ignore
def __init__(self, original_exception):
self.status_code = original_exception.http_status
super().__init__(

View file

@ -141,5 +141,7 @@ class AnthropicLLM:
}
return model_response
def embedding(self): # logic for parsing in - calling - parsing out model embedding calls
def embedding(
self,
): # logic for parsing in - calling - parsing out model embedding calls
pass

View file

@ -5,8 +5,12 @@ class BaseLLM:
def validate_environment(self): # set up the environment required to run the model
pass
def completion(self): # logic for parsing in - calling - parsing out model completion calls
def completion(
self,
): # logic for parsing in - calling - parsing out model completion calls
pass
def embedding(self): # logic for parsing in - calling - parsing out model embedding calls
def embedding(
self,
): # logic for parsing in - calling - parsing out model embedding calls
pass

View file

@ -45,10 +45,11 @@ class HuggingfaceRestAPILLM:
litellm_params=None,
logger_fn=None,
): # logic for parsing in - calling - parsing out model completion calls
completion_url: str = ""
if custom_api_base:
completion_url: Optional[str] = custom_api_base
completion_url = custom_api_base
elif "HF_API_BASE" in os.environ:
completion_url = os.getenv("HF_API_BASE")
completion_url = os.getenv("HF_API_BASE", "")
else:
completion_url = f"https://api-inference.huggingface.co/models/{model}"
prompt = ""
@ -137,5 +138,7 @@ class HuggingfaceRestAPILLM:
return model_response
pass
def embedding(self): # logic for parsing in - calling - parsing out model embedding calls
def embedding(
self,
): # logic for parsing in - calling - parsing out model embedding calls
pass

View file

@ -4,7 +4,7 @@ from functools import partial
import dotenv, traceback, random, asyncio, time
from copy import deepcopy
import litellm
from litellm import (
from litellm import ( # type: ignore
client,
logging,
exception_type,
@ -55,7 +55,7 @@ async def acompletion(*args, **kwargs):
@client
# @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(2), reraise=True, retry_error_callback=lambda retry_state: setattr(retry_state.outcome, 'retry_variable', litellm.retry)) # retry call, turn this off by setting `litellm.retry = False`
@timeout(
@timeout( # type: ignore
600
) ## set timeouts, in case calls hang (e.g. Azure) - default is 60s, override with `force_timeout`
def completion(
@ -266,7 +266,7 @@ def completion(
or litellm.replicate_key
)
# set replicate key
os.environ["REPLICATE_API_TOKEN"]: str = replicate_key
os.environ["REPLICATE_API_TOKEN"] = str(replicate_key)
prompt = " ".join([message["content"] for message in messages])
input = {"prompt": prompt}
if "max_tokens" in optional_params:
@ -807,7 +807,7 @@ def batch_completion(*args, **kwargs):
### EMBEDDING ENDPOINTS ####################
@client
@timeout(
@timeout( # type: ignore
60
) ## set timeouts, in case calls hang (e.g. Azure) - default is 60s, override with `force_timeout`
def embedding(model, input=[], azure=False, force_timeout=60, logger_fn=None):

View file

@ -21,8 +21,8 @@ user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}]
## Test 1: Setting key dynamically
temp_key = os.environ.get("ANTHROPIC_API_KEY")
os.environ["ANTHROPIC_API_KEY"]: str = "bad-key"
temp_key = os.environ.get("ANTHROPIC_API_KEY", "")
os.environ["ANTHROPIC_API_KEY"] = "bad-key"
# test on openai completion call
try:
response = completion(
@ -39,7 +39,7 @@ os.environ["ANTHROPIC_API_KEY"] = temp_key
## Test 2: Setting key via __init__ params
litellm.anthropic_key: str = os.environ.get("ANTHROPIC_API_KEY")
litellm.anthropic_key = os.environ.get("ANTHROPIC_API_KEY", "")
os.environ.pop("ANTHROPIC_API_KEY")
# test on openai completion call
try:

View file

@ -50,4 +50,4 @@ try:
except:
print(f"error occurred: {traceback.format_exc()}")
pass
os.environ["OPENAI_API_KEY"] = temp_key
os.environ["OPENAI_API_KEY"] = str(temp_key)