forked from phoenix/litellm-mirror
feat(proxy_server.py): expose new /model_group/info
endpoint
returns model-group level info on supported params, max tokens, pricing, etc.
This commit is contained in:
parent
bec13d465a
commit
22b6b99b34
6 changed files with 191 additions and 16 deletions
|
@ -60,17 +60,20 @@ def get_complete_model_list(
|
||||||
- If team list is empty -> defer to proxy model list
|
- If team list is empty -> defer to proxy model list
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if len(key_models) > 0:
|
unique_models = set()
|
||||||
return key_models
|
|
||||||
|
|
||||||
if len(team_models) > 0:
|
if key_models:
|
||||||
return team_models
|
unique_models.update(key_models)
|
||||||
|
elif team_models:
|
||||||
|
unique_models.update(team_models)
|
||||||
|
else:
|
||||||
|
unique_models.update(proxy_model_list)
|
||||||
|
|
||||||
returned_models = proxy_model_list
|
if user_model:
|
||||||
if user_model is not None: # set via `litellm --model ollama/llam3`
|
unique_models.add(user_model)
|
||||||
returned_models.append(user_model)
|
|
||||||
|
|
||||||
if infer_model_from_keys is not None and infer_model_from_keys == True:
|
if infer_model_from_keys:
|
||||||
valid_models = get_valid_models()
|
valid_models = get_valid_models()
|
||||||
returned_models.extend(valid_models)
|
unique_models.update(valid_models)
|
||||||
return returned_models
|
|
||||||
|
return list(unique_models)
|
||||||
|
|
|
@ -2,7 +2,7 @@ import sys, os, platform, time, copy, re, asyncio, inspect
|
||||||
import threading, ast
|
import threading, ast
|
||||||
import shutil, random, traceback, requests
|
import shutil, random, traceback, requests
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Optional, List, Callable, get_args
|
from typing import Optional, List, Callable, get_args, Set
|
||||||
import secrets, subprocess
|
import secrets, subprocess
|
||||||
import hashlib, uuid
|
import hashlib, uuid
|
||||||
import warnings
|
import warnings
|
||||||
|
@ -106,7 +106,7 @@ import pydantic
|
||||||
from litellm.proxy._types import *
|
from litellm.proxy._types import *
|
||||||
from litellm.caching import DualCache, RedisCache
|
from litellm.caching import DualCache, RedisCache
|
||||||
from litellm.proxy.health_check import perform_health_check
|
from litellm.proxy.health_check import perform_health_check
|
||||||
from litellm.router import LiteLLM_Params, Deployment, updateDeployment
|
from litellm.router import LiteLLM_Params, Deployment, updateDeployment, ModelGroupInfo
|
||||||
from litellm.router import ModelInfo as RouterModelInfo
|
from litellm.router import ModelInfo as RouterModelInfo
|
||||||
from litellm._logging import verbose_router_logger, verbose_proxy_logger
|
from litellm._logging import verbose_router_logger, verbose_proxy_logger
|
||||||
from litellm.proxy.auth.handle_jwt import JWTHandler
|
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||||
|
@ -9730,6 +9730,58 @@ async def model_info_v1(
|
||||||
return {"data": all_models}
|
return {"data": all_models}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/model_group/info",
|
||||||
|
description="Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)",
|
||||||
|
tags=["model management"],
|
||||||
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
|
)
|
||||||
|
async def model_group_info(
|
||||||
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Returns model info at the model group level.
|
||||||
|
"""
|
||||||
|
global llm_model_list, general_settings, user_config_file_path, proxy_config, llm_router
|
||||||
|
|
||||||
|
if llm_model_list is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500, detail={"error": "LLM Model List not loaded in"}
|
||||||
|
)
|
||||||
|
if llm_router is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500, detail={"error": "LLM Router is not loaded in"}
|
||||||
|
)
|
||||||
|
all_models: List[dict] = []
|
||||||
|
## CHECK IF MODEL RESTRICTIONS ARE SET AT KEY/TEAM LEVEL ##
|
||||||
|
if llm_model_list is None:
|
||||||
|
proxy_model_list = []
|
||||||
|
else:
|
||||||
|
proxy_model_list = [m["model_name"] for m in llm_model_list]
|
||||||
|
key_models = get_key_models(
|
||||||
|
user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
|
||||||
|
)
|
||||||
|
team_models = get_team_models(
|
||||||
|
user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
|
||||||
|
)
|
||||||
|
all_models_str = get_complete_model_list(
|
||||||
|
key_models=key_models,
|
||||||
|
team_models=team_models,
|
||||||
|
proxy_model_list=proxy_model_list,
|
||||||
|
user_model=user_model,
|
||||||
|
infer_model_from_keys=general_settings.get("infer_model_from_keys", False),
|
||||||
|
)
|
||||||
|
|
||||||
|
model_groups: List[ModelGroupInfo] = []
|
||||||
|
for model in all_models_str:
|
||||||
|
|
||||||
|
_model_group_info = llm_router.get_model_group_info(model_group=model)
|
||||||
|
if _model_group_info is not None:
|
||||||
|
model_groups.append(_model_group_info)
|
||||||
|
|
||||||
|
return {"data": model_groups}
|
||||||
|
|
||||||
|
|
||||||
#### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/964
|
#### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/964
|
||||||
@router.post(
|
@router.post(
|
||||||
"/model/delete",
|
"/model/delete",
|
||||||
|
|
|
@ -48,6 +48,7 @@ from litellm.types.router import (
|
||||||
RetryPolicy,
|
RetryPolicy,
|
||||||
AlertingConfig,
|
AlertingConfig,
|
||||||
DeploymentTypedDict,
|
DeploymentTypedDict,
|
||||||
|
ModelGroupInfo,
|
||||||
)
|
)
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm.llms.azure import get_azure_ad_token_from_oidc
|
from litellm.llms.azure import get_azure_ad_token_from_oidc
|
||||||
|
@ -3045,6 +3046,100 @@ class Router:
|
||||||
return model
|
return model
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_model_group_info(self, model_group: str) -> Optional[ModelGroupInfo]:
|
||||||
|
"""
|
||||||
|
For a given model group name, return the combined model info
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- ModelGroupInfo if able to construct a model group
|
||||||
|
- None if error constructing model group info
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_group_info: Optional[ModelGroupInfo] = None
|
||||||
|
|
||||||
|
for model in self.model_list:
|
||||||
|
if "model_name" in model and model["model_name"] == model_group:
|
||||||
|
# model in model group found #
|
||||||
|
litellm_params = LiteLLM_Params(**model["litellm_params"])
|
||||||
|
# get model info
|
||||||
|
try:
|
||||||
|
model_info = litellm.get_model_info(model=litellm_params.model)
|
||||||
|
except Exception as e:
|
||||||
|
continue
|
||||||
|
# get llm provider
|
||||||
|
try:
|
||||||
|
model, llm_provider, _, _ = litellm.get_llm_provider(
|
||||||
|
model=litellm_params.model,
|
||||||
|
custom_llm_provider=litellm_params.custom_llm_provider,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if model_group_info is None:
|
||||||
|
model_group_info = ModelGroupInfo(
|
||||||
|
model_group=model_group, providers=[llm_provider], **model_info # type: ignore
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# if max_input_tokens > curr
|
||||||
|
# if max_output_tokens > curr
|
||||||
|
# if input_cost_per_token > curr
|
||||||
|
# if output_cost_per_token > curr
|
||||||
|
# supports_parallel_function_calling == True
|
||||||
|
# supports_vision == True
|
||||||
|
# supports_function_calling == True
|
||||||
|
if llm_provider not in model_group_info.providers:
|
||||||
|
model_group_info.providers.append(llm_provider)
|
||||||
|
if model_info.get("max_input_tokens", None) is not None and (
|
||||||
|
model_group_info.max_input_tokens is None
|
||||||
|
or model_info["max_input_tokens"]
|
||||||
|
> model_group_info.max_input_tokens
|
||||||
|
):
|
||||||
|
model_group_info.max_input_tokens = model_info[
|
||||||
|
"max_input_tokens"
|
||||||
|
]
|
||||||
|
if model_info.get("max_output_tokens", None) is not None and (
|
||||||
|
model_group_info.max_output_tokens is None
|
||||||
|
or model_info["max_output_tokens"]
|
||||||
|
> model_group_info.max_output_tokens
|
||||||
|
):
|
||||||
|
model_group_info.max_output_tokens = model_info[
|
||||||
|
"max_output_tokens"
|
||||||
|
]
|
||||||
|
if model_info.get("input_cost_per_token", None) is not None and (
|
||||||
|
model_group_info.input_cost_per_token is None
|
||||||
|
or model_info["input_cost_per_token"]
|
||||||
|
> model_group_info.input_cost_per_token
|
||||||
|
):
|
||||||
|
model_group_info.input_cost_per_token = model_info[
|
||||||
|
"input_cost_per_token"
|
||||||
|
]
|
||||||
|
if model_info.get("output_cost_per_token", None) is not None and (
|
||||||
|
model_group_info.output_cost_per_token is None
|
||||||
|
or model_info["output_cost_per_token"]
|
||||||
|
> model_group_info.output_cost_per_token
|
||||||
|
):
|
||||||
|
model_group_info.output_cost_per_token = model_info[
|
||||||
|
"output_cost_per_token"
|
||||||
|
]
|
||||||
|
if (
|
||||||
|
model_info.get("supports_parallel_function_calling", None)
|
||||||
|
is not None
|
||||||
|
and model_info["supports_parallel_function_calling"] == True # type: ignore
|
||||||
|
):
|
||||||
|
model_group_info.supports_parallel_function_calling = True
|
||||||
|
if (
|
||||||
|
model_info.get("supports_vision", None) is not None
|
||||||
|
and model_info["supports_vision"] == True # type: ignore
|
||||||
|
):
|
||||||
|
model_group_info.supports_vision = True
|
||||||
|
if (
|
||||||
|
model_info.get("supports_function_calling", None) is not None
|
||||||
|
and model_info["supports_function_calling"] == True # type: ignore
|
||||||
|
):
|
||||||
|
model_group_info.supports_function_calling = True
|
||||||
|
|
||||||
|
return model_group_info
|
||||||
|
|
||||||
def get_model_ids(self) -> List[str]:
|
def get_model_ids(self) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Returns list of model id's.
|
Returns list of model id's.
|
||||||
|
|
|
@ -411,3 +411,18 @@ class AlertingConfig(BaseModel):
|
||||||
|
|
||||||
webhook_url: str
|
webhook_url: str
|
||||||
alerting_threshold: Optional[float] = 300
|
alerting_threshold: Optional[float] = 300
|
||||||
|
|
||||||
|
|
||||||
|
class ModelGroupInfo(BaseModel):
|
||||||
|
model_group: str
|
||||||
|
providers: List[str]
|
||||||
|
max_input_tokens: Optional[float] = None
|
||||||
|
max_output_tokens: Optional[float] = None
|
||||||
|
input_cost_per_token: Optional[float] = None
|
||||||
|
output_cost_per_token: Optional[float] = None
|
||||||
|
mode: Literal[
|
||||||
|
"chat", "embedding", "completion", "image_generation", "audio_transcription"
|
||||||
|
]
|
||||||
|
supports_parallel_function_calling: bool = Field(default=False)
|
||||||
|
supports_vision: bool = Field(default=False)
|
||||||
|
supports_function_calling: bool = Field(default=False)
|
||||||
|
|
|
@ -12,3 +12,13 @@ class ProviderField(TypedDict):
|
||||||
field_type: Literal["string"]
|
field_type: Literal["string"]
|
||||||
field_description: str
|
field_description: str
|
||||||
field_value: str
|
field_value: str
|
||||||
|
|
||||||
|
|
||||||
|
class ModelInfo(TypedDict):
|
||||||
|
max_tokens: int
|
||||||
|
max_input_tokens: int
|
||||||
|
max_output_tokens: int
|
||||||
|
input_cost_per_token: float
|
||||||
|
output_cost_per_token: float
|
||||||
|
litellm_provider: str
|
||||||
|
mode: str
|
||||||
|
|
|
@ -34,7 +34,7 @@ from dataclasses import (
|
||||||
import litellm._service_logger # for storing API inputs, outputs, and metadata
|
import litellm._service_logger # for storing API inputs, outputs, and metadata
|
||||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
from litellm.types.utils import CostPerToken, ProviderField
|
from litellm.types.utils import CostPerToken, ProviderField, ModelInfo
|
||||||
|
|
||||||
oidc_cache = DualCache()
|
oidc_cache = DualCache()
|
||||||
|
|
||||||
|
@ -7092,7 +7092,7 @@ def get_max_tokens(model: str):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_model_info(model: str):
|
def get_model_info(model: str) -> ModelInfo:
|
||||||
"""
|
"""
|
||||||
Get a dict for the maximum tokens (context window),
|
Get a dict for the maximum tokens (context window),
|
||||||
input_cost_per_token, output_cost_per_token for a given model.
|
input_cost_per_token, output_cost_per_token for a given model.
|
||||||
|
@ -7154,7 +7154,7 @@ def get_model_info(model: str):
|
||||||
if custom_llm_provider == "huggingface":
|
if custom_llm_provider == "huggingface":
|
||||||
max_tokens = _get_max_position_embeddings(model_name=model)
|
max_tokens = _get_max_position_embeddings(model_name=model)
|
||||||
return {
|
return {
|
||||||
"max_tokens": max_tokens,
|
"max_tokens": max_tokens, # type: ignore
|
||||||
"input_cost_per_token": 0,
|
"input_cost_per_token": 0,
|
||||||
"output_cost_per_token": 0,
|
"output_cost_per_token": 0,
|
||||||
"litellm_provider": "huggingface",
|
"litellm_provider": "huggingface",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue