Merge pull request #1980 from BerriAI/litellm_fix_model_access_groups_bug

[FIX] Model Access Groups
This commit is contained in:
Ishaan Jaff 2024-02-14 20:51:53 -08:00 committed by GitHub
commit 78d0da6aa5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 29 additions and 33 deletions

View file

@ -9,14 +9,19 @@ model_list:
mode: chat mode: chat
max_tokens: 4096 max_tokens: 4096
base_model: azure/gpt-4-1106-preview base_model: azure/gpt-4-1106-preview
access_groups: ["public"]
- model_name: openai-gpt-3.5 - model_name: openai-gpt-3.5
litellm_params: litellm_params:
model: gpt-3.5-turbo model: gpt-3.5-turbo
api_key: os.environ/OPENAI_API_KEY api_key: os.environ/OPENAI_API_KEY
model_info:
access_groups: ["public"]
- model_name: anthropic-claude-v2.1 - model_name: anthropic-claude-v2.1
litellm_params: litellm_params:
model: bedrock/anthropic.claude-v2:1 model: bedrock/anthropic.claude-v2:1
timeout: 300 # sets a 5 minute timeout timeout: 300 # sets a 5 minute timeout
model_info:
access_groups: ["private"]
- model_name: anthropic-claude-v2 - model_name: anthropic-claude-v2
litellm_params: litellm_params:
model: bedrock/anthropic.claude-v2 model: bedrock/anthropic.claude-v2
@ -39,16 +44,6 @@ model_list:
litellm_settings: litellm_settings:
fallbacks: [{"openai-gpt-3.5": ["azure-gpt-3.5"]}] fallbacks: [{"openai-gpt-3.5": ["azure-gpt-3.5"]}]
success_callback: ['langfuse'] success_callback: ['langfuse']
max_budget: 50 # global budget for proxy
max_user_budget: 0.0001
budget_duration: 30d # global budget duration, will reset after 30d
default_key_generate_params:
max_budget: 1.5000
models: ["azure-gpt-3.5"]
duration: None
upperbound_key_generate_params:
max_budget: 100
duration: "30d"
# setting callback class # setting callback class
# callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance] # callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance]

View file

@ -403,34 +403,43 @@ async def user_api_key_auth(
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
f"LLM Model List pre access group check: {llm_model_list}" f"LLM Model List pre access group check: {llm_model_list}"
) )
access_groups = [] from collections import defaultdict
access_groups = defaultdict(list)
if llm_model_list is not None: if llm_model_list is not None:
for m in llm_model_list: for m in llm_model_list:
for group in m.get("model_info", {}).get("access_groups", []): for group in m.get("model_info", {}).get("access_groups", []):
access_groups.append((m["model_name"], group)) model_name = m["model_name"]
access_groups[group].append(model_name)
allowed_models = valid_token.models models_in_current_access_groups = []
access_group_idx = set()
if ( if (
len(access_groups) > 0 len(access_groups) > 0
): # check if token contains any model access groups ): # check if token contains any model access groups
for idx, m in enumerate(valid_token.models): for idx, m in enumerate(
for model_name, group in access_groups: valid_token.models
if m == group: ): # loop token models, if any of them are an access group add the access group
access_group_idx.add(idx) if m in access_groups:
allowed_models.append(model_name) # if it is an access group we need to remove it from valid_token.models
models_in_group = access_groups[m]
models_in_current_access_groups.extend(models_in_group)
# Filter out models that are access_groups
filtered_models = [
m for m in valid_token.models if m not in access_groups
]
filtered_models += models_in_current_access_groups
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
f"model: {model}; allowed_models: {allowed_models}" f"model: {model}; allowed_models: {filtered_models}"
) )
if model is not None and model not in allowed_models: if model is not None and model not in filtered_models:
raise ValueError( raise ValueError(
f"API Key not allowed to access model. This token can only access models={valid_token.models}. Tried to access {model}" f"API Key not allowed to access model. This token can only access models={valid_token.models}. Tried to access {model}"
) )
for val in access_group_idx: valid_token.models = filtered_models
allowed_models.pop(val)
valid_token.models = allowed_models
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
f"filtered allowed_models: {allowed_models}; valid_token.models: {valid_token.models}" f"filtered allowed_models: {filtered_models}; valid_token.models: {valid_token.models}"
) )
# Check 2. If user_id for this token is in budget # Check 2. If user_id for this token is in budget
@ -2087,14 +2096,6 @@ def model_list(
if user_model is not None: if user_model is not None:
all_models += [user_model] all_models += [user_model]
verbose_proxy_logger.debug(f"all_models: {all_models}") verbose_proxy_logger.debug(f"all_models: {all_models}")
### CHECK OLLAMA MODELS ###
try:
response = requests.get("http://0.0.0.0:11434/api/tags")
models = response.json()["models"]
ollama_models = ["ollama/" + m["name"].replace(":latest", "") for m in models]
all_models.extend(ollama_models)
except Exception as e:
pass
return dict( return dict(
data=[ data=[
{ {