feat(proxy_server.py): expose new /model_group/info endpoint

returns model-group level info on supported params, max tokens, pricing, etc.
This commit is contained in:
Krrish Dholakia 2024-05-26 14:07:35 -07:00
parent 90edb1d46e
commit 8e9a3fef81
6 changed files with 191 additions and 16 deletions

View file

@ -60,17 +60,20 @@ def get_complete_model_list(
- If team list is empty -> defer to proxy model list
"""
if len(key_models) > 0:
return key_models
unique_models = set()
if len(team_models) > 0:
return team_models
if key_models:
unique_models.update(key_models)
elif team_models:
unique_models.update(team_models)
else:
unique_models.update(proxy_model_list)
returned_models = proxy_model_list
if user_model is not None: # set via `litellm --model ollama/llam3`
returned_models.append(user_model)
if user_model:
unique_models.add(user_model)
if infer_model_from_keys is not None and infer_model_from_keys == True:
if infer_model_from_keys:
valid_models = get_valid_models()
returned_models.extend(valid_models)
return returned_models
unique_models.update(valid_models)
return list(unique_models)

View file

@ -2,7 +2,7 @@ import sys, os, platform, time, copy, re, asyncio, inspect
import threading, ast
import shutil, random, traceback, requests
from datetime import datetime, timedelta, timezone
from typing import Optional, List, Callable, get_args
from typing import Optional, List, Callable, get_args, Set
import secrets, subprocess
import hashlib, uuid
import warnings
@ -106,7 +106,7 @@ import pydantic
from litellm.proxy._types import *
from litellm.caching import DualCache, RedisCache
from litellm.proxy.health_check import perform_health_check
from litellm.router import LiteLLM_Params, Deployment, updateDeployment
from litellm.router import LiteLLM_Params, Deployment, updateDeployment, ModelGroupInfo
from litellm.router import ModelInfo as RouterModelInfo
from litellm._logging import verbose_router_logger, verbose_proxy_logger
from litellm.proxy.auth.handle_jwt import JWTHandler
@ -9730,6 +9730,58 @@ async def model_info_v1(
return {"data": all_models}
@router.get(
"/model_group/info",
description="Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)",
tags=["model management"],
dependencies=[Depends(user_api_key_auth)],
)
async def model_group_info(
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Returns model info at the model group level.
"""
global llm_model_list, general_settings, user_config_file_path, proxy_config, llm_router
if llm_model_list is None:
raise HTTPException(
status_code=500, detail={"error": "LLM Model List not loaded in"}
)
if llm_router is None:
raise HTTPException(
status_code=500, detail={"error": "LLM Router is not loaded in"}
)
all_models: List[dict] = []
## CHECK IF MODEL RESTRICTIONS ARE SET AT KEY/TEAM LEVEL ##
if llm_model_list is None:
proxy_model_list = []
else:
proxy_model_list = [m["model_name"] for m in llm_model_list]
key_models = get_key_models(
user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
)
team_models = get_team_models(
user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
)
all_models_str = get_complete_model_list(
key_models=key_models,
team_models=team_models,
proxy_model_list=proxy_model_list,
user_model=user_model,
infer_model_from_keys=general_settings.get("infer_model_from_keys", False),
)
model_groups: List[ModelGroupInfo] = []
for model in all_models_str:
_model_group_info = llm_router.get_model_group_info(model_group=model)
if _model_group_info is not None:
model_groups.append(_model_group_info)
return {"data": model_groups}
#### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/964
@router.post(
"/model/delete",

View file

@ -48,6 +48,7 @@ from litellm.types.router import (
RetryPolicy,
AlertingConfig,
DeploymentTypedDict,
ModelGroupInfo,
)
from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.azure import get_azure_ad_token_from_oidc
@ -3045,6 +3046,100 @@ class Router:
return model
return None
def get_model_group_info(self, model_group: str) -> Optional[ModelGroupInfo]:
"""
For a given model group name, return the combined model info
Returns:
- ModelGroupInfo if able to construct a model group
- None if error constructing model group info
"""
model_group_info: Optional[ModelGroupInfo] = None
for model in self.model_list:
if "model_name" in model and model["model_name"] == model_group:
# model in model group found #
litellm_params = LiteLLM_Params(**model["litellm_params"])
# get model info
try:
model_info = litellm.get_model_info(model=litellm_params.model)
except Exception as e:
continue
# get llm provider
try:
model, llm_provider, _, _ = litellm.get_llm_provider(
model=litellm_params.model,
custom_llm_provider=litellm_params.custom_llm_provider,
)
except Exception as e:
continue
if model_group_info is None:
model_group_info = ModelGroupInfo(
model_group=model_group, providers=[llm_provider], **model_info # type: ignore
)
else:
# if max_input_tokens > curr
# if max_output_tokens > curr
# if input_cost_per_token > curr
# if output_cost_per_token > curr
# supports_parallel_function_calling == True
# supports_vision == True
# supports_function_calling == True
if llm_provider not in model_group_info.providers:
model_group_info.providers.append(llm_provider)
if model_info.get("max_input_tokens", None) is not None and (
model_group_info.max_input_tokens is None
or model_info["max_input_tokens"]
> model_group_info.max_input_tokens
):
model_group_info.max_input_tokens = model_info[
"max_input_tokens"
]
if model_info.get("max_output_tokens", None) is not None and (
model_group_info.max_output_tokens is None
or model_info["max_output_tokens"]
> model_group_info.max_output_tokens
):
model_group_info.max_output_tokens = model_info[
"max_output_tokens"
]
if model_info.get("input_cost_per_token", None) is not None and (
model_group_info.input_cost_per_token is None
or model_info["input_cost_per_token"]
> model_group_info.input_cost_per_token
):
model_group_info.input_cost_per_token = model_info[
"input_cost_per_token"
]
if model_info.get("output_cost_per_token", None) is not None and (
model_group_info.output_cost_per_token is None
or model_info["output_cost_per_token"]
> model_group_info.output_cost_per_token
):
model_group_info.output_cost_per_token = model_info[
"output_cost_per_token"
]
if (
model_info.get("supports_parallel_function_calling", None)
is not None
and model_info["supports_parallel_function_calling"] == True # type: ignore
):
model_group_info.supports_parallel_function_calling = True
if (
model_info.get("supports_vision", None) is not None
and model_info["supports_vision"] == True # type: ignore
):
model_group_info.supports_vision = True
if (
model_info.get("supports_function_calling", None) is not None
and model_info["supports_function_calling"] == True # type: ignore
):
model_group_info.supports_function_calling = True
return model_group_info
def get_model_ids(self) -> List[str]:
"""
Returns list of model id's.

View file

@ -411,3 +411,18 @@ class AlertingConfig(BaseModel):
webhook_url: str
alerting_threshold: Optional[float] = 300
class ModelGroupInfo(BaseModel):
model_group: str
providers: List[str]
max_input_tokens: Optional[float] = None
max_output_tokens: Optional[float] = None
input_cost_per_token: Optional[float] = None
output_cost_per_token: Optional[float] = None
mode: Literal[
"chat", "embedding", "completion", "image_generation", "audio_transcription"
]
supports_parallel_function_calling: bool = Field(default=False)
supports_vision: bool = Field(default=False)
supports_function_calling: bool = Field(default=False)

View file

@ -12,3 +12,13 @@ class ProviderField(TypedDict):
field_type: Literal["string"]
field_description: str
field_value: str
class ModelInfo(TypedDict):
max_tokens: int
max_input_tokens: int
max_output_tokens: int
input_cost_per_token: float
output_cost_per_token: float
litellm_provider: str
mode: str

View file

@ -34,7 +34,7 @@ from dataclasses import (
import litellm._service_logger # for storing API inputs, outputs, and metadata
from litellm.llms.custom_httpx.http_handler import HTTPHandler
from litellm.caching import DualCache
from litellm.types.utils import CostPerToken, ProviderField
from litellm.types.utils import CostPerToken, ProviderField, ModelInfo
oidc_cache = DualCache()
@ -7092,7 +7092,7 @@ def get_max_tokens(model: str):
)
def get_model_info(model: str):
def get_model_info(model: str) -> ModelInfo:
"""
Get a dict for the maximum tokens (context window),
input_cost_per_token, output_cost_per_token for a given model.
@ -7154,7 +7154,7 @@ def get_model_info(model: str):
if custom_llm_provider == "huggingface":
max_tokens = _get_max_position_embeddings(model_name=model)
return {
"max_tokens": max_tokens,
"max_tokens": max_tokens, # type: ignore
"input_cost_per_token": 0,
"output_cost_per_token": 0,
"litellm_provider": "huggingface",