mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
accept model alias
This commit is contained in:
parent
fd59ff12d5
commit
67d6527217
8 changed files with 35 additions and 15 deletions
|
@ -161,14 +161,18 @@ class Logging:
|
|||
"litellm_params": self.litellm_params,
|
||||
}
|
||||
|
||||
def pre_call(self, input, api_key, additional_args={}):
|
||||
def pre_call(self, input, api_key, model=None, additional_args={}):
|
||||
try:
|
||||
print_verbose(f"logging pre call for model: {self.model}")
|
||||
self.model_call_details["input"] = input
|
||||
self.model_call_details["api_key"] = api_key
|
||||
self.model_call_details["additional_args"] = additional_args
|
||||
|
||||
if model: # if model name was changes pre-call, overwrite the initial model call name with the new one
|
||||
self.model_call_details["model"] = model
|
||||
|
||||
# User Logging -> if you pass in a custom logging function
|
||||
print_verbose(f"model call details: {self.model_call_details}")
|
||||
print_verbose(
|
||||
f"Logging Details: logger_fn - {self.logger_fn} | callable(logger_fn) - {callable(self.logger_fn)}"
|
||||
)
|
||||
|
@ -187,8 +191,8 @@ class Logging:
|
|||
try:
|
||||
if callback == "supabase":
|
||||
print_verbose("reaches supabase for logging!")
|
||||
model = self.model
|
||||
messages = self.messages
|
||||
model = self.model_call_details["model"]
|
||||
messages = self.model_call_details["input"]
|
||||
print(f"supabaseClient: {supabaseClient}")
|
||||
supabaseClient.input_log_event(
|
||||
model=model,
|
||||
|
@ -201,8 +205,8 @@ class Logging:
|
|||
|
||||
elif callback == "lite_debugger":
|
||||
print_verbose("reaches litedebugger for logging!")
|
||||
model = self.model
|
||||
messages = self.messages
|
||||
model = self.model_call_details["model"]
|
||||
messages = self.model_call_details["input"]
|
||||
print_verbose(f"liteDebuggerClient: {liteDebuggerClient}")
|
||||
liteDebuggerClient.input_log_event(
|
||||
model=model,
|
||||
|
@ -1119,6 +1123,7 @@ def get_all_keys(llm_provider=None):
|
|||
try:
|
||||
global last_fetched_at
|
||||
# if user is using hosted product -> instantiate their env with their hosted api keys - refresh every 5 minutes
|
||||
print_verbose(f"Reaches get all keys, llm_provider: {llm_provider}")
|
||||
user_email = os.getenv("LITELLM_EMAIL") or litellm.email or litellm.token or os.getenv("LITELLM_TOKEN")
|
||||
if user_email:
|
||||
time_delta = 0
|
||||
|
@ -1135,9 +1140,11 @@ def get_all_keys(llm_provider=None):
|
|||
# update model list
|
||||
for key, value in data["model_keys"].items(): # follows the LITELLM API KEY format - <UPPERCASE_PROVIDER_NAME>_API_KEY - e.g. HUGGINGFACE_API_KEY
|
||||
os.environ[key] = value
|
||||
# set model alias map
|
||||
for model_alias, value in data["model_alias_map"].items():
|
||||
litellm.model_alias_map[model_alias] = value
|
||||
return "it worked!"
|
||||
return None
|
||||
# return None by default
|
||||
return None
|
||||
except:
|
||||
print_verbose(f"[Non-Blocking Error] get_all_keys error - {traceback.format_exc()}")
|
||||
|
@ -1149,12 +1156,6 @@ def get_model_list():
|
|||
# if user is using hosted product -> get their updated model list
|
||||
user_email = os.getenv("LITELLM_EMAIL") or litellm.email or litellm.token or os.getenv("LITELLM_TOKEN")
|
||||
if user_email:
|
||||
# Commented out the section checking time delta
|
||||
# time_delta = 0
|
||||
# if last_fetched_at != None:
|
||||
# current_time = time.time()
|
||||
# time_delta = current_time - last_fetched_at
|
||||
# if time_delta > 300 or last_fetched_at == None:
|
||||
# make the api call
|
||||
last_fetched_at = time.time()
|
||||
print(f"last_fetched_at: {last_fetched_at}")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue