diff --git a/litellm/proxy/auth/model_checks.py b/litellm/proxy/auth/model_checks.py index 3c874ff0e9..ccbf907864 100644 --- a/litellm/proxy/auth/model_checks.py +++ b/litellm/proxy/auth/model_checks.py @@ -60,17 +60,20 @@ def get_complete_model_list( - If team list is empty -> defer to proxy model list """ - if len(key_models) > 0: - return key_models + unique_models = set() - if len(team_models) > 0: - return team_models + if key_models: + unique_models.update(key_models) + elif team_models: + unique_models.update(team_models) + else: + unique_models.update(proxy_model_list) - returned_models = proxy_model_list - if user_model is not None: # set via `litellm --model ollama/llam3` - returned_models.append(user_model) + if user_model: + unique_models.add(user_model) - if infer_model_from_keys is not None and infer_model_from_keys == True: - valid_models = get_valid_models() - returned_models.extend(valid_models) - return returned_models + if infer_model_from_keys: + valid_models = get_valid_models() + unique_models.update(valid_models) + + return list(unique_models) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index e225169f97..6f1a3e557b 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -2,7 +2,7 @@ import sys, os, platform, time, copy, re, asyncio, inspect import threading, ast import shutil, random, traceback, requests from datetime import datetime, timedelta, timezone -from typing import Optional, List, Callable, get_args +from typing import Optional, List, Callable, get_args, Set import secrets, subprocess import hashlib, uuid import warnings @@ -106,7 +106,7 @@ import pydantic from litellm.proxy._types import * from litellm.caching import DualCache, RedisCache from litellm.proxy.health_check import perform_health_check -from litellm.router import LiteLLM_Params, Deployment, updateDeployment +from litellm.router import LiteLLM_Params, Deployment, updateDeployment, ModelGroupInfo from litellm.router import ModelInfo as RouterModelInfo from litellm._logging import verbose_router_logger, verbose_proxy_logger from litellm.proxy.auth.handle_jwt import JWTHandler @@ -9730,6 +9730,58 @@ async def model_info_v1( return {"data": all_models} +@router.get( + "/model_group/info", + description="Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)", + tags=["model management"], + dependencies=[Depends(user_api_key_auth)], +) +async def model_group_info( + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Returns model info at the model group level. + """ + global llm_model_list, general_settings, user_config_file_path, proxy_config, llm_router + + if llm_model_list is None: + raise HTTPException( + status_code=500, detail={"error": "LLM Model List not loaded in"} + ) + if llm_router is None: + raise HTTPException( + status_code=500, detail={"error": "LLM Router is not loaded in"} + ) + all_models: List[dict] = [] + ## CHECK IF MODEL RESTRICTIONS ARE SET AT KEY/TEAM LEVEL ## + if llm_model_list is None: + proxy_model_list = [] + else: + proxy_model_list = [m["model_name"] for m in llm_model_list] + key_models = get_key_models( + user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list + ) + team_models = get_team_models( + user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list + ) + all_models_str = get_complete_model_list( + key_models=key_models, + team_models=team_models, + proxy_model_list=proxy_model_list, + user_model=user_model, + infer_model_from_keys=general_settings.get("infer_model_from_keys", False), + ) + + model_groups: List[ModelGroupInfo] = [] + for model in all_models_str: + + _model_group_info = llm_router.get_model_group_info(model_group=model) + if _model_group_info is not None: + model_groups.append(_model_group_info) + + return {"data": model_groups} + + #### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/964 @router.post( "/model/delete", diff --git a/litellm/router.py b/litellm/router.py index 3c486a7478..3243e09fa4 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -48,6 +48,7 @@ from litellm.types.router import ( RetryPolicy, AlertingConfig, DeploymentTypedDict, + ModelGroupInfo, ) from litellm.integrations.custom_logger import CustomLogger from litellm.llms.azure import get_azure_ad_token_from_oidc @@ -3045,6 +3046,100 @@ class Router: return model return None + def get_model_group_info(self, model_group: str) -> Optional[ModelGroupInfo]: + """ + For a given model group name, return the combined model info + + Returns: + - ModelGroupInfo if able to construct a model group + - None if error constructing model group info + """ + + model_group_info: Optional[ModelGroupInfo] = None + + for model in self.model_list: + if "model_name" in model and model["model_name"] == model_group: + # model in model group found # + litellm_params = LiteLLM_Params(**model["litellm_params"]) + # get model info + try: + model_info = litellm.get_model_info(model=litellm_params.model) + except Exception as e: + continue + # get llm provider + try: + model, llm_provider, _, _ = litellm.get_llm_provider( + model=litellm_params.model, + custom_llm_provider=litellm_params.custom_llm_provider, + ) + except Exception as e: + continue + + if model_group_info is None: + model_group_info = ModelGroupInfo( + model_group=model_group, providers=[llm_provider], **model_info # type: ignore + ) + else: + # if max_input_tokens > curr + # if max_output_tokens > curr + # if input_cost_per_token > curr + # if output_cost_per_token > curr + # supports_parallel_function_calling == True + # supports_vision == True + # supports_function_calling == True + if llm_provider not in model_group_info.providers: + model_group_info.providers.append(llm_provider) + if model_info.get("max_input_tokens", None) is not None and ( + model_group_info.max_input_tokens is None + or model_info["max_input_tokens"] + > model_group_info.max_input_tokens + ): + model_group_info.max_input_tokens = model_info[ + "max_input_tokens" + ] + if model_info.get("max_output_tokens", None) is not None and ( + model_group_info.max_output_tokens is None + or model_info["max_output_tokens"] + > model_group_info.max_output_tokens + ): + model_group_info.max_output_tokens = model_info[ + "max_output_tokens" + ] + if model_info.get("input_cost_per_token", None) is not None and ( + model_group_info.input_cost_per_token is None + or model_info["input_cost_per_token"] + > model_group_info.input_cost_per_token + ): + model_group_info.input_cost_per_token = model_info[ + "input_cost_per_token" + ] + if model_info.get("output_cost_per_token", None) is not None and ( + model_group_info.output_cost_per_token is None + or model_info["output_cost_per_token"] + > model_group_info.output_cost_per_token + ): + model_group_info.output_cost_per_token = model_info[ + "output_cost_per_token" + ] + if ( + model_info.get("supports_parallel_function_calling", None) + is not None + and model_info["supports_parallel_function_calling"] == True # type: ignore + ): + model_group_info.supports_parallel_function_calling = True + if ( + model_info.get("supports_vision", None) is not None + and model_info["supports_vision"] == True # type: ignore + ): + model_group_info.supports_vision = True + if ( + model_info.get("supports_function_calling", None) is not None + and model_info["supports_function_calling"] == True # type: ignore + ): + model_group_info.supports_function_calling = True + + return model_group_info + def get_model_ids(self) -> List[str]: """ Returns list of model id's. diff --git a/litellm/types/router.py b/litellm/types/router.py index a61e551a70..5e6f2c1483 100644 --- a/litellm/types/router.py +++ b/litellm/types/router.py @@ -411,3 +411,18 @@ class AlertingConfig(BaseModel): webhook_url: str alerting_threshold: Optional[float] = 300 + + +class ModelGroupInfo(BaseModel): + model_group: str + providers: List[str] + max_input_tokens: Optional[float] = None + max_output_tokens: Optional[float] = None + input_cost_per_token: Optional[float] = None + output_cost_per_token: Optional[float] = None + mode: Literal[ + "chat", "embedding", "completion", "image_generation", "audio_transcription" + ] + supports_parallel_function_calling: bool = Field(default=False) + supports_vision: bool = Field(default=False) + supports_function_calling: bool = Field(default=False) diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 21823cc1fa..5c730cca8c 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -12,3 +12,13 @@ class ProviderField(TypedDict): field_type: Literal["string"] field_description: str field_value: str + + +class ModelInfo(TypedDict): + max_tokens: int + max_input_tokens: int + max_output_tokens: int + input_cost_per_token: float + output_cost_per_token: float + litellm_provider: str + mode: str diff --git a/litellm/utils.py b/litellm/utils.py index 5da01c764f..b777819e52 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -34,7 +34,7 @@ from dataclasses import ( import litellm._service_logger # for storing API inputs, outputs, and metadata from litellm.llms.custom_httpx.http_handler import HTTPHandler from litellm.caching import DualCache -from litellm.types.utils import CostPerToken, ProviderField +from litellm.types.utils import CostPerToken, ProviderField, ModelInfo oidc_cache = DualCache() @@ -7092,7 +7092,7 @@ def get_max_tokens(model: str): ) -def get_model_info(model: str): +def get_model_info(model: str) -> ModelInfo: """ Get a dict for the maximum tokens (context window), input_cost_per_token, output_cost_per_token for a given model. @@ -7154,7 +7154,7 @@ def get_model_info(model: str): if custom_llm_provider == "huggingface": max_tokens = _get_max_position_embeddings(model_name=model) return { - "max_tokens": max_tokens, + "max_tokens": max_tokens, # type: ignore "input_cost_per_token": 0, "output_cost_per_token": 0, "litellm_provider": "huggingface",