diff --git a/litellm/__init__.py b/litellm/__init__.py index 7d88e9b11..c5c2cd68d 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -23,6 +23,7 @@ vertex_location: Optional[str] = None togetherai_api_key: Optional[str] = None caching = False caching_with_models = False # if you want the caching key to be model + prompt +model_alias_map = {} debugger = False model_cost = { "babbage-002": { diff --git a/litellm/__pycache__/__init__.cpython-311.pyc b/litellm/__pycache__/__init__.cpython-311.pyc index 2d7525526..3b65dc585 100644 Binary files a/litellm/__pycache__/__init__.cpython-311.pyc and b/litellm/__pycache__/__init__.cpython-311.pyc differ diff --git a/litellm/__pycache__/main.cpython-311.pyc b/litellm/__pycache__/main.cpython-311.pyc index 932f8294d..5c7773959 100644 Binary files a/litellm/__pycache__/main.cpython-311.pyc and b/litellm/__pycache__/main.cpython-311.pyc differ diff --git a/litellm/__pycache__/utils.cpython-311.pyc b/litellm/__pycache__/utils.cpython-311.pyc index 643e64add..b695df71a 100644 Binary files a/litellm/__pycache__/utils.cpython-311.pyc and b/litellm/__pycache__/utils.cpython-311.pyc differ diff --git a/litellm/main.py b/litellm/main.py index a17dfc8a5..5062ace36 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -58,7 +58,7 @@ async def acompletion(*args, **kwargs): # @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(2), reraise=True, retry_error_callback=lambda retry_state: setattr(retry_state.outcome, 'retry_variable', litellm.retry)) # retry call, turn this off by setting `litellm.retry = False` @timeout( # type: ignore 600 -) ## set timeouts, in case calls hang (e.g. Azure) - default is 60s, override with `force_timeout` +) ## set timeouts, in case calls hang (e.g. Azure) - default is 600s, override with `force_timeout` def completion( model, messages, # required params @@ -97,6 +97,8 @@ def completion( try: if fallbacks != []: return completion_with_fallbacks(**args) + if litellm.model_alias_map and model in litellm.model_alias_map: + model = litellm.model_alias_map[model] # update the model to the actual value if an alias has been passed in model_response = ModelResponse() if azure: # this flag is deprecated, remove once notebooks are also updated. custom_llm_provider = "azure" @@ -686,7 +688,7 @@ def completion( prompt = " ".join([message["content"] for message in messages]) ## LOGGING - logging.pre_call(input=prompt, api_key=base_ten_key) + logging.pre_call(input=prompt, api_key=base_ten_key, model=model) base_ten__model = baseten.deployed_model_version_id(model) diff --git a/litellm/tests/test_model_alias_map.py b/litellm/tests/test_model_alias_map.py new file mode 100644 index 000000000..368f02b1a --- /dev/null +++ b/litellm/tests/test_model_alias_map.py @@ -0,0 +1,16 @@ +#### What this tests #### +# This tests the model alias mapping - if user passes in an alias, and has set an alias, set it to the actual value + +import sys, os +import traceback + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import litellm +from litellm import embedding, completion + +litellm.set_verbose = True + +# Test: Check if the alias created via LiteDebugger is mapped correctly +print(completion("wizard-lm", messages=[{"role": "user", "content": "Hey, how's it going?"}])) \ No newline at end of file diff --git a/litellm/utils.py b/litellm/utils.py index d7c48aa3b..8819abd31 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -161,14 +161,18 @@ class Logging: "litellm_params": self.litellm_params, } - def pre_call(self, input, api_key, additional_args={}): + def pre_call(self, input, api_key, model=None, additional_args={}): try: print_verbose(f"logging pre call for model: {self.model}") self.model_call_details["input"] = input self.model_call_details["api_key"] = api_key self.model_call_details["additional_args"] = additional_args + if model: # if model name was changes pre-call, overwrite the initial model call name with the new one + self.model_call_details["model"] = model + # User Logging -> if you pass in a custom logging function + print_verbose(f"model call details: {self.model_call_details}") print_verbose( f"Logging Details: logger_fn - {self.logger_fn} | callable(logger_fn) - {callable(self.logger_fn)}" ) @@ -187,8 +191,8 @@ class Logging: try: if callback == "supabase": print_verbose("reaches supabase for logging!") - model = self.model - messages = self.messages + model = self.model_call_details["model"] + messages = self.model_call_details["input"] print(f"supabaseClient: {supabaseClient}") supabaseClient.input_log_event( model=model, @@ -201,8 +205,8 @@ class Logging: elif callback == "lite_debugger": print_verbose("reaches litedebugger for logging!") - model = self.model - messages = self.messages + model = self.model_call_details["model"] + messages = self.model_call_details["input"] print_verbose(f"liteDebuggerClient: {liteDebuggerClient}") liteDebuggerClient.input_log_event( model=model, @@ -1119,6 +1123,7 @@ def get_all_keys(llm_provider=None): try: global last_fetched_at # if user is using hosted product -> instantiate their env with their hosted api keys - refresh every 5 minutes + print_verbose(f"Reaches get all keys, llm_provider: {llm_provider}") user_email = os.getenv("LITELLM_EMAIL") or litellm.email or litellm.token or os.getenv("LITELLM_TOKEN") if user_email: time_delta = 0 @@ -1135,9 +1140,11 @@ def get_all_keys(llm_provider=None): # update model list for key, value in data["model_keys"].items(): # follows the LITELLM API KEY format - _API_KEY - e.g. HUGGINGFACE_API_KEY os.environ[key] = value + # set model alias map + for model_alias, value in data["model_alias_map"].items(): + litellm.model_alias_map[model_alias] = value return "it worked!" return None - # return None by default return None except: print_verbose(f"[Non-Blocking Error] get_all_keys error - {traceback.format_exc()}") @@ -1149,12 +1156,6 @@ def get_model_list(): # if user is using hosted product -> get their updated model list user_email = os.getenv("LITELLM_EMAIL") or litellm.email or litellm.token or os.getenv("LITELLM_TOKEN") if user_email: - # Commented out the section checking time delta - # time_delta = 0 - # if last_fetched_at != None: - # current_time = time.time() - # time_delta = current_time - last_fetched_at - # if time_delta > 300 or last_fetched_at == None: # make the api call last_fetched_at = time.time() print(f"last_fetched_at: {last_fetched_at}") diff --git a/pyproject.toml b/pyproject.toml index 5bcbd942c..ea08499fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "0.1.475" +version = "0.1.476" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT License"