mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
feat(proxy_server.py): expose new /model_group/info
endpoint
returns model-group level info on supported params, max tokens, pricing, etc.
This commit is contained in:
parent
bec13d465a
commit
22b6b99b34
6 changed files with 191 additions and 16 deletions
|
@ -60,17 +60,20 @@ def get_complete_model_list(
|
|||
- If team list is empty -> defer to proxy model list
|
||||
"""
|
||||
|
||||
if len(key_models) > 0:
|
||||
return key_models
|
||||
unique_models = set()
|
||||
|
||||
if len(team_models) > 0:
|
||||
return team_models
|
||||
if key_models:
|
||||
unique_models.update(key_models)
|
||||
elif team_models:
|
||||
unique_models.update(team_models)
|
||||
else:
|
||||
unique_models.update(proxy_model_list)
|
||||
|
||||
returned_models = proxy_model_list
|
||||
if user_model is not None: # set via `litellm --model ollama/llam3`
|
||||
returned_models.append(user_model)
|
||||
if user_model:
|
||||
unique_models.add(user_model)
|
||||
|
||||
if infer_model_from_keys is not None and infer_model_from_keys == True:
|
||||
valid_models = get_valid_models()
|
||||
returned_models.extend(valid_models)
|
||||
return returned_models
|
||||
if infer_model_from_keys:
|
||||
valid_models = get_valid_models()
|
||||
unique_models.update(valid_models)
|
||||
|
||||
return list(unique_models)
|
||||
|
|
|
@ -2,7 +2,7 @@ import sys, os, platform, time, copy, re, asyncio, inspect
|
|||
import threading, ast
|
||||
import shutil, random, traceback, requests
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional, List, Callable, get_args
|
||||
from typing import Optional, List, Callable, get_args, Set
|
||||
import secrets, subprocess
|
||||
import hashlib, uuid
|
||||
import warnings
|
||||
|
@ -106,7 +106,7 @@ import pydantic
|
|||
from litellm.proxy._types import *
|
||||
from litellm.caching import DualCache, RedisCache
|
||||
from litellm.proxy.health_check import perform_health_check
|
||||
from litellm.router import LiteLLM_Params, Deployment, updateDeployment
|
||||
from litellm.router import LiteLLM_Params, Deployment, updateDeployment, ModelGroupInfo
|
||||
from litellm.router import ModelInfo as RouterModelInfo
|
||||
from litellm._logging import verbose_router_logger, verbose_proxy_logger
|
||||
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||
|
@ -9730,6 +9730,58 @@ async def model_info_v1(
|
|||
return {"data": all_models}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/model_group/info",
|
||||
description="Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)",
|
||||
tags=["model management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def model_group_info(
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Returns model info at the model group level.
|
||||
"""
|
||||
global llm_model_list, general_settings, user_config_file_path, proxy_config, llm_router
|
||||
|
||||
if llm_model_list is None:
|
||||
raise HTTPException(
|
||||
status_code=500, detail={"error": "LLM Model List not loaded in"}
|
||||
)
|
||||
if llm_router is None:
|
||||
raise HTTPException(
|
||||
status_code=500, detail={"error": "LLM Router is not loaded in"}
|
||||
)
|
||||
all_models: List[dict] = []
|
||||
## CHECK IF MODEL RESTRICTIONS ARE SET AT KEY/TEAM LEVEL ##
|
||||
if llm_model_list is None:
|
||||
proxy_model_list = []
|
||||
else:
|
||||
proxy_model_list = [m["model_name"] for m in llm_model_list]
|
||||
key_models = get_key_models(
|
||||
user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
|
||||
)
|
||||
team_models = get_team_models(
|
||||
user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
|
||||
)
|
||||
all_models_str = get_complete_model_list(
|
||||
key_models=key_models,
|
||||
team_models=team_models,
|
||||
proxy_model_list=proxy_model_list,
|
||||
user_model=user_model,
|
||||
infer_model_from_keys=general_settings.get("infer_model_from_keys", False),
|
||||
)
|
||||
|
||||
model_groups: List[ModelGroupInfo] = []
|
||||
for model in all_models_str:
|
||||
|
||||
_model_group_info = llm_router.get_model_group_info(model_group=model)
|
||||
if _model_group_info is not None:
|
||||
model_groups.append(_model_group_info)
|
||||
|
||||
return {"data": model_groups}
|
||||
|
||||
|
||||
#### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/964
|
||||
@router.post(
|
||||
"/model/delete",
|
||||
|
|
|
@ -48,6 +48,7 @@ from litellm.types.router import (
|
|||
RetryPolicy,
|
||||
AlertingConfig,
|
||||
DeploymentTypedDict,
|
||||
ModelGroupInfo,
|
||||
)
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.llms.azure import get_azure_ad_token_from_oidc
|
||||
|
@ -3045,6 +3046,100 @@ class Router:
|
|||
return model
|
||||
return None
|
||||
|
||||
def get_model_group_info(self, model_group: str) -> Optional[ModelGroupInfo]:
|
||||
"""
|
||||
For a given model group name, return the combined model info
|
||||
|
||||
Returns:
|
||||
- ModelGroupInfo if able to construct a model group
|
||||
- None if error constructing model group info
|
||||
"""
|
||||
|
||||
model_group_info: Optional[ModelGroupInfo] = None
|
||||
|
||||
for model in self.model_list:
|
||||
if "model_name" in model and model["model_name"] == model_group:
|
||||
# model in model group found #
|
||||
litellm_params = LiteLLM_Params(**model["litellm_params"])
|
||||
# get model info
|
||||
try:
|
||||
model_info = litellm.get_model_info(model=litellm_params.model)
|
||||
except Exception as e:
|
||||
continue
|
||||
# get llm provider
|
||||
try:
|
||||
model, llm_provider, _, _ = litellm.get_llm_provider(
|
||||
model=litellm_params.model,
|
||||
custom_llm_provider=litellm_params.custom_llm_provider,
|
||||
)
|
||||
except Exception as e:
|
||||
continue
|
||||
|
||||
if model_group_info is None:
|
||||
model_group_info = ModelGroupInfo(
|
||||
model_group=model_group, providers=[llm_provider], **model_info # type: ignore
|
||||
)
|
||||
else:
|
||||
# if max_input_tokens > curr
|
||||
# if max_output_tokens > curr
|
||||
# if input_cost_per_token > curr
|
||||
# if output_cost_per_token > curr
|
||||
# supports_parallel_function_calling == True
|
||||
# supports_vision == True
|
||||
# supports_function_calling == True
|
||||
if llm_provider not in model_group_info.providers:
|
||||
model_group_info.providers.append(llm_provider)
|
||||
if model_info.get("max_input_tokens", None) is not None and (
|
||||
model_group_info.max_input_tokens is None
|
||||
or model_info["max_input_tokens"]
|
||||
> model_group_info.max_input_tokens
|
||||
):
|
||||
model_group_info.max_input_tokens = model_info[
|
||||
"max_input_tokens"
|
||||
]
|
||||
if model_info.get("max_output_tokens", None) is not None and (
|
||||
model_group_info.max_output_tokens is None
|
||||
or model_info["max_output_tokens"]
|
||||
> model_group_info.max_output_tokens
|
||||
):
|
||||
model_group_info.max_output_tokens = model_info[
|
||||
"max_output_tokens"
|
||||
]
|
||||
if model_info.get("input_cost_per_token", None) is not None and (
|
||||
model_group_info.input_cost_per_token is None
|
||||
or model_info["input_cost_per_token"]
|
||||
> model_group_info.input_cost_per_token
|
||||
):
|
||||
model_group_info.input_cost_per_token = model_info[
|
||||
"input_cost_per_token"
|
||||
]
|
||||
if model_info.get("output_cost_per_token", None) is not None and (
|
||||
model_group_info.output_cost_per_token is None
|
||||
or model_info["output_cost_per_token"]
|
||||
> model_group_info.output_cost_per_token
|
||||
):
|
||||
model_group_info.output_cost_per_token = model_info[
|
||||
"output_cost_per_token"
|
||||
]
|
||||
if (
|
||||
model_info.get("supports_parallel_function_calling", None)
|
||||
is not None
|
||||
and model_info["supports_parallel_function_calling"] == True # type: ignore
|
||||
):
|
||||
model_group_info.supports_parallel_function_calling = True
|
||||
if (
|
||||
model_info.get("supports_vision", None) is not None
|
||||
and model_info["supports_vision"] == True # type: ignore
|
||||
):
|
||||
model_group_info.supports_vision = True
|
||||
if (
|
||||
model_info.get("supports_function_calling", None) is not None
|
||||
and model_info["supports_function_calling"] == True # type: ignore
|
||||
):
|
||||
model_group_info.supports_function_calling = True
|
||||
|
||||
return model_group_info
|
||||
|
||||
def get_model_ids(self) -> List[str]:
|
||||
"""
|
||||
Returns list of model id's.
|
||||
|
|
|
@ -411,3 +411,18 @@ class AlertingConfig(BaseModel):
|
|||
|
||||
webhook_url: str
|
||||
alerting_threshold: Optional[float] = 300
|
||||
|
||||
|
||||
class ModelGroupInfo(BaseModel):
|
||||
model_group: str
|
||||
providers: List[str]
|
||||
max_input_tokens: Optional[float] = None
|
||||
max_output_tokens: Optional[float] = None
|
||||
input_cost_per_token: Optional[float] = None
|
||||
output_cost_per_token: Optional[float] = None
|
||||
mode: Literal[
|
||||
"chat", "embedding", "completion", "image_generation", "audio_transcription"
|
||||
]
|
||||
supports_parallel_function_calling: bool = Field(default=False)
|
||||
supports_vision: bool = Field(default=False)
|
||||
supports_function_calling: bool = Field(default=False)
|
||||
|
|
|
@ -12,3 +12,13 @@ class ProviderField(TypedDict):
|
|||
field_type: Literal["string"]
|
||||
field_description: str
|
||||
field_value: str
|
||||
|
||||
|
||||
class ModelInfo(TypedDict):
|
||||
max_tokens: int
|
||||
max_input_tokens: int
|
||||
max_output_tokens: int
|
||||
input_cost_per_token: float
|
||||
output_cost_per_token: float
|
||||
litellm_provider: str
|
||||
mode: str
|
||||
|
|
|
@ -34,7 +34,7 @@ from dataclasses import (
|
|||
import litellm._service_logger # for storing API inputs, outputs, and metadata
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
from litellm.caching import DualCache
|
||||
from litellm.types.utils import CostPerToken, ProviderField
|
||||
from litellm.types.utils import CostPerToken, ProviderField, ModelInfo
|
||||
|
||||
oidc_cache = DualCache()
|
||||
|
||||
|
@ -7092,7 +7092,7 @@ def get_max_tokens(model: str):
|
|||
)
|
||||
|
||||
|
||||
def get_model_info(model: str):
|
||||
def get_model_info(model: str) -> ModelInfo:
|
||||
"""
|
||||
Get a dict for the maximum tokens (context window),
|
||||
input_cost_per_token, output_cost_per_token for a given model.
|
||||
|
@ -7154,7 +7154,7 @@ def get_model_info(model: str):
|
|||
if custom_llm_provider == "huggingface":
|
||||
max_tokens = _get_max_position_embeddings(model_name=model)
|
||||
return {
|
||||
"max_tokens": max_tokens,
|
||||
"max_tokens": max_tokens, # type: ignore
|
||||
"input_cost_per_token": 0,
|
||||
"output_cost_per_token": 0,
|
||||
"litellm_provider": "huggingface",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue