helper function to check if user is allowed to call model

This commit is contained in:
Krrish Dholakia 2023-08-21 12:36:58 -07:00
parent 6f82392983
commit 3375caf307
2 changed files with 15 additions and 5 deletions

View file

@ -243,9 +243,6 @@ class Logging:
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {traceback.format_exc()}" f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {traceback.format_exc()}"
) )
pass pass
# Add more methods as needed
def exception_logging( def exception_logging(
additional_args={}, additional_args={},
@ -1026,6 +1023,17 @@ def prompt_token_calculator(model, messages):
return num_tokens return num_tokens
def valid_model(model):
try:
# for a given model name, check if the user has the right permissions to access the model
if model in litellm.open_ai_chat_completion_models or model in litellm.open_ai_text_completion_models:
openai.Model.retrieve(model)
else:
messages = [{"role": "user", "content": "Hello World"}]
litellm.completion(model=model, messages=messages)
except:
raise InvalidRequestError(message="", model=model, llm_provider="")
# integration helper function # integration helper function
def modify_integration(integration_name, integration_params): def modify_integration(integration_name, integration_params):
global supabaseClient global supabaseClient
@ -1034,6 +1042,7 @@ def modify_integration(integration_name, integration_params):
Supabase.supabase_table_name = integration_params["table_name"] Supabase.supabase_table_name = integration_params["table_name"]
####### EXCEPTION MAPPING ################
def exception_type(model, original_exception, custom_llm_provider): def exception_type(model, original_exception, custom_llm_provider):
global user_logger_fn, liteDebuggerClient global user_logger_fn, liteDebuggerClient
exception_mapping_worked = False exception_mapping_worked = False
@ -1175,6 +1184,7 @@ def exception_type(model, original_exception, custom_llm_provider):
raise original_exception raise original_exception
####### CRASH REPORTING ################
def safe_crash_reporting(model=None, exception=None, custom_llm_provider=None): def safe_crash_reporting(model=None, exception=None, custom_llm_provider=None):
data = { data = {
"model": model, "model": model,
@ -1373,7 +1383,7 @@ async def stream_to_string(generator):
return response return response
########## Together AI streaming ############################# ########## Together AI streaming ############################# [TODO] move together ai to it's own llm class
async def together_ai_completion_streaming(json_data, headers): async def together_ai_completion_streaming(json_data, headers):
session = aiohttp.ClientSession() session = aiohttp.ClientSession()
url = "https://api.together.xyz/inference" url = "https://api.together.xyz/inference"

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "litellm" name = "litellm"
version = "0.1.442" version = "0.1.443"
description = "Library to easily interface with LLM API providers" description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"] authors = ["BerriAI"]
license = "MIT License" license = "MIT License"