all fixes to linting

This commit is contained in:
ishaan-jaff 2023-08-18 11:56:44 -07:00
parent 5e7d22512d
commit 1bb2aefea1
8 changed files with 44 additions and 34 deletions

View file

@ -1,24 +1,25 @@
import threading import threading
from typing import Callable, List, Optional
success_callback = [] success_callback: List[str] = []
failure_callback = [] failure_callback: List[str] = []
set_verbose = False set_verbose = False
telemetry = True telemetry = True
max_tokens = 256 # OpenAI Defaults max_tokens = 256 # OpenAI Defaults
retry = True retry = True
api_key = None api_key: Optional[str] = None
openai_key = None openai_key: Optional[str] = None
azure_key = None azure_key: Optional[str] = None
anthropic_key = None anthropic_key: Optional[str] = None
replicate_key = None replicate_key: Optional[str] = None
cohere_key = None cohere_key: Optional[str] = None
openrouter_key = None openrouter_key: Optional[str] = None
huggingface_key = None huggingface_key: Optional[str] = None
vertex_project = None vertex_project: Optional[str] = None
vertex_location = None vertex_location: Optional[str] = None
hugging_api_token: Optional[str] = None
togetherai_api_key: Optional[str] = None
caching = False caching = False
hugging_api_token = None
togetherai_api_key = None
model_cost = { model_cost = {
"gpt-3.5-turbo": { "gpt-3.5-turbo": {
"max_tokens": 4000, "max_tokens": 4000,
@ -223,7 +224,7 @@ from .utils import (
completion_cost, completion_cost,
get_litellm_params, get_litellm_params,
) )
from .main import * # Import all the symbols from main.py from .main import * # type: ignore
from .integrations import * from .integrations import *
from openai.error import ( from openai.error import (
AuthenticationError, AuthenticationError,

View file

@ -8,7 +8,7 @@ from openai.error import (
) )
class AuthenticationError(AuthenticationError): # type: ignore class AuthenticationError(AuthenticationError): # type: ignore
def __init__(self, message, llm_provider): def __init__(self, message, llm_provider):
self.status_code = 401 self.status_code = 401
self.message = message self.message = message
@ -18,7 +18,7 @@ class AuthenticationError(AuthenticationError): # type: ignore
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class InvalidRequestError(InvalidRequestError): # type: ignore class InvalidRequestError(InvalidRequestError): # type: ignore
def __init__(self, message, model, llm_provider): def __init__(self, message, model, llm_provider):
self.status_code = 400 self.status_code = 400
self.message = message self.message = message
@ -29,7 +29,7 @@ class InvalidRequestError(InvalidRequestError): # type: ignore
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class RateLimitError(RateLimitError): # type: ignore class RateLimitError(RateLimitError): # type: ignore
def __init__(self, message, llm_provider): def __init__(self, message, llm_provider):
self.status_code = 429 self.status_code = 429
self.message = message self.message = message
@ -39,7 +39,7 @@ class RateLimitError(RateLimitError): # type: ignore
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class ServiceUnavailableError(ServiceUnavailableError): # type: ignore class ServiceUnavailableError(ServiceUnavailableError): # type: ignore
def __init__(self, message, llm_provider): def __init__(self, message, llm_provider):
self.status_code = 500 self.status_code = 500
self.message = message self.message = message
@ -49,7 +49,7 @@ class ServiceUnavailableError(ServiceUnavailableError): # type: ignore
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class OpenAIError(OpenAIError): # type: ignore class OpenAIError(OpenAIError): # type: ignore
def __init__(self, original_exception): def __init__(self, original_exception):
self.status_code = original_exception.http_status self.status_code = original_exception.http_status
super().__init__( super().__init__(

View file

@ -141,5 +141,7 @@ class AnthropicLLM:
} }
return model_response return model_response
def embedding(self): # logic for parsing in - calling - parsing out model embedding calls def embedding(
self,
): # logic for parsing in - calling - parsing out model embedding calls
pass pass

View file

@ -5,8 +5,12 @@ class BaseLLM:
def validate_environment(self): # set up the environment required to run the model def validate_environment(self): # set up the environment required to run the model
pass pass
def completion(self): # logic for parsing in - calling - parsing out model completion calls def completion(
self,
): # logic for parsing in - calling - parsing out model completion calls
pass pass
def embedding(self): # logic for parsing in - calling - parsing out model embedding calls def embedding(
self,
): # logic for parsing in - calling - parsing out model embedding calls
pass pass

View file

@ -45,10 +45,11 @@ class HuggingfaceRestAPILLM:
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
): # logic for parsing in - calling - parsing out model completion calls ): # logic for parsing in - calling - parsing out model completion calls
completion_url: str = ""
if custom_api_base: if custom_api_base:
completion_url: Optional[str] = custom_api_base completion_url = custom_api_base
elif "HF_API_BASE" in os.environ: elif "HF_API_BASE" in os.environ:
completion_url = os.getenv("HF_API_BASE") completion_url = os.getenv("HF_API_BASE", "")
else: else:
completion_url = f"https://api-inference.huggingface.co/models/{model}" completion_url = f"https://api-inference.huggingface.co/models/{model}"
prompt = "" prompt = ""
@ -137,5 +138,7 @@ class HuggingfaceRestAPILLM:
return model_response return model_response
pass pass
def embedding(self): # logic for parsing in - calling - parsing out model embedding calls def embedding(
self,
): # logic for parsing in - calling - parsing out model embedding calls
pass pass

View file

@ -4,7 +4,7 @@ from functools import partial
import dotenv, traceback, random, asyncio, time import dotenv, traceback, random, asyncio, time
from copy import deepcopy from copy import deepcopy
import litellm import litellm
from litellm import ( from litellm import ( # type: ignore
client, client,
logging, logging,
exception_type, exception_type,
@ -55,7 +55,7 @@ async def acompletion(*args, **kwargs):
@client @client
# @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(2), reraise=True, retry_error_callback=lambda retry_state: setattr(retry_state.outcome, 'retry_variable', litellm.retry)) # retry call, turn this off by setting `litellm.retry = False` # @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(2), reraise=True, retry_error_callback=lambda retry_state: setattr(retry_state.outcome, 'retry_variable', litellm.retry)) # retry call, turn this off by setting `litellm.retry = False`
@timeout( @timeout( # type: ignore
600 600
) ## set timeouts, in case calls hang (e.g. Azure) - default is 60s, override with `force_timeout` ) ## set timeouts, in case calls hang (e.g. Azure) - default is 60s, override with `force_timeout`
def completion( def completion(
@ -266,7 +266,7 @@ def completion(
or litellm.replicate_key or litellm.replicate_key
) )
# set replicate key # set replicate key
os.environ["REPLICATE_API_TOKEN"]: str = replicate_key os.environ["REPLICATE_API_TOKEN"] = str(replicate_key)
prompt = " ".join([message["content"] for message in messages]) prompt = " ".join([message["content"] for message in messages])
input = {"prompt": prompt} input = {"prompt": prompt}
if "max_tokens" in optional_params: if "max_tokens" in optional_params:
@ -807,7 +807,7 @@ def batch_completion(*args, **kwargs):
### EMBEDDING ENDPOINTS #################### ### EMBEDDING ENDPOINTS ####################
@client @client
@timeout( @timeout( # type: ignore
60 60
) ## set timeouts, in case calls hang (e.g. Azure) - default is 60s, override with `force_timeout` ) ## set timeouts, in case calls hang (e.g. Azure) - default is 60s, override with `force_timeout`
def embedding(model, input=[], azure=False, force_timeout=60, logger_fn=None): def embedding(model, input=[], azure=False, force_timeout=60, logger_fn=None):

View file

@ -21,8 +21,8 @@ user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
## Test 1: Setting key dynamically ## Test 1: Setting key dynamically
temp_key = os.environ.get("ANTHROPIC_API_KEY") temp_key = os.environ.get("ANTHROPIC_API_KEY", "")
os.environ["ANTHROPIC_API_KEY"]: str = "bad-key" os.environ["ANTHROPIC_API_KEY"] = "bad-key"
# test on openai completion call # test on openai completion call
try: try:
response = completion( response = completion(
@ -39,7 +39,7 @@ os.environ["ANTHROPIC_API_KEY"] = temp_key
## Test 2: Setting key via __init__ params ## Test 2: Setting key via __init__ params
litellm.anthropic_key: str = os.environ.get("ANTHROPIC_API_KEY") litellm.anthropic_key = os.environ.get("ANTHROPIC_API_KEY", "")
os.environ.pop("ANTHROPIC_API_KEY") os.environ.pop("ANTHROPIC_API_KEY")
# test on openai completion call # test on openai completion call
try: try:

View file

@ -50,4 +50,4 @@ try:
except: except:
print(f"error occurred: {traceback.format_exc()}") print(f"error occurred: {traceback.format_exc()}")
pass pass
os.environ["OPENAI_API_KEY"] = temp_key os.environ["OPENAI_API_KEY"] = str(temp_key)