forked from phoenix/litellm-mirror
accept model alias
This commit is contained in:
parent
fd59ff12d5
commit
67d6527217
8 changed files with 35 additions and 15 deletions
|
@ -23,6 +23,7 @@ vertex_location: Optional[str] = None
|
||||||
togetherai_api_key: Optional[str] = None
|
togetherai_api_key: Optional[str] = None
|
||||||
caching = False
|
caching = False
|
||||||
caching_with_models = False # if you want the caching key to be model + prompt
|
caching_with_models = False # if you want the caching key to be model + prompt
|
||||||
|
model_alias_map = {}
|
||||||
debugger = False
|
debugger = False
|
||||||
model_cost = {
|
model_cost = {
|
||||||
"babbage-002": {
|
"babbage-002": {
|
||||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -58,7 +58,7 @@ async def acompletion(*args, **kwargs):
|
||||||
# @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(2), reraise=True, retry_error_callback=lambda retry_state: setattr(retry_state.outcome, 'retry_variable', litellm.retry)) # retry call, turn this off by setting `litellm.retry = False`
|
# @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(2), reraise=True, retry_error_callback=lambda retry_state: setattr(retry_state.outcome, 'retry_variable', litellm.retry)) # retry call, turn this off by setting `litellm.retry = False`
|
||||||
@timeout( # type: ignore
|
@timeout( # type: ignore
|
||||||
600
|
600
|
||||||
) ## set timeouts, in case calls hang (e.g. Azure) - default is 60s, override with `force_timeout`
|
) ## set timeouts, in case calls hang (e.g. Azure) - default is 600s, override with `force_timeout`
|
||||||
def completion(
|
def completion(
|
||||||
model,
|
model,
|
||||||
messages, # required params
|
messages, # required params
|
||||||
|
@ -97,6 +97,8 @@ def completion(
|
||||||
try:
|
try:
|
||||||
if fallbacks != []:
|
if fallbacks != []:
|
||||||
return completion_with_fallbacks(**args)
|
return completion_with_fallbacks(**args)
|
||||||
|
if litellm.model_alias_map and model in litellm.model_alias_map:
|
||||||
|
model = litellm.model_alias_map[model] # update the model to the actual value if an alias has been passed in
|
||||||
model_response = ModelResponse()
|
model_response = ModelResponse()
|
||||||
if azure: # this flag is deprecated, remove once notebooks are also updated.
|
if azure: # this flag is deprecated, remove once notebooks are also updated.
|
||||||
custom_llm_provider = "azure"
|
custom_llm_provider = "azure"
|
||||||
|
@ -686,7 +688,7 @@ def completion(
|
||||||
|
|
||||||
prompt = " ".join([message["content"] for message in messages])
|
prompt = " ".join([message["content"] for message in messages])
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging.pre_call(input=prompt, api_key=base_ten_key)
|
logging.pre_call(input=prompt, api_key=base_ten_key, model=model)
|
||||||
|
|
||||||
base_ten__model = baseten.deployed_model_version_id(model)
|
base_ten__model = baseten.deployed_model_version_id(model)
|
||||||
|
|
||||||
|
|
16
litellm/tests/test_model_alias_map.py
Normal file
16
litellm/tests/test_model_alias_map.py
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
#### What this tests ####
|
||||||
|
# This tests the model alias mapping - if user passes in an alias, and has set an alias, set it to the actual value
|
||||||
|
|
||||||
|
import sys, os
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import litellm
|
||||||
|
from litellm import embedding, completion
|
||||||
|
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
# Test: Check if the alias created via LiteDebugger is mapped correctly
|
||||||
|
print(completion("wizard-lm", messages=[{"role": "user", "content": "Hey, how's it going?"}]))
|
|
@ -161,14 +161,18 @@ class Logging:
|
||||||
"litellm_params": self.litellm_params,
|
"litellm_params": self.litellm_params,
|
||||||
}
|
}
|
||||||
|
|
||||||
def pre_call(self, input, api_key, additional_args={}):
|
def pre_call(self, input, api_key, model=None, additional_args={}):
|
||||||
try:
|
try:
|
||||||
print_verbose(f"logging pre call for model: {self.model}")
|
print_verbose(f"logging pre call for model: {self.model}")
|
||||||
self.model_call_details["input"] = input
|
self.model_call_details["input"] = input
|
||||||
self.model_call_details["api_key"] = api_key
|
self.model_call_details["api_key"] = api_key
|
||||||
self.model_call_details["additional_args"] = additional_args
|
self.model_call_details["additional_args"] = additional_args
|
||||||
|
|
||||||
|
if model: # if model name was changes pre-call, overwrite the initial model call name with the new one
|
||||||
|
self.model_call_details["model"] = model
|
||||||
|
|
||||||
# User Logging -> if you pass in a custom logging function
|
# User Logging -> if you pass in a custom logging function
|
||||||
|
print_verbose(f"model call details: {self.model_call_details}")
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"Logging Details: logger_fn - {self.logger_fn} | callable(logger_fn) - {callable(self.logger_fn)}"
|
f"Logging Details: logger_fn - {self.logger_fn} | callable(logger_fn) - {callable(self.logger_fn)}"
|
||||||
)
|
)
|
||||||
|
@ -187,8 +191,8 @@ class Logging:
|
||||||
try:
|
try:
|
||||||
if callback == "supabase":
|
if callback == "supabase":
|
||||||
print_verbose("reaches supabase for logging!")
|
print_verbose("reaches supabase for logging!")
|
||||||
model = self.model
|
model = self.model_call_details["model"]
|
||||||
messages = self.messages
|
messages = self.model_call_details["input"]
|
||||||
print(f"supabaseClient: {supabaseClient}")
|
print(f"supabaseClient: {supabaseClient}")
|
||||||
supabaseClient.input_log_event(
|
supabaseClient.input_log_event(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -201,8 +205,8 @@ class Logging:
|
||||||
|
|
||||||
elif callback == "lite_debugger":
|
elif callback == "lite_debugger":
|
||||||
print_verbose("reaches litedebugger for logging!")
|
print_verbose("reaches litedebugger for logging!")
|
||||||
model = self.model
|
model = self.model_call_details["model"]
|
||||||
messages = self.messages
|
messages = self.model_call_details["input"]
|
||||||
print_verbose(f"liteDebuggerClient: {liteDebuggerClient}")
|
print_verbose(f"liteDebuggerClient: {liteDebuggerClient}")
|
||||||
liteDebuggerClient.input_log_event(
|
liteDebuggerClient.input_log_event(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -1119,6 +1123,7 @@ def get_all_keys(llm_provider=None):
|
||||||
try:
|
try:
|
||||||
global last_fetched_at
|
global last_fetched_at
|
||||||
# if user is using hosted product -> instantiate their env with their hosted api keys - refresh every 5 minutes
|
# if user is using hosted product -> instantiate their env with their hosted api keys - refresh every 5 minutes
|
||||||
|
print_verbose(f"Reaches get all keys, llm_provider: {llm_provider}")
|
||||||
user_email = os.getenv("LITELLM_EMAIL") or litellm.email or litellm.token or os.getenv("LITELLM_TOKEN")
|
user_email = os.getenv("LITELLM_EMAIL") or litellm.email or litellm.token or os.getenv("LITELLM_TOKEN")
|
||||||
if user_email:
|
if user_email:
|
||||||
time_delta = 0
|
time_delta = 0
|
||||||
|
@ -1135,9 +1140,11 @@ def get_all_keys(llm_provider=None):
|
||||||
# update model list
|
# update model list
|
||||||
for key, value in data["model_keys"].items(): # follows the LITELLM API KEY format - <UPPERCASE_PROVIDER_NAME>_API_KEY - e.g. HUGGINGFACE_API_KEY
|
for key, value in data["model_keys"].items(): # follows the LITELLM API KEY format - <UPPERCASE_PROVIDER_NAME>_API_KEY - e.g. HUGGINGFACE_API_KEY
|
||||||
os.environ[key] = value
|
os.environ[key] = value
|
||||||
|
# set model alias map
|
||||||
|
for model_alias, value in data["model_alias_map"].items():
|
||||||
|
litellm.model_alias_map[model_alias] = value
|
||||||
return "it worked!"
|
return "it worked!"
|
||||||
return None
|
return None
|
||||||
# return None by default
|
|
||||||
return None
|
return None
|
||||||
except:
|
except:
|
||||||
print_verbose(f"[Non-Blocking Error] get_all_keys error - {traceback.format_exc()}")
|
print_verbose(f"[Non-Blocking Error] get_all_keys error - {traceback.format_exc()}")
|
||||||
|
@ -1149,12 +1156,6 @@ def get_model_list():
|
||||||
# if user is using hosted product -> get their updated model list
|
# if user is using hosted product -> get their updated model list
|
||||||
user_email = os.getenv("LITELLM_EMAIL") or litellm.email or litellm.token or os.getenv("LITELLM_TOKEN")
|
user_email = os.getenv("LITELLM_EMAIL") or litellm.email or litellm.token or os.getenv("LITELLM_TOKEN")
|
||||||
if user_email:
|
if user_email:
|
||||||
# Commented out the section checking time delta
|
|
||||||
# time_delta = 0
|
|
||||||
# if last_fetched_at != None:
|
|
||||||
# current_time = time.time()
|
|
||||||
# time_delta = current_time - last_fetched_at
|
|
||||||
# if time_delta > 300 or last_fetched_at == None:
|
|
||||||
# make the api call
|
# make the api call
|
||||||
last_fetched_at = time.time()
|
last_fetched_at = time.time()
|
||||||
print(f"last_fetched_at: {last_fetched_at}")
|
print(f"last_fetched_at: {last_fetched_at}")
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "litellm"
|
name = "litellm"
|
||||||
version = "0.1.475"
|
version = "0.1.476"
|
||||||
description = "Library to easily interface with LLM API providers"
|
description = "Library to easily interface with LLM API providers"
|
||||||
authors = ["BerriAI"]
|
authors = ["BerriAI"]
|
||||||
license = "MIT License"
|
license = "MIT License"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue