feat(proxy_server.py): expose new /model_group/info endpoint

returns model-group level info on supported params, max tokens, pricing, etc.
This commit is contained in:
Krrish Dholakia 2024-05-26 14:07:35 -07:00
parent bec13d465a
commit 22b6b99b34
6 changed files with 191 additions and 16 deletions

View file

@ -60,17 +60,20 @@ def get_complete_model_list(
- If team list is empty -> defer to proxy model list - If team list is empty -> defer to proxy model list
""" """
if len(key_models) > 0: unique_models = set()
return key_models
if len(team_models) > 0: if key_models:
return team_models unique_models.update(key_models)
elif team_models:
unique_models.update(team_models)
else:
unique_models.update(proxy_model_list)
returned_models = proxy_model_list if user_model:
if user_model is not None: # set via `litellm --model ollama/llam3` unique_models.add(user_model)
returned_models.append(user_model)
if infer_model_from_keys is not None and infer_model_from_keys == True: if infer_model_from_keys:
valid_models = get_valid_models() valid_models = get_valid_models()
returned_models.extend(valid_models) unique_models.update(valid_models)
return returned_models
return list(unique_models)

View file

@ -2,7 +2,7 @@ import sys, os, platform, time, copy, re, asyncio, inspect
import threading, ast import threading, ast
import shutil, random, traceback, requests import shutil, random, traceback, requests
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import Optional, List, Callable, get_args from typing import Optional, List, Callable, get_args, Set
import secrets, subprocess import secrets, subprocess
import hashlib, uuid import hashlib, uuid
import warnings import warnings
@ -106,7 +106,7 @@ import pydantic
from litellm.proxy._types import * from litellm.proxy._types import *
from litellm.caching import DualCache, RedisCache from litellm.caching import DualCache, RedisCache
from litellm.proxy.health_check import perform_health_check from litellm.proxy.health_check import perform_health_check
from litellm.router import LiteLLM_Params, Deployment, updateDeployment from litellm.router import LiteLLM_Params, Deployment, updateDeployment, ModelGroupInfo
from litellm.router import ModelInfo as RouterModelInfo from litellm.router import ModelInfo as RouterModelInfo
from litellm._logging import verbose_router_logger, verbose_proxy_logger from litellm._logging import verbose_router_logger, verbose_proxy_logger
from litellm.proxy.auth.handle_jwt import JWTHandler from litellm.proxy.auth.handle_jwt import JWTHandler
@ -9730,6 +9730,58 @@ async def model_info_v1(
return {"data": all_models} return {"data": all_models}
@router.get(
"/model_group/info",
description="Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)",
tags=["model management"],
dependencies=[Depends(user_api_key_auth)],
)
async def model_group_info(
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Returns model info at the model group level.
"""
global llm_model_list, general_settings, user_config_file_path, proxy_config, llm_router
if llm_model_list is None:
raise HTTPException(
status_code=500, detail={"error": "LLM Model List not loaded in"}
)
if llm_router is None:
raise HTTPException(
status_code=500, detail={"error": "LLM Router is not loaded in"}
)
all_models: List[dict] = []
## CHECK IF MODEL RESTRICTIONS ARE SET AT KEY/TEAM LEVEL ##
if llm_model_list is None:
proxy_model_list = []
else:
proxy_model_list = [m["model_name"] for m in llm_model_list]
key_models = get_key_models(
user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
)
team_models = get_team_models(
user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
)
all_models_str = get_complete_model_list(
key_models=key_models,
team_models=team_models,
proxy_model_list=proxy_model_list,
user_model=user_model,
infer_model_from_keys=general_settings.get("infer_model_from_keys", False),
)
model_groups: List[ModelGroupInfo] = []
for model in all_models_str:
_model_group_info = llm_router.get_model_group_info(model_group=model)
if _model_group_info is not None:
model_groups.append(_model_group_info)
return {"data": model_groups}
#### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/964 #### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/964
@router.post( @router.post(
"/model/delete", "/model/delete",

View file

@ -48,6 +48,7 @@ from litellm.types.router import (
RetryPolicy, RetryPolicy,
AlertingConfig, AlertingConfig,
DeploymentTypedDict, DeploymentTypedDict,
ModelGroupInfo,
) )
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.azure import get_azure_ad_token_from_oidc from litellm.llms.azure import get_azure_ad_token_from_oidc
@ -3045,6 +3046,100 @@ class Router:
return model return model
return None return None
def get_model_group_info(self, model_group: str) -> Optional[ModelGroupInfo]:
"""
For a given model group name, return the combined model info
Returns:
- ModelGroupInfo if able to construct a model group
- None if error constructing model group info
"""
model_group_info: Optional[ModelGroupInfo] = None
for model in self.model_list:
if "model_name" in model and model["model_name"] == model_group:
# model in model group found #
litellm_params = LiteLLM_Params(**model["litellm_params"])
# get model info
try:
model_info = litellm.get_model_info(model=litellm_params.model)
except Exception as e:
continue
# get llm provider
try:
model, llm_provider, _, _ = litellm.get_llm_provider(
model=litellm_params.model,
custom_llm_provider=litellm_params.custom_llm_provider,
)
except Exception as e:
continue
if model_group_info is None:
model_group_info = ModelGroupInfo(
model_group=model_group, providers=[llm_provider], **model_info # type: ignore
)
else:
# if max_input_tokens > curr
# if max_output_tokens > curr
# if input_cost_per_token > curr
# if output_cost_per_token > curr
# supports_parallel_function_calling == True
# supports_vision == True
# supports_function_calling == True
if llm_provider not in model_group_info.providers:
model_group_info.providers.append(llm_provider)
if model_info.get("max_input_tokens", None) is not None and (
model_group_info.max_input_tokens is None
or model_info["max_input_tokens"]
> model_group_info.max_input_tokens
):
model_group_info.max_input_tokens = model_info[
"max_input_tokens"
]
if model_info.get("max_output_tokens", None) is not None and (
model_group_info.max_output_tokens is None
or model_info["max_output_tokens"]
> model_group_info.max_output_tokens
):
model_group_info.max_output_tokens = model_info[
"max_output_tokens"
]
if model_info.get("input_cost_per_token", None) is not None and (
model_group_info.input_cost_per_token is None
or model_info["input_cost_per_token"]
> model_group_info.input_cost_per_token
):
model_group_info.input_cost_per_token = model_info[
"input_cost_per_token"
]
if model_info.get("output_cost_per_token", None) is not None and (
model_group_info.output_cost_per_token is None
or model_info["output_cost_per_token"]
> model_group_info.output_cost_per_token
):
model_group_info.output_cost_per_token = model_info[
"output_cost_per_token"
]
if (
model_info.get("supports_parallel_function_calling", None)
is not None
and model_info["supports_parallel_function_calling"] == True # type: ignore
):
model_group_info.supports_parallel_function_calling = True
if (
model_info.get("supports_vision", None) is not None
and model_info["supports_vision"] == True # type: ignore
):
model_group_info.supports_vision = True
if (
model_info.get("supports_function_calling", None) is not None
and model_info["supports_function_calling"] == True # type: ignore
):
model_group_info.supports_function_calling = True
return model_group_info
def get_model_ids(self) -> List[str]: def get_model_ids(self) -> List[str]:
""" """
Returns list of model id's. Returns list of model id's.

View file

@ -411,3 +411,18 @@ class AlertingConfig(BaseModel):
webhook_url: str webhook_url: str
alerting_threshold: Optional[float] = 300 alerting_threshold: Optional[float] = 300
class ModelGroupInfo(BaseModel):
model_group: str
providers: List[str]
max_input_tokens: Optional[float] = None
max_output_tokens: Optional[float] = None
input_cost_per_token: Optional[float] = None
output_cost_per_token: Optional[float] = None
mode: Literal[
"chat", "embedding", "completion", "image_generation", "audio_transcription"
]
supports_parallel_function_calling: bool = Field(default=False)
supports_vision: bool = Field(default=False)
supports_function_calling: bool = Field(default=False)

View file

@ -12,3 +12,13 @@ class ProviderField(TypedDict):
field_type: Literal["string"] field_type: Literal["string"]
field_description: str field_description: str
field_value: str field_value: str
class ModelInfo(TypedDict):
max_tokens: int
max_input_tokens: int
max_output_tokens: int
input_cost_per_token: float
output_cost_per_token: float
litellm_provider: str
mode: str

View file

@ -34,7 +34,7 @@ from dataclasses import (
import litellm._service_logger # for storing API inputs, outputs, and metadata import litellm._service_logger # for storing API inputs, outputs, and metadata
from litellm.llms.custom_httpx.http_handler import HTTPHandler from litellm.llms.custom_httpx.http_handler import HTTPHandler
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.types.utils import CostPerToken, ProviderField from litellm.types.utils import CostPerToken, ProviderField, ModelInfo
oidc_cache = DualCache() oidc_cache = DualCache()
@ -7092,7 +7092,7 @@ def get_max_tokens(model: str):
) )
def get_model_info(model: str): def get_model_info(model: str) -> ModelInfo:
""" """
Get a dict for the maximum tokens (context window), Get a dict for the maximum tokens (context window),
input_cost_per_token, output_cost_per_token for a given model. input_cost_per_token, output_cost_per_token for a given model.
@ -7154,7 +7154,7 @@ def get_model_info(model: str):
if custom_llm_provider == "huggingface": if custom_llm_provider == "huggingface":
max_tokens = _get_max_position_embeddings(model_name=model) max_tokens = _get_max_position_embeddings(model_name=model)
return { return {
"max_tokens": max_tokens, "max_tokens": max_tokens, # type: ignore
"input_cost_per_token": 0, "input_cost_per_token": 0,
"output_cost_per_token": 0, "output_cost_per_token": 0,
"litellm_provider": "huggingface", "litellm_provider": "huggingface",